|
import gradio as gr |
|
import torch |
|
import torchvision.models as models |
|
from torchvision import transforms |
|
from torch import nn |
|
from PIL import Image |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((128, 128)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
model = models.mobilenet_v3_large(pretrained=True) |
|
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 2) |
|
model = model.to("cpu") |
|
model.load_state_dict(torch.load("cnn_model.pth", weights_only=True, map_location="cpu")) |
|
model.eval() |
|
|
|
label = ["nsfw", "safe"] |
|
|
|
def inference(image): |
|
image = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(image) |
|
output = torch.nn.functional.softmax(output, dim=1) |
|
|
|
predicted_class = torch.argmax(output, dim=1).item() |
|
score = output[0][predicted_class] |
|
|
|
if label[predicted_class] == "nsfw": |
|
output = f'Boneka ini terlalu seksi dan tidak aman dilihat anak kecil (NSFW) [{label[predicted_class]}:{score}]' |
|
else: |
|
output = f'Boneka ini aman (SAFE) [{label[predicted_class]}:{score}]' |
|
|
|
return output |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
inputs = gr.Image(type="pil") |
|
with gr.Column(): |
|
btn = gr.Button("Cek") |
|
pred = gr.Text(label="Prediction") |
|
|
|
btn.click(fn=inference, inputs=inputs, outputs=pred) |
|
|
|
|
|
|
|
demo.queue().launch() |