vishalkatheriya18's picture
Create app.py
1f9a45b verified
raw
history blame
3.67 kB
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import threading
import time
# Load models and processor only once using session state
if 'models_loaded' not in st.session_state:
# Image processor
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
# Topwear model
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
# Pattern model
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
# Print model
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
# Sleeve length model
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
st.session_state.models_loaded = True
# Functions for predictions
def topwear(encoding):
outputs = st.session_state.top_wear_model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
def patterns(encoding):
outputs = st.session_state.pattern_model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
return st.session_state.pattern_model.config.id2label[predicted_class_idx]
def prints(encoding):
outputs = st.session_state.print_model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
return st.session_state.print_model.config.id2label[predicted_class_idx]
def sleevelengths(encoding):
outputs = st.session_state.sleeve_length_model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
def imageprocessing(url):
response = requests.get(url)
image = Image.open(BytesIO(response.content))
encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
return encoding, image
def pipes(imagepath):
encoding, image = imageprocessing(imagepath)
# Using threading for faster results
results = [None] * 4
threads = [
threading.Thread(target=lambda: results.__setitem__(0, topwear(encoding))),
threading.Thread(target=lambda: results.__setitem__(1, patterns(encoding))),
threading.Thread(target=lambda: results.__setitem__(2, prints(encoding))),
threading.Thread(target=lambda: results.__setitem__(3, sleevelengths(encoding))),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
dicts = {"top": results[0], "pattern": results[1], "print": results[2], "sleeve_length": results[3]}
return dicts, image
# Streamlit app UI
st.title("Clothing Classification Pipeline")
image_url = st.text_input("Enter Image URL")
if image_url:
start_time = time.time()
results, img = pipes(image_url)
st.image(img.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
# Display results
st.write("Classification Results:")
st.write(f"Topwear: {results['top']}")
st.write(f"Pattern: {results['pattern']}")
st.write(f"Print: {results['print']}")
st.write(f"Sleeve Length: {results['sleeve_length']}")
st.write(f"Time taken: {time.time() - start_time:.2f} seconds")