|
import streamlit as st |
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import threading |
|
import time |
|
|
|
|
|
if 'models_loaded' not in st.session_state: |
|
|
|
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") |
|
|
|
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") |
|
|
|
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb") |
|
|
|
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print") |
|
|
|
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length") |
|
st.session_state.models_loaded = True |
|
|
|
|
|
def topwear(encoding): |
|
outputs = st.session_state.top_wear_model(**encoding) |
|
predicted_class_idx = outputs.logits.argmax(-1).item() |
|
return st.session_state.top_wear_model.config.id2label[predicted_class_idx] |
|
|
|
def patterns(encoding): |
|
outputs = st.session_state.pattern_model(**encoding) |
|
predicted_class_idx = outputs.logits.argmax(-1).item() |
|
return st.session_state.pattern_model.config.id2label[predicted_class_idx] |
|
|
|
def prints(encoding): |
|
outputs = st.session_state.print_model(**encoding) |
|
predicted_class_idx = outputs.logits.argmax(-1).item() |
|
return st.session_state.print_model.config.id2label[predicted_class_idx] |
|
|
|
def sleevelengths(encoding): |
|
outputs = st.session_state.sleeve_length_model(**encoding) |
|
predicted_class_idx = outputs.logits.argmax(-1).item() |
|
return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx] |
|
|
|
def imageprocessing(url): |
|
response = requests.get(url) |
|
image = Image.open(BytesIO(response.content)) |
|
encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt") |
|
return encoding, image |
|
|
|
def pipes(imagepath): |
|
encoding, image = imageprocessing(imagepath) |
|
|
|
results = [None] * 4 |
|
threads = [ |
|
threading.Thread(target=lambda: results.__setitem__(0, topwear(encoding))), |
|
threading.Thread(target=lambda: results.__setitem__(1, patterns(encoding))), |
|
threading.Thread(target=lambda: results.__setitem__(2, prints(encoding))), |
|
threading.Thread(target=lambda: results.__setitem__(3, sleevelengths(encoding))), |
|
] |
|
for thread in threads: |
|
thread.start() |
|
for thread in threads: |
|
thread.join() |
|
|
|
dicts = {"top": results[0], "pattern": results[1], "print": results[2], "sleeve_length": results[3]} |
|
return dicts, image |
|
|
|
|
|
st.title("Clothing Classification Pipeline") |
|
|
|
image_url = st.text_input("Enter Image URL") |
|
|
|
if image_url: |
|
start_time = time.time() |
|
results, img = pipes(image_url) |
|
st.image(img.resize((200, 200)), caption="Uploaded Image", use_column_width=False) |
|
|
|
|
|
st.write("Classification Results:") |
|
st.write(f"Topwear: {results['top']}") |
|
st.write(f"Pattern: {results['pattern']}") |
|
st.write(f"Print: {results['print']}") |
|
st.write(f"Sleeve Length: {results['sleeve_length']}") |
|
|
|
st.write(f"Time taken: {time.time() - start_time:.2f} seconds") |
|
|