|
import streamlit as st |
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import time |
|
import torch |
|
|
|
|
|
if 'models_loaded' not in st.session_state: |
|
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") |
|
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") |
|
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb") |
|
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print") |
|
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length") |
|
st.session_state.models_loaded = True |
|
|
|
|
|
def topwear(encoding): |
|
with torch.no_grad(): |
|
outputs = st.session_state.top_wear_model(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
st.write(f"Top Wear: {st.session_state.top_wear_model.config.id2label[predicted_class_idx]}") |
|
return st.session_state.top_wear_model.config.id2label[predicted_class_idx] |
|
|
|
def patterns(encoding): |
|
with torch.no_grad(): |
|
outputs = st.session_state.pattern_model(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
st.write(f"Pattern: {st.session_state.pattern_model.config.id2label[predicted_class_idx]}") |
|
return st.session_state.pattern_model.config.id2label[predicted_class_idx] |
|
|
|
def prints(encoding): |
|
with torch.no_grad(): |
|
outputs = st.session_state.print_model(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
st.write(f"Print: {st.session_state.print_model.config.id2label[predicted_class_idx]}") |
|
return st.session_state.print_model.config.id2label[predicted_class_idx] |
|
|
|
def sleevelengths(encoding): |
|
with torch.no_grad(): |
|
outputs = st.session_state.sleeve_length_model(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
st.write(f"Sleeve Length: {st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]}") |
|
return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx] |
|
|
|
def imageprocessing(image): |
|
encoding = st.session_state.image_processor(images=image, return_tensors="pt") |
|
return encoding |
|
|
|
|
|
def pipes(image): |
|
|
|
encoding = imageprocessing(image) |
|
|
|
|
|
topwear_result = topwear(encoding) |
|
pattern_result = patterns(encoding) |
|
print_result = prints(encoding) |
|
sleeve_length_result = sleevelengths(encoding) |
|
|
|
|
|
results = { |
|
"top": topwear_result, |
|
"pattern": pattern_result, |
|
"print": print_result, |
|
"sleeve_length": sleeve_length_result |
|
} |
|
st.write(results) |
|
return results |
|
|
|
|
|
st.title("Clothing Classification Pipeline") |
|
|
|
url = st.text_input("Paste image URL here...") |
|
if url: |
|
try: |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
image = Image.open(BytesIO(response.content)) |
|
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False) |
|
|
|
start_time = time.time() |
|
|
|
result = pipes(image) |
|
st.write("Classification Results (JSON):") |
|
st.json(result) |
|
st.write(f"Time taken: {time.time() - start_time:.2f} seconds") |
|
else: |
|
st.error("Failed to load image from URL. Please check the URL.") |
|
except Exception as e: |
|
st.error(f"Error processing the image: {str(e)}") |
|
|