mvp2 / app.py
CullerWhale's picture
Update app.py
41b0032 verified
import gradio as gr
import PIL.Image
import pandas as pd
import numpy as np
import boto3
from io import BytesIO, StringIO
from fastai.vision.all import *
def get_x(r): return r['Image Path']
def get_y(r): return r['Survived']
def ProjectReportSplitter(df):
valid_pct = 0.2
unique_reports = df['Project Report'].unique()
valid_reports = np.random.choice(unique_reports, int(len(unique_reports) * valid_pct), replace=False)
valid_idx = df.index[df['Project Report'].isin(valid_reports)].tolist()
train_idx = df.index[~df.index.isin(valid_idx)].tolist()
return train_idx, valid_idx
# Use a function to resolve path
def get_x_transformed(r): return open_image_from_s3(get_x(r))
dblock = DataBlock(
blocks=(ImageBlock(cls=PILImage), CategoryBlock),
splitter=ProjectReportSplitter,
get_x=get_x_transformed,
get_y=get_y,
item_tfms=Resize(460, method='pad', pad_mode='zeros'),
batch_tfms=aug_transforms(mult=2, do_flip=True, max_rotate=20, max_zoom=1.1, max_warp=0.2)
)
# Load your model
learn = load_learner("templateClassifierDATAhalfEPOCHoneVISION.pkl")
# Print the vocabulary of the model
print("Model Vocabulary:", learn.dls.vocab)
labels = learn.dls.vocab
def predict(img):
img = PILImage.create(img)
pred,pred_idx,probs = learn.predict(img)
return {labels[i]: float(probs[i]) for i in range(len(labels))}
# def predict(img):
# img = PILImage.create(img)
# pred, pred_idx, probs = learn.predict(img)
# results = {labels[i]: float(probs[i]) for i in range(len(labels))}
# # Adjust results to highlight when 'Survived' meets the 75% threshold
# if results['Survived'] >= 0.75:
# results['Survived'] = 1.0 # Indicating high confidence of survival
# else:
# results['Survived'] = 0.0 # Indicating it did not meet the threshold
# return results
# def predict(img):
# img = PILImage.create(img)
# pred, pred_idx, probs = learn.predict(img)
# results = {labels[i]: float(probs[i]) for i in range(len(labels))}
# # Adjusting to display survival status based on the threshold
# survival_status = 'Survived' if results['Survived'] >= 0.75 else 'Not Survived'
# results['Survival Status'] = survival_status
# return results
# Gradio interface setup
title = "Photo Culling AI"
description = "Upload your photo to check if it survives culling."
article = "This interface uses a model trained to predict whether a photo is relevant for a project report."
gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=2), title=title, description=description, article=article).launch(share=True,show_error=True)
# import gradio as gr
# import PIL.Image
# import pandas as pd
# import boto3
# from io import BytesIO, StringIO
# from fastai.vision.all import *
# def get_x(r): return r['Image Path']
# def get_y(r): return r['Survived']
# def ProjectReportSplitter(df):
# valid_pct = 0.2
# unique_reports = df['Project Report'].unique()
# valid_reports = np.random.choice(unique_reports, int(len(unique_reports) * valid_pct), replace=False)
# valid_idx = df.index[df['Project Report'].isin(valid_reports)].tolist()
# train_idx = df.index[~df.index.isin(valid_idx)].tolist()
# return train_idx, valid_idx
# # Use a function to resolve path
# def get_x_transformed(r): return open_image_from_s3(get_x(r))
# dblock = DataBlock(
# blocks=(ImageBlock(cls=PILImage), CategoryBlock),
# splitter=ProjectReportSplitter,
# get_x=get_x_transformed,
# get_y=get_y,
# item_tfms=Resize(460, method='pad', pad_mode='zeros'),
# batch_tfms=aug_transforms(mult=2, do_flip=True, max_rotate=20, max_zoom=1.1, max_warp=0.2)
# )
# # Load your model
# learn = load_learner("templateClassifierDATAhalfEPOCHoneVISION.pkl")
# # Print the vocabulary of the model
# print("Model Vocabulary:", learn.dls.vocab)
# # Update prediction function to directly read from S3
# def predict(img_path):
# pred, pred_idx, probs = learn.predict(img_path)
# return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}
# # Gradio interface setup
# title = "Photo Culling AI"
# description = "Upload your photo to check if it survives culling."
# article = "This interface uses a model trained to predict whether a photo is relevant for a project report."
# gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=2), title=title, description=description, article=article).launch(share=True)