Boboiazumi's picture
Update app.py
60e4e39 verified
raw
history blame
1.46 kB
import gradio as gr
import torch
import torchvision.models as models
from torchvision import transforms
from torch import nn
from PIL import Image
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
model = models.mobilenet_v3_large(pretrained=True)
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 2)
model = model.to("cpu")
model.load_state_dict(torch.load("cnn_model.pth", weights_only=True, map_location="cpu"))
model.eval()
label = ["nsfw", "safe"]
def inference(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
output = torch.nn.functional.softmax(output, dim=1)
predicted_class = torch.argmax(output, dim=1).item()
score = output[0][predicted_class]
if label[predicted_class] == "nsfw":
output = f'Boneka ini terlalu seksi dan tidak aman dilihat anak kecil (NSFW) [{label[predicted_class]}:{score}]'
else:
output = f'Boneka ini aman (SAFE) [{label[predicted_class]}:{score}]'
return output
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
inputs = gr.Image(type="pil")
with gr.Column():
btn = gr.Button("Cek")
pred = gr.Text(label="Prediction")
btn.click(fn=inference, inputs=inputs, outputs=pred)
demo.queue().launch()