vishalkatheriya commited on
Commit
2b22bca
·
verified ·
1 Parent(s): c48e567

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -17,37 +17,37 @@ if 'models_loaded' not in st.session_state:
17
  st.session_state.models_loaded = True
18
 
19
  # Define image processing and classification functions
20
- def topwear(encoding):
21
  with torch.no_grad():
22
- outputs = st.session_state.top_wear_model(**encoding)
23
  logits = outputs.logits
24
  predicted_class_idx = logits.argmax(-1).item()
25
- st.write(f"Top Wear: {st.session_state.top_wear_model.config.id2label[predicted_class_idx]}")
26
- return st.session_state.top_wear_model.config.id2label[predicted_class_idx]
27
 
28
- def patterns(encoding):
29
  with torch.no_grad():
30
- outputs = st.session_state.pattern_model(**encoding)
31
  logits = outputs.logits
32
  predicted_class_idx = logits.argmax(-1).item()
33
- st.write(f"Pattern: {st.session_state.pattern_model.config.id2label[predicted_class_idx]}")
34
- return st.session_state.pattern_model.config.id2label[predicted_class_idx]
35
 
36
- def prints(encoding):
37
  with torch.no_grad():
38
- outputs = st.session_state.print_model(**encoding)
39
  logits = outputs.logits
40
  predicted_class_idx = logits.argmax(-1).item()
41
- st.write(f"Print: {st.session_state.print_model.config.id2label[predicted_class_idx]}")
42
- return st.session_state.print_model.config.id2label[predicted_class_idx]
43
 
44
- def sleevelengths(encoding):
45
  with torch.no_grad():
46
- outputs = st.session_state.sleeve_length_model(**encoding)
47
  logits = outputs.logits
48
  predicted_class_idx = logits.argmax(-1).item()
49
- st.write(f"Sleeve Length: {st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]}")
50
- return st.session_state.sleeve_length_model.config.id2label[predicted_class_idx]
51
 
52
  def imageprocessing(image):
53
  encoding = st.session_state.image_processor(images=image, return_tensors="pt")
@@ -58,13 +58,19 @@ def pipes(image):
58
  # Process the image once and reuse the encoding
59
  encoding = imageprocessing(image)
60
 
 
 
 
 
 
 
61
  # Define functions to run the models in parallel
62
  with concurrent.futures.ThreadPoolExecutor() as executor:
63
  futures = {
64
- executor.submit(topwear, encoding): "topwear",
65
- executor.submit(patterns, encoding): "patterns",
66
- executor.submit(prints, encoding): "prints",
67
- executor.submit(sleevelengths, encoding): "sleeve_length"
68
  }
69
 
70
  results = {}
 
17
  st.session_state.models_loaded = True
18
 
19
  # Define image processing and classification functions
20
+ def topwear(encoding, top_wear_model):
21
  with torch.no_grad():
22
+ outputs = top_wear_model(**encoding)
23
  logits = outputs.logits
24
  predicted_class_idx = logits.argmax(-1).item()
25
+ st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
26
+ return top_wear_model.config.id2label[predicted_class_idx]
27
 
28
+ def patterns(encoding, pattern_model):
29
  with torch.no_grad():
30
+ outputs = pattern_model(**encoding)
31
  logits = outputs.logits
32
  predicted_class_idx = logits.argmax(-1).item()
33
+ st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
34
+ return pattern_model.config.id2label[predicted_class_idx]
35
 
36
+ def prints(encoding, print_model):
37
  with torch.no_grad():
38
+ outputs = print_model(**encoding)
39
  logits = outputs.logits
40
  predicted_class_idx = logits.argmax(-1).item()
41
+ st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
42
+ return print_model.config.id2label[predicted_class_idx]
43
 
44
+ def sleevelengths(encoding, sleeve_length_model):
45
  with torch.no_grad():
46
+ outputs = sleeve_length_model(**encoding)
47
  logits = outputs.logits
48
  predicted_class_idx = logits.argmax(-1).item()
49
+ st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
50
+ return sleeve_length_model.config.id2label[predicted_class_idx]
51
 
52
  def imageprocessing(image):
53
  encoding = st.session_state.image_processor(images=image, return_tensors="pt")
 
58
  # Process the image once and reuse the encoding
59
  encoding = imageprocessing(image)
60
 
61
+ # Access models from session state before threading
62
+ top_wear_model = st.session_state.top_wear_model
63
+ pattern_model = st.session_state.pattern_model
64
+ print_model = st.session_state.print_model
65
+ sleeve_length_model = st.session_state.sleeve_length_model
66
+
67
  # Define functions to run the models in parallel
68
  with concurrent.futures.ThreadPoolExecutor() as executor:
69
  futures = {
70
+ executor.submit(topwear, encoding, top_wear_model): "topwear",
71
+ executor.submit(patterns, encoding, pattern_model): "patterns",
72
+ executor.submit(prints, encoding, print_model): "prints",
73
+ executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length"
74
  }
75
 
76
  results = {}