vishalkatheriya's picture
Update app.py
2b22bca verified
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import time
import torch
import concurrent.futures
# Load models and processor only once using Streamlit session state
if 'models_loaded' not in st.session_state:
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
st.session_state.models_loaded = True
# Define image processing and classification functions
def topwear(encoding, top_wear_model):
with torch.no_grad():
outputs = top_wear_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
return top_wear_model.config.id2label[predicted_class_idx]
def patterns(encoding, pattern_model):
with torch.no_grad():
outputs = pattern_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
return pattern_model.config.id2label[predicted_class_idx]
def prints(encoding, print_model):
with torch.no_grad():
outputs = print_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
return print_model.config.id2label[predicted_class_idx]
def sleevelengths(encoding, sleeve_length_model):
with torch.no_grad():
outputs = sleeve_length_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
return sleeve_length_model.config.id2label[predicted_class_idx]
def imageprocessing(image):
encoding = st.session_state.image_processor(images=image, return_tensors="pt")
return encoding
# Run all models concurrently using threading
def pipes(image):
# Process the image once and reuse the encoding
encoding = imageprocessing(image)
# Access models from session state before threading
top_wear_model = st.session_state.top_wear_model
pattern_model = st.session_state.pattern_model
print_model = st.session_state.print_model
sleeve_length_model = st.session_state.sleeve_length_model
# Define functions to run the models in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(topwear, encoding, top_wear_model): "topwear",
executor.submit(patterns, encoding, pattern_model): "patterns",
executor.submit(prints, encoding, print_model): "prints",
executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length"
}
results = {}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
try:
results[model_name] = future.result()
except Exception as e:
st.error(f"Error in {model_name}: {str(e)}")
results[model_name] = None
# Display the results
st.write(results)
return results
# Streamlit app UI
st.title("Clothing Classification Pipeline")
url = st.text_input("Paste image URL here...")
if url:
try:
response = requests.get(url)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
start_time = time.time()
result = pipes(image)
st.write("Classification Results (JSON):")
st.json(result) # Display results in JSON format
st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
else:
st.error("Failed to load image from URL. Please check the URL.")
except Exception as e:
st.error(f"Error processing the image: {str(e)}")