File size: 4,751 Bytes
1f9a45b
 
 
 
 
 
9274ae9
c48e567
1f9a45b
5f0ae39
1f9a45b
 
 
 
 
 
 
 
5f0ae39
2b22bca
9274ae9
2b22bca
0bb757b
 
2b22bca
 
1f9a45b
2b22bca
9274ae9
2b22bca
0bb757b
 
2b22bca
 
1f9a45b
2b22bca
9274ae9
2b22bca
0bb757b
 
2b22bca
 
1f9a45b
2b22bca
9274ae9
2b22bca
0bb757b
 
2b22bca
 
1f9a45b
2c867d4
9274ae9
0bb757b
2c867d4
c48e567
2c867d4
0bb757b
2c867d4
c48e567
2b22bca
 
 
 
 
 
c48e567
 
 
2b22bca
 
 
 
c48e567
1f9a45b
c48e567
 
 
 
 
 
 
 
1f9a45b
c48e567
9274ae9
 
1f9a45b
 
 
 
1408ebf
 
2c867d4
 
 
 
 
 
 
 
 
1408ebf
 
 
2c867d4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import time
import torch
import concurrent.futures

# Load models and processor only once using Streamlit session state
if 'models_loaded' not in st.session_state:
    st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
    st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear")
    st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb")
    st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print")
    st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length")
    st.session_state.models_loaded = True

# Define image processing and classification functions
def topwear(encoding, top_wear_model):
    with torch.no_grad():
        outputs = top_wear_model(**encoding)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
    return top_wear_model.config.id2label[predicted_class_idx]

def patterns(encoding, pattern_model):
    with torch.no_grad():
        outputs = pattern_model(**encoding)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
    return pattern_model.config.id2label[predicted_class_idx]

def prints(encoding, print_model):
    with torch.no_grad():
        outputs = print_model(**encoding)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
    return print_model.config.id2label[predicted_class_idx]

def sleevelengths(encoding, sleeve_length_model):
    with torch.no_grad():
        outputs = sleeve_length_model(**encoding)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
    return sleeve_length_model.config.id2label[predicted_class_idx]

def imageprocessing(image):
    encoding = st.session_state.image_processor(images=image, return_tensors="pt")
    return encoding

# Run all models concurrently using threading
def pipes(image):
    # Process the image once and reuse the encoding
    encoding = imageprocessing(image)
    
    # Access models from session state before threading
    top_wear_model = st.session_state.top_wear_model
    pattern_model = st.session_state.pattern_model
    print_model = st.session_state.print_model
    sleeve_length_model = st.session_state.sleeve_length_model
    
    # Define functions to run the models in parallel
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = {
            executor.submit(topwear, encoding, top_wear_model): "topwear",
            executor.submit(patterns, encoding, pattern_model): "patterns",
            executor.submit(prints, encoding, print_model): "prints",
            executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length"
        }

        results = {}
        for future in concurrent.futures.as_completed(futures):
            model_name = futures[future]
            try:
                results[model_name] = future.result()
            except Exception as e:
                st.error(f"Error in {model_name}: {str(e)}")
                results[model_name] = None

    # Display the results
    st.write(results)
    return results

# Streamlit app UI
st.title("Clothing Classification Pipeline")

url = st.text_input("Paste image URL here...")
if url:
    try:
        response = requests.get(url)
        if response.status_code == 200:
            image = Image.open(BytesIO(response.content))
            st.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
            
            start_time = time.time()

            result = pipes(image)
            st.write("Classification Results (JSON):")
            st.json(result)  # Display results in JSON format
            st.write(f"Time taken: {time.time() - start_time:.2f} seconds")
        else:
            st.error("Failed to load image from URL. Please check the URL.")
    except Exception as e:
        st.error(f"Error processing the image: {str(e)}")