Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
# Load the pre-trained model and image processor | |
processor = AutoImageProcessor.from_pretrained("tiya1012/vit-accident-image") | |
model = AutoModelForImageClassification.from_pretrained("tiya1012/vit-accident-image") | |
# Define a label mapping for `LABEL_0` and `LABEL_1` | |
label_mapping = { | |
"LABEL_0": "No Accident", | |
"LABEL_1": "Accident Detected" | |
} | |
# Define the classification function | |
def classify_accident_image(image): | |
# Ensure the image is provided | |
if image is None: | |
return "No image uploaded" | |
# Preprocess the image | |
inputs = processor(images=image, return_tensors="pt") | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Debug: Print logits for analysis | |
print("Logits:", logits) | |
# Get the predicted class index and label | |
probabilities = torch.softmax(logits, dim=1)[0] # Softmax to get probabilities | |
predicted_class_idx = torch.argmax(probabilities).item() | |
print("Predicted Class Index:", predicted_class_idx) | |
print("Probabilities:", probabilities) | |
# Map the model's label to human-readable label using label_mapping | |
predicted_label_key = model.config.id2label[predicted_class_idx] | |
predicted_label = label_mapping.get(predicted_label_key, "Unknown") | |
# Get the confidence score | |
confidence = probabilities[predicted_class_idx].item() * 100 | |
# Format the result | |
result = f"Prediction: {predicted_label}\nConfidence: {confidence:.2f}%" | |
return result | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=classify_accident_image, | |
inputs=gr.Image(type="pil", label="Upload Accident Image"), | |
outputs=gr.Textbox(label="Classification Result"), | |
title="Accident Image Classifier", | |
description="Upload an image to classify whether it depicts an accident or not.", | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() | |