tiya1012's picture
Update app.py
1bdac09 verified
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
# Load the pre-trained model and image processor
processor = AutoImageProcessor.from_pretrained("tiya1012/vit-accident-image")
model = AutoModelForImageClassification.from_pretrained("tiya1012/vit-accident-image")
# Define a label mapping for `LABEL_0` and `LABEL_1`
label_mapping = {
"LABEL_0": "No Accident",
"LABEL_1": "Accident Detected"
}
# Define the classification function
def classify_accident_image(image):
# Ensure the image is provided
if image is None:
return "No image uploaded"
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Debug: Print logits for analysis
print("Logits:", logits)
# Get the predicted class index and label
probabilities = torch.softmax(logits, dim=1)[0] # Softmax to get probabilities
predicted_class_idx = torch.argmax(probabilities).item()
print("Predicted Class Index:", predicted_class_idx)
print("Probabilities:", probabilities)
# Map the model's label to human-readable label using label_mapping
predicted_label_key = model.config.id2label[predicted_class_idx]
predicted_label = label_mapping.get(predicted_label_key, "Unknown")
# Get the confidence score
confidence = probabilities[predicted_class_idx].item() * 100
# Format the result
result = f"Prediction: {predicted_label}\nConfidence: {confidence:.2f}%"
return result
# Create Gradio interface
iface = gr.Interface(
fn=classify_accident_image,
inputs=gr.Image(type="pil", label="Upload Accident Image"),
outputs=gr.Textbox(label="Classification Result"),
title="Accident Image Classifier",
description="Upload an image to classify whether it depicts an accident or not.",
)
# Launch the interface
if __name__ == "__main__":
iface.launch()