CullerWhale commited on
Commit
41b0032
·
verified ·
1 Parent(s): bdf89a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py CHANGED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import PIL.Image
3
+ import pandas as pd
4
+ import numpy as np
5
+ import boto3
6
+ from io import BytesIO, StringIO
7
+ from fastai.vision.all import *
8
+
9
+ def get_x(r): return r['Image Path']
10
+ def get_y(r): return r['Survived']
11
+
12
+ def ProjectReportSplitter(df):
13
+ valid_pct = 0.2
14
+ unique_reports = df['Project Report'].unique()
15
+ valid_reports = np.random.choice(unique_reports, int(len(unique_reports) * valid_pct), replace=False)
16
+ valid_idx = df.index[df['Project Report'].isin(valid_reports)].tolist()
17
+ train_idx = df.index[~df.index.isin(valid_idx)].tolist()
18
+ return train_idx, valid_idx
19
+
20
+ # Use a function to resolve path
21
+ def get_x_transformed(r): return open_image_from_s3(get_x(r))
22
+
23
+ dblock = DataBlock(
24
+ blocks=(ImageBlock(cls=PILImage), CategoryBlock),
25
+ splitter=ProjectReportSplitter,
26
+ get_x=get_x_transformed,
27
+ get_y=get_y,
28
+ item_tfms=Resize(460, method='pad', pad_mode='zeros'),
29
+ batch_tfms=aug_transforms(mult=2, do_flip=True, max_rotate=20, max_zoom=1.1, max_warp=0.2)
30
+ )
31
+
32
+ # Load your model
33
+ learn = load_learner("templateClassifierDATAhalfEPOCHoneVISION.pkl")
34
+
35
+ # Print the vocabulary of the model
36
+ print("Model Vocabulary:", learn.dls.vocab)
37
+
38
+ labels = learn.dls.vocab
39
+ def predict(img):
40
+ img = PILImage.create(img)
41
+ pred,pred_idx,probs = learn.predict(img)
42
+ return {labels[i]: float(probs[i]) for i in range(len(labels))}
43
+
44
+ # def predict(img):
45
+ # img = PILImage.create(img)
46
+ # pred, pred_idx, probs = learn.predict(img)
47
+ # results = {labels[i]: float(probs[i]) for i in range(len(labels))}
48
+ # # Adjust results to highlight when 'Survived' meets the 75% threshold
49
+ # if results['Survived'] >= 0.75:
50
+ # results['Survived'] = 1.0 # Indicating high confidence of survival
51
+ # else:
52
+ # results['Survived'] = 0.0 # Indicating it did not meet the threshold
53
+ # return results
54
+ # def predict(img):
55
+ # img = PILImage.create(img)
56
+ # pred, pred_idx, probs = learn.predict(img)
57
+ # results = {labels[i]: float(probs[i]) for i in range(len(labels))}
58
+ # # Adjusting to display survival status based on the threshold
59
+ # survival_status = 'Survived' if results['Survived'] >= 0.75 else 'Not Survived'
60
+ # results['Survival Status'] = survival_status
61
+ # return results
62
+
63
+
64
+ # Gradio interface setup
65
+ title = "Photo Culling AI"
66
+ description = "Upload your photo to check if it survives culling."
67
+ article = "This interface uses a model trained to predict whether a photo is relevant for a project report."
68
+ gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=2), title=title, description=description, article=article).launch(share=True,show_error=True)
69
+
70
+
71
+
72
+
73
+ # import gradio as gr
74
+ # import PIL.Image
75
+ # import pandas as pd
76
+ # import boto3
77
+ # from io import BytesIO, StringIO
78
+ # from fastai.vision.all import *
79
+
80
+
81
+ # def get_x(r): return r['Image Path']
82
+
83
+ # def get_y(r): return r['Survived']
84
+
85
+ # def ProjectReportSplitter(df):
86
+ # valid_pct = 0.2
87
+ # unique_reports = df['Project Report'].unique()
88
+ # valid_reports = np.random.choice(unique_reports, int(len(unique_reports) * valid_pct), replace=False)
89
+ # valid_idx = df.index[df['Project Report'].isin(valid_reports)].tolist()
90
+ # train_idx = df.index[~df.index.isin(valid_idx)].tolist()
91
+ # return train_idx, valid_idx
92
+
93
+
94
+ # # Use a function to resolve path
95
+ # def get_x_transformed(r): return open_image_from_s3(get_x(r))
96
+
97
+ # dblock = DataBlock(
98
+ # blocks=(ImageBlock(cls=PILImage), CategoryBlock),
99
+ # splitter=ProjectReportSplitter,
100
+ # get_x=get_x_transformed,
101
+ # get_y=get_y,
102
+ # item_tfms=Resize(460, method='pad', pad_mode='zeros'),
103
+ # batch_tfms=aug_transforms(mult=2, do_flip=True, max_rotate=20, max_zoom=1.1, max_warp=0.2)
104
+ # )
105
+
106
+ # # Load your model
107
+ # learn = load_learner("templateClassifierDATAhalfEPOCHoneVISION.pkl")
108
+
109
+ # # Print the vocabulary of the model
110
+ # print("Model Vocabulary:", learn.dls.vocab)
111
+
112
+
113
+ # # Update prediction function to directly read from S3
114
+ # def predict(img_path):
115
+ # pred, pred_idx, probs = learn.predict(img_path)
116
+ # return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}
117
+
118
+ # # Gradio interface setup
119
+ # title = "Photo Culling AI"
120
+ # description = "Upload your photo to check if it survives culling."
121
+ # article = "This interface uses a model trained to predict whether a photo is relevant for a project report."
122
+ # gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=2), title=title, description=description, article=article).launch(share=True)