File size: 4,751 Bytes
1f9a45b 9274ae9 c48e567 1f9a45b 5f0ae39 1f9a45b 5f0ae39 2b22bca 9274ae9 2b22bca 0bb757b 2b22bca 1f9a45b 2b22bca 9274ae9 2b22bca 0bb757b 2b22bca 1f9a45b 2b22bca 9274ae9 2b22bca 0bb757b 2b22bca 1f9a45b 2b22bca 9274ae9 2b22bca 0bb757b 2b22bca 1f9a45b 2c867d4 9274ae9 0bb757b 2c867d4 c48e567 2c867d4 0bb757b 2c867d4 c48e567 2b22bca c48e567 2b22bca c48e567 1f9a45b c48e567 1f9a45b c48e567 9274ae9 1f9a45b 1408ebf 2c867d4 1408ebf 2c867d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import time
import torch
import concurrent.futures
# Load models and processor only once using Streamlit session state
if 'models_loaded' not in st.session_state:
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
st.session_state.models_loaded = True
# Define image processing and classification functions
def topwear(encoding, top_wear_model):
with torch.no_grad():
outputs = top_wear_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
return top_wear_model.config.id2label[predicted_class_idx]
def patterns(encoding, pattern_model):
with torch.no_grad():
outputs = pattern_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
return pattern_model.config.id2label[predicted_class_idx]
def prints(encoding, print_model):
with torch.no_grad():
outputs = print_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
return print_model.config.id2label[predicted_class_idx]
def sleevelengths(encoding, sleeve_length_model):
with torch.no_grad():
outputs = sleeve_length_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
return sleeve_length_model.config.id2label[predicted_class_idx]
def imageprocessing(image):
encoding = st.session_state.image_processor(images=image, return_tensors="pt")
return encoding
# Run all models concurrently using threading
def pipes(image):
# Process the image once and reuse the encoding
encoding = imageprocessing(image)
# Access models from session state before threading
top_wear_model = st.session_state.top_wear_model
pattern_model = st.session_state.pattern_model
print_model = st.session_state.print_model
sleeve_length_model = st.session_state.sleeve_length_model
# Define functions to run the models in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(topwear, encoding, top_wear_model): "topwear",
executor.submit(patterns, encoding, pattern_model): "patterns",
executor.submit(prints, encoding, print_model): "prints",
executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length"
}
results = {}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
try:
results[model_name] = future.result()
except Exception as e:
st.error(f"Error in {model_name}: {str(e)}")
results[model_name] = None
# Display the results
st.write(results)
return results
# Streamlit app UI
st.title("Clothing Classification Pipeline")
url = st.text_input("Paste image URL here...")
if url:
try:
response = requests.get(url)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
start_time = time.time()
result = pipes(image)
st.write("Classification Results (JSON):")
st.json(result) # Display results in JSON format
st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
else:
st.error("Failed to load image from URL. Please check the URL.")
except Exception as e:
st.error(f"Error processing the image: {str(e)}")
|