import streamlit as st from transformers import AutoModelForImageClassification, AutoImageProcessor from PIL import Image import requests from io import BytesIO import time import torch # Load models and processor only once using Streamlit session state if 'models_loaded' not in st.session_state: st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb") st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print") st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length") st.session_state.models_loaded = True # Define image processing and classification functions def topwear(encoding): with torch.no_grad(): outputs = st.session_state.top_wear_model(**encoding) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() st.write(f"Top Wear: {st.session_state.top_wear_model.config.id2label[predicted_class_idx]}") return st.session_state.top_wear_model.config.id2label[predicted_class_idx] def patterns(encoding): with torch.no_grad(): outputs = st.session_state.pattern_model(**encoding) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() st.write(f"Pattern: {st.session_state.pattern_model.config.id2label[predicted_class_idx]}") return st.session_state.pattern_model.config.id2label[predicted_class_idx] def prints(encoding): with torch.no_grad(): outputs = st.session_state.print_model(**encoding) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() st.write(f"Print: {st.session_state.print_model.config.id2label[predicted_class_idx]}") return st.session_state.print_model.config.id2label[predicted_class_idx] def sleevelengths(encoding): with torch.no_grad(): outputs = st.session_state.sleeve_length_model(**encoding) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() st.write(f"Sleeve Length: {st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]}") return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx] def imageprocessing(image): encoding = st.session_state.image_processor(images=image, return_tensors="pt") return encoding # Run all models sequentially def pipes(image): # Process the image once and reuse the encoding encoding = imageprocessing(image) # Get results from each model topwear_result = topwear(encoding) pattern_result = patterns(encoding) print_result = prints(encoding) sleeve_length_result = sleevelengths(encoding) # Combine the results into a dictionary results = { "top": topwear_result, "pattern": pattern_result, "print": print_result, "sleeve_length": sleeve_length_result } st.write(results) return results # Streamlit app UI st.title("Clothing Classification Pipeline") url = st.text_input("Paste image URL here...") if url: try: response = requests.get(url) if response.status_code == 200: image = Image.open(BytesIO(response.content)) st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False) start_time = time.time() result = pipes(image) st.write("Classification Results (JSON):") st.json(result) # Display results in JSON format st.write(f"Time taken: {time.time() - start_time:.2f} seconds") else: st.error("Failed to load image from URL. Please check the URL.") except Exception as e: st.error(f"Error processing the image: {str(e)}")