vishalkatheriya18's picture
Update app.py
9274ae9 verified
raw
history blame
4.16 kB
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import time
import torch
# Load models and processor only once using Streamlit session state
if 'models_loaded' not in st.session_state:
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
st.session_state.models_loaded = True
# Define image processing and classification functions
def topwear(encoding):
with torch.no_grad():
outputs = st.session_state.top_wear_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Top Wear: {st.session_state.top_wear_model.config.id2label[predicted_class_idx]}")
return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
def patterns(encoding):
with torch.no_grad():
outputs = st.session_state.pattern_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Pattern: {st.session_state.pattern_model.config.id2label[predicted_class_idx]}")
return st.session_state.pattern_model.config.id2label[predicted_class_idx]
def prints(encoding):
with torch.no_grad():
outputs = st.session_state.print_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Print: {st.session_state.print_model.config.id2label[predicted_class_idx]}")
return st.session_state.print_model.config.id2label[predicted_class_idx]
def sleevelengths(encoding):
with torch.no_grad():
outputs = st.session_state.sleeve_length_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Sleeve Length: {st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]}")
return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
def imageprocessing(image):
encoding = st.session_state.image_processor(images=image, return_tensors="pt")
return encoding
# Run all models sequentially
def pipes(image):
# Process the image once and reuse the encoding
encoding = imageprocessing(image)
# Get results from each model
topwear_result = topwear(encoding)
pattern_result = patterns(encoding)
print_result = prints(encoding)
sleeve_length_result = sleevelengths(encoding)
# Combine the results into a dictionary
results = {
"top": topwear_result,
"pattern": pattern_result,
"print": print_result,
"sleeve_length": sleeve_length_result
}
st.write(results)
return results
# Streamlit app UI
st.title("Clothing Classification Pipeline")
url = st.text_input("Paste image URL here...")
if url:
try:
response = requests.get(url)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
start_time = time.time()
result = pipes(image)
st.write("Classification Results (JSON):")
st.json(result) # Display results in JSON format
st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
else:
st.error("Failed to load image from URL. Please check the URL.")
except Exception as e:
st.error(f"Error processing the image: {str(e)}")