|
import streamlit as st |
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import time |
|
import torch |
|
import concurrent.futures |
|
|
|
|
|
if 'models_loaded' not in st.session_state: |
|
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") |
|
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") |
|
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb") |
|
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print") |
|
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length") |
|
st.session_state.models_loaded = True |
|
|
|
|
|
def topwear(encoding, top_wear_model): |
|
with torch.no_grad(): |
|
outputs = top_wear_model(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}") |
|
return top_wear_model.config.id2label[predicted_class_idx] |
|
|
|
def patterns(encoding, pattern_model): |
|
with torch.no_grad(): |
|
outputs = pattern_model(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}") |
|
return pattern_model.config.id2label[predicted_class_idx] |
|
|
|
def prints(encoding, print_model): |
|
with torch.no_grad(): |
|
outputs = print_model(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}") |
|
return print_model.config.id2label[predicted_class_idx] |
|
|
|
def sleevelengths(encoding, sleeve_length_model): |
|
with torch.no_grad(): |
|
outputs = sleeve_length_model(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}") |
|
return sleeve_length_model.config.id2label[predicted_class_idx] |
|
|
|
def imageprocessing(image): |
|
encoding = st.session_state.image_processor(images=image, return_tensors="pt") |
|
return encoding |
|
|
|
|
|
def pipes(image): |
|
|
|
encoding = imageprocessing(image) |
|
|
|
|
|
top_wear_model = st.session_state.top_wear_model |
|
pattern_model = st.session_state.pattern_model |
|
print_model = st.session_state.print_model |
|
sleeve_length_model = st.session_state.sleeve_length_model |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
futures = { |
|
executor.submit(topwear, encoding, top_wear_model): "topwear", |
|
executor.submit(patterns, encoding, pattern_model): "patterns", |
|
executor.submit(prints, encoding, print_model): "prints", |
|
executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length" |
|
} |
|
|
|
results = {} |
|
for future in concurrent.futures.as_completed(futures): |
|
model_name = futures[future] |
|
try: |
|
results[model_name] = future.result() |
|
except Exception as e: |
|
st.error(f"Error in {model_name}: {str(e)}") |
|
results[model_name] = None |
|
|
|
|
|
st.write(results) |
|
return results |
|
|
|
|
|
st.title("Clothing Classification Pipeline") |
|
|
|
url = st.text_input("Paste image URL here...") |
|
if url: |
|
try: |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
image = Image.open(BytesIO(response.content)) |
|
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False) |
|
|
|
start_time = time.time() |
|
|
|
result = pipes(image) |
|
st.write("Classification Results (JSON):") |
|
st.json(result) |
|
st.write(f"Time taken: {time.time() - start_time:.2f} seconds") |
|
else: |
|
st.error("Failed to load image from URL. Please check the URL.") |
|
except Exception as e: |
|
st.error(f"Error processing the image: {str(e)}") |
|
|