import streamlit as st from transformers import AutoModelForImageClassification, AutoImageProcessor from PIL import Image import requests from io import BytesIO import threading import time # Load models and processor only once using session state if 'models_loaded' not in st.session_state: # Image processor st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") # Topwear model st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") # Pattern model st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb") # Print model st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print") # Sleeve length model st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length") st.session_state.models_loaded = True # Functions for predictions def topwear(encoding): outputs = st.session_state.top_wear_model(**encoding) predicted_class_idx = outputs.logits.argmax(-1).item() return st.session_state.top_wear_model.config.id2label[predicted_class_idx] def patterns(encoding): outputs = st.session_state.pattern_model(**encoding) predicted_class_idx = outputs.logits.argmax(-1).item() return st.session_state.pattern_model.config.id2label[predicted_class_idx] def prints(encoding): outputs = st.session_state.print_model(**encoding) predicted_class_idx = outputs.logits.argmax(-1).item() return st.session_state.print_model.config.id2label[predicted_class_idx] def sleevelengths(encoding): outputs = st.session_state.sleeve_length_model(**encoding) predicted_class_idx = outputs.logits.argmax(-1).item() return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx] def imageprocessing(url): response = requests.get(url) image = Image.open(BytesIO(response.content)) encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt") return encoding, image def pipes(imagepath): encoding, image = imageprocessing(imagepath) # Using threading for faster results results = [None] * 4 threads = [ threading.Thread(target=lambda: results.__setitem__(0, topwear(encoding))), threading.Thread(target=lambda: results.__setitem__(1, patterns(encoding))), threading.Thread(target=lambda: results.__setitem__(2, prints(encoding))), threading.Thread(target=lambda: results.__setitem__(3, sleevelengths(encoding))), ] for thread in threads: thread.start() for thread in threads: thread.join() dicts = {"top": results[0], "pattern": results[1], "print": results[2], "sleeve_length": results[3]} return dicts, image # Streamlit app UI st.title("Clothing Classification Pipeline") image_url = st.text_input("Enter Image URL") if image_url: start_time = time.time() results, img = pipes(image_url) st.image(img.resize((200, 200)), caption="Uploaded Image", use_column_width=False) # Display results st.write("Classification Results:") st.write(f"Topwear: {results['top']}") st.write(f"Pattern: {results['pattern']}") st.write(f"Print: {results['print']}") st.write(f"Sleeve Length: {results['sleeve_length']}") st.write(f"Time taken: {time.time() - start_time:.2f} seconds")