vishalkatheriya18 commited on
Commit
9274ae9
·
verified ·
1 Parent(s): 5f69c43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -40
app.py CHANGED
@@ -3,8 +3,8 @@ from transformers import AutoModelForImageClassification, AutoImageProcessor
3
  from PIL import Image
4
  import requests
5
  from io import BytesIO
6
- import threading
7
  import time
 
8
 
9
  # Load models and processor only once using Streamlit session state
10
  if 'models_loaded' not in st.session_state:
@@ -17,74 +17,61 @@ if 'models_loaded' not in st.session_state:
17
 
18
  # Define image processing and classification functions
19
  def topwear(encoding):
20
- outputs = st.session_state.top_wear_model(**encoding)
 
21
  logits = outputs.logits
22
  predicted_class_idx = logits.argmax(-1).item()
23
- st.write(st.session_state.top_wear_model.config.id2label[predicted_class_idx])
24
  return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
25
 
26
  def patterns(encoding):
27
- outputs = st.session_state.pattern_model(**encoding)
 
28
  logits = outputs.logits
29
  predicted_class_idx = logits.argmax(-1).item()
30
- st.write(st.session_state.pattern_model.config.id2label[predicted_class_idx])
31
  return st.session_state.pattern_model.config.id2label[predicted_class_idx]
32
 
33
  def prints(encoding):
34
- outputs = st.session_state.print_model(**encoding)
 
35
  logits = outputs.logits
36
  predicted_class_idx = logits.argmax(-1).item()
37
- st.write(st.session_state.print_model.config.id2label[predicted_class_idx])
38
  return st.session_state.print_model.config.id2label[predicted_class_idx]
39
 
40
  def sleevelengths(encoding):
41
- outputs = st.session_state.sleeve_length_model(**encoding)
 
42
  logits = outputs.logits
43
  predicted_class_idx = logits.argmax(-1).item()
44
- st.write(st.session_state.sleeve_length_model.config.id2label[predicted_class_idx])
45
  return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
46
 
47
  def imageprocessing(image):
48
- encoding = st.session_state.image_processor(image.convert("RGB"), return_tensors="pt")
49
  return encoding
50
 
51
- # Define the function that will be used in each thread
52
- def call_model(func, encoding, results, index):
53
- results[index] = func(encoding)
54
-
55
- # Run all models in parallel
56
  def pipes(image):
57
  # Process the image once and reuse the encoding
58
  encoding = imageprocessing(image)
59
 
60
- # Prepare a list to store the results from each thread
61
- results = [None] * 4
62
-
63
- # Create threads for each function call
64
- threads = [
65
- threading.Thread(target=call_model, args=(topwear, encoding, results, 0)),
66
- threading.Thread(target=call_model, args=(patterns, encoding, results, 1)),
67
- threading.Thread(target=call_model, args=(prints, encoding, results, 2)),
68
- threading.Thread(target=call_model, args=(sleevelengths, encoding, results, 3)),
69
- ]
70
-
71
- # Start all threads
72
- for thread in threads:
73
- thread.start()
74
-
75
- # Wait for all threads to finish
76
- for thread in threads:
77
- thread.join()
78
 
79
  # Combine the results into a dictionary
80
- dicts = {
81
- "top": results[0],
82
- "pattern": results[1],
83
- "print": results[2],
84
- "sleeve_length": results[3]
85
  }
86
- st.write(dicts)
87
- return dicts
88
 
89
  # Streamlit app UI
90
  st.title("Clothing Classification Pipeline")
 
3
  from PIL import Image
4
  import requests
5
  from io import BytesIO
 
6
  import time
7
+ import torch
8
 
9
  # Load models and processor only once using Streamlit session state
10
  if 'models_loaded' not in st.session_state:
 
17
 
18
  # Define image processing and classification functions
19
  def topwear(encoding):
20
+ with torch.no_grad():
21
+ outputs = st.session_state.top_wear_model(**encoding)
22
  logits = outputs.logits
23
  predicted_class_idx = logits.argmax(-1).item()
24
+ st.write(f"Top Wear: {st.session_state.top_wear_model.config.id2label[predicted_class_idx]}")
25
  return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
26
 
27
  def patterns(encoding):
28
+ with torch.no_grad():
29
+ outputs = st.session_state.pattern_model(**encoding)
30
  logits = outputs.logits
31
  predicted_class_idx = logits.argmax(-1).item()
32
+ st.write(f"Pattern: {st.session_state.pattern_model.config.id2label[predicted_class_idx]}")
33
  return st.session_state.pattern_model.config.id2label[predicted_class_idx]
34
 
35
  def prints(encoding):
36
+ with torch.no_grad():
37
+ outputs = st.session_state.print_model(**encoding)
38
  logits = outputs.logits
39
  predicted_class_idx = logits.argmax(-1).item()
40
+ st.write(f"Print: {st.session_state.print_model.config.id2label[predicted_class_idx]}")
41
  return st.session_state.print_model.config.id2label[predicted_class_idx]
42
 
43
  def sleevelengths(encoding):
44
+ with torch.no_grad():
45
+ outputs = st.session_state.sleeve_length_model(**encoding)
46
  logits = outputs.logits
47
  predicted_class_idx = logits.argmax(-1).item()
48
+ st.write(f"Sleeve Length: {st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]}")
49
  return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
50
 
51
  def imageprocessing(image):
52
+ encoding = st.session_state.image_processor(images=image, return_tensors="pt")
53
  return encoding
54
 
55
+ # Run all models sequentially
 
 
 
 
56
  def pipes(image):
57
  # Process the image once and reuse the encoding
58
  encoding = imageprocessing(image)
59
 
60
+ # Get results from each model
61
+ topwear_result = topwear(encoding)
62
+ pattern_result = patterns(encoding)
63
+ print_result = prints(encoding)
64
+ sleeve_length_result = sleevelengths(encoding)
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Combine the results into a dictionary
67
+ results = {
68
+ "top": topwear_result,
69
+ "pattern": pattern_result,
70
+ "print": print_result,
71
+ "sleeve_length": sleeve_length_result
72
  }
73
+ st.write(results)
74
+ return results
75
 
76
  # Streamlit app UI
77
  st.title("Clothing Classification Pipeline")