Spaces:
Sleeping
Sleeping
File size: 2,035 Bytes
c4ba5f9 1bdac09 f0b590f c4ba5f9 f0b590f 8d94655 f0b590f 8d94655 1bdac09 f0b590f 8d94655 1bdac09 f0b590f c4ba5f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
# Load the pre-trained model and image processor
processor = AutoImageProcessor.from_pretrained("tiya1012/vit-accident-image")
model = AutoModelForImageClassification.from_pretrained("tiya1012/vit-accident-image")
# Define a label mapping for `LABEL_0` and `LABEL_1`
label_mapping = {
"LABEL_0": "No Accident",
"LABEL_1": "Accident Detected"
}
# Define the classification function
def classify_accident_image(image):
# Ensure the image is provided
if image is None:
return "No image uploaded"
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Debug: Print logits for analysis
print("Logits:", logits)
# Get the predicted class index and label
probabilities = torch.softmax(logits, dim=1)[0] # Softmax to get probabilities
predicted_class_idx = torch.argmax(probabilities).item()
print("Predicted Class Index:", predicted_class_idx)
print("Probabilities:", probabilities)
# Map the model's label to human-readable label using label_mapping
predicted_label_key = model.config.id2label[predicted_class_idx]
predicted_label = label_mapping.get(predicted_label_key, "Unknown")
# Get the confidence score
confidence = probabilities[predicted_class_idx].item() * 100
# Format the result
result = f"Prediction: {predicted_label}\nConfidence: {confidence:.2f}%"
return result
# Create Gradio interface
iface = gr.Interface(
fn=classify_accident_image,
inputs=gr.Image(type="pil", label="Upload Accident Image"),
outputs=gr.Textbox(label="Classification Result"),
title="Accident Image Classifier",
description="Upload an image to classify whether it depicts an accident or not.",
)
# Launch the interface
if __name__ == "__main__":
iface.launch()
|