Spaces:
Runtime error
Runtime error
import os | |
import glob | |
import time | |
import threading | |
import requests | |
import wikipedia | |
import torch | |
import cv2 | |
import numpy as np | |
from io import BytesIO | |
from PIL import Image | |
import base64 # Added import | |
import gradio as gr | |
from ultralytics import YOLO | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from diffusers import MarigoldDepthPipeline # Updated import for depth model | |
from realesrgan import RealESRGANer | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
# Set environment variable for PyTorch MPS fallback before importing torch | |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | |
# Initialize Models | |
def initialize_models(): | |
models = {} | |
# Device detection | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
elif torch.backends.mps.is_available(): | |
device = 'mps' | |
else: | |
device = 'cpu' | |
models['device'] = device | |
print(f"Using device: {device}") | |
# Initialize the RoBERTa model for question answering | |
try: | |
models['qa_pipeline'] = pipeline( | |
"question-answering", model="deepset/roberta-base-squad2", device=0 if device == 'cuda' else -1) | |
print("RoBERTa QA pipeline initialized.") | |
except Exception as e: | |
print(f"Error initializing the RoBERTa model: {e}") | |
models['qa_pipeline'] = None | |
# Initialize the Gemma model | |
try: | |
models['gemma_tokenizer'] = AutoTokenizer.from_pretrained("google/gemma-2-2b-it") | |
models['gemma_model'] = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2-2b-it", | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32 | |
) | |
print("Gemma model initialized.") | |
except Exception as e: | |
print(f"Error initializing the Gemma model: {e}") | |
models['gemma_model'] = None | |
# Initialize the depth estimation model using MarigoldDepthPipeline exactly as per your sample | |
try: | |
if device == 'cuda': | |
models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained( | |
"prs-eth/marigold-depth-lcm-v1-0", | |
variant="fp16", | |
torch_dtype=torch.float16 | |
).to('cuda') | |
else: | |
# For CPU or MPS devices, keep on 'cpu' to avoid unsupported operators | |
models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained( | |
"prs-eth/marigold-depth-lcm-v1-0", | |
torch_dtype=torch.float32 | |
).to('cpu') | |
print("Depth estimation model initialized.") | |
except Exception as e: | |
error_message = f"Error initializing the depth estimation model: {e}" | |
print(error_message) | |
models['depth_pipe'] = None | |
models['depth_init_error'] = error_message # Store the error message | |
# Initialize the upscaling model | |
try: | |
upscaler_model_path = 'weights/RealESRGAN_x4plus.pth' # Ensure this path is correct | |
if not os.path.exists(upscaler_model_path): | |
print(f"Upscaling model weights not found at {upscaler_model_path}. Please download them.") | |
models['upscaler'] = None | |
else: | |
# Define the model architecture | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, | |
num_block=23, num_grow_ch=32, scale=4) | |
# Initialize RealESRGANer | |
models['upscaler'] = RealESRGANer( | |
scale=4, | |
model_path=upscaler_model_path, | |
model=model, | |
pre_pad=0, | |
half=(device == 'cuda'), | |
device=device | |
) | |
print("Real-ESRGAN upscaler initialized.") | |
except Exception as e: | |
print(f"Error initializing the upscaling model: {e}") | |
models['upscaler'] = None | |
# Initialize YOLO model | |
try: | |
source_weights_path = "/Users/David/Downloads/WheelOfFortuneLab-DavidDriscoll/Eurybia1.3/mbari_315k_yolov8.pt" | |
if not os.path.exists(source_weights_path): | |
print(f"YOLO weights not found at {source_weights_path}. Please download them.") | |
models['yolo_model'] = None | |
else: | |
models['yolo_model'] = YOLO(source_weights_path) | |
print("YOLO model initialized.") | |
except Exception as e: | |
print(f"Error initializing YOLO model: {e}") | |
models['yolo_model'] = None | |
return models | |
models = initialize_models() | |
# Utility Functions | |
def search_class_description(class_name): | |
wikipedia.set_lang("en") | |
wikipedia.set_rate_limiting(True) | |
description = "" | |
try: | |
page = wikipedia.page(class_name) | |
if page: | |
description = page.content[:5000] # Get more content | |
except Exception as e: | |
print(f"Error fetching description for {class_name}: {e}") | |
return description | |
def search_class_image(class_name): | |
wikipedia.set_lang("en") | |
wikipedia.set_rate_limiting(True) | |
img_url = "" | |
try: | |
page = wikipedia.page(class_name) | |
if page: | |
for img in page.images: | |
if img.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')): | |
img_url = img | |
break | |
except Exception as e: | |
print(f"Error fetching image for {class_name}: {e}") | |
return img_url | |
def process_image(image): | |
if models['yolo_model'] is None: | |
return None, "YOLO model is not initialized.", "YOLO model is not initialized.", [], None | |
try: | |
if image is None: | |
return None, "No image uploaded.", "No image uploaded.", [], None | |
# Convert Gradio Image to OpenCV format | |
image_np = np.array(image) | |
if image_np.dtype != np.uint8: | |
image_np = image_np.astype(np.uint8) | |
if len(image_np.shape) != 3 or image_np.shape[2] != 3: | |
return None, "Invalid image format. Please upload a RGB image.", "Invalid image format. Please upload a RGB image.", [], None | |
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
# Store the original image before drawing bounding boxes | |
original_image_cv = image_cv.copy() | |
original_image_pil = Image.fromarray(cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB)) | |
# Perform YOLO prediction | |
results = models['yolo_model'].predict( | |
source=image_cv, conf=0.075)[0] # Lowered the threshold | |
bounding_boxes = [] | |
image_processed = image_cv.copy() | |
if results.boxes is not None: | |
for box in results.boxes: | |
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
class_name = models['yolo_model'].names[int(box.cls)] | |
confidence = box.conf.item() * 100 # Convert to percentage | |
bounding_boxes.append({ | |
"coords": (x1, y1, x2, y2), | |
"class_name": class_name, | |
"confidence": confidence | |
}) | |
cv2.rectangle(image_processed, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
cv2.putText(image_processed, f'{class_name} {confidence:.2f}%', | |
(x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, | |
0.9, (0, 0, 255), 2) | |
# Convert back to PIL Image | |
processed_image = Image.fromarray(cv2.cvtColor(image_processed, cv2.COLOR_BGR2RGB)) | |
# Prepare detection info | |
if bounding_boxes: | |
detection_info = "\n".join( | |
[f'{box["class_name"]}: {box["confidence"]:.2f}%' for box in bounding_boxes] | |
) | |
else: | |
detection_info = "No detections found." | |
# Prepare detection details as Markdown | |
if bounding_boxes: | |
details = "" | |
for idx, box in enumerate(bounding_boxes): | |
class_name = box['class_name'] | |
confidence = box['confidence'] | |
description = search_class_description(class_name) | |
img_url = search_class_image(class_name) | |
img_md = "" | |
if img_url: | |
try: | |
headers = { | |
'User-Agent': 'MyApp/1.0 (https://example.com/contact; [email protected])' | |
} | |
response = requests.get(img_url, headers=headers, timeout=10) | |
img_data = response.content | |
img = Image.open(BytesIO(img_data)).convert("RGB") | |
img.thumbnail((400, 400)) # Resize for faster loading | |
buffered = BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
img_md = f"![{class_name}](data:image/png;base64,{img_str})\n\n" | |
except Exception as e: | |
print(f"Error fetching image for {class_name}: {e}") | |
details += f"### {idx+1}. {class_name} ({confidence:.2f}%)\n\n" | |
if description: | |
details += f"{description}\n\n" | |
if img_md: | |
details += f"{img_md}\n\n" | |
detection_details_md = details | |
else: | |
detection_details_md = "No detections to show." | |
return processed_image, detection_info, detection_details_md, bounding_boxes, original_image_pil | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return None, f"Error processing image: {e}", f"Error processing image: {e}", [], None | |
def ask_eurybia(question, state): | |
if not question.strip(): | |
return "Please enter a valid question.", state | |
if not state['bounding_boxes']: | |
return "No detected objects to ask about.", state | |
# Combine descriptions of all detected objects as context | |
context = "" | |
for box in state['bounding_boxes']: | |
description = search_class_description(box['class_name']) | |
if description: | |
context += description + "\n" | |
if not context.strip(): | |
return "No sufficient context available to answer the question.", state | |
try: | |
if models['qa_pipeline'] is None: | |
return "QA pipeline is not initialized.", state | |
answer = models['qa_pipeline'](question=question, context=context) | |
answer_text = answer['answer'].strip() | |
if not answer_text: | |
return "I couldn't find an answer to that question based on the detected objects.", state | |
return answer_text, state | |
except Exception as e: | |
print(f"Error during question answering: {e}") | |
return f"Error during question answering: {e}", state | |
def enhance_image(cropped_image_pil): | |
if models['upscaler'] is None: | |
return None, "Upscaling model is not initialized." | |
try: | |
input_image = cropped_image_pil.convert("RGB") | |
img = np.array(input_image) | |
# Run the model to enhance the image | |
output, _ = models['upscaler'].enhance(img, outscale=4) | |
enhanced_image = Image.fromarray(output) | |
return enhanced_image, "Image enhanced successfully." | |
except Exception as e: | |
print(f"Error during image enhancement: {e}") | |
return None, f"Error during image enhancement: {e}" | |
def run_depth_prediction(original_image): | |
if models['depth_pipe'] is None: | |
error_msg = models.get('depth_init_error', "Depth estimation model is not initialized.") | |
return None, error_msg | |
try: | |
if original_image is None: | |
return None, "No image uploaded for depth prediction." | |
# Prepare the image | |
input_image = original_image.convert("RGB") | |
# Run the depth pipeline | |
result = models['depth_pipe'](input_image) | |
# Access the depth prediction | |
depth_prediction = result.prediction # Adjust based on sample code | |
# Visualize the depth map | |
vis_depth = models['depth_pipe'].image_processor.visualize_depth(depth_prediction) | |
# Ensure vis_depth is a list and extract the first image | |
if isinstance(vis_depth, list) and len(vis_depth) > 0: | |
vis_depth_image = vis_depth[0] | |
else: | |
vis_depth_image = vis_depth # Fallback if not a list | |
return vis_depth_image, "Depth prediction completed." | |
except Exception as e: | |
print(f"Error during depth prediction: {e}") | |
return None, f"Error during depth prediction: {e}" | |
# Gradio Interface Components | |
with gr.Blocks() as demo: | |
gr.Markdown("# Eurybia Mini - Object Detection and Analysis Tool") | |
with gr.Tab("Upload & Process"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
process_button = gr.Button("Process Image") | |
clear_button = gr.Button("Clear") | |
with gr.Column(): | |
processed_image = gr.Image(type="pil", label="Processed Image") | |
detection_info = gr.Textbox(label="Detection Information", lines=10) | |
with gr.Tab("Detection Details"): | |
with gr.Accordion("Click to see detection details", open=False): | |
detection_details_md = gr.Markdown("No detections to show.") | |
with gr.Tab("Ask Eurybia"): | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox(label="Ask a question about the detected objects") | |
ask_button = gr.Button("Ask Eurybia") | |
with gr.Column(): | |
answer_output = gr.Markdown(label="Eurybia's Answer") | |
with gr.Tab("Depth Estimation"): | |
with gr.Row(): | |
with gr.Column(): | |
depth_button = gr.Button("Run Depth Prediction") | |
with gr.Column(): | |
depth_output = gr.Image(type="pil", label="Depth Map") | |
depth_status = gr.Textbox(label="Status", lines=2) | |
# Display error message if depth estimation model failed to initialize | |
if models.get('depth_init_error'): | |
gr.Markdown(f"**Depth Estimation Initialization Error:** {models['depth_init_error']}") | |
with gr.Tab("Enhance Detected Objects"): | |
if models['yolo_model'] is not None and models['upscaler'] is not None: | |
with gr.Row(): | |
detected_objects = gr.Dropdown(choices=[], label="Select Detected Object", interactive=True) | |
enhance_btn = gr.Button("Enhance Image") | |
with gr.Column(): | |
enhanced_image = gr.Image(type="pil", label="Enhanced Image") | |
enhance_status = gr.Textbox(label="Status", lines=2) | |
else: | |
gr.Markdown("**Warning:** YOLO model or Upscaling model is not initialized. Image enhancement functionality will be unavailable.") | |
with gr.Tab("Credits"): | |
gr.Markdown(""" | |
# Credits and Licensing Information | |
This project utilizes various open-source libraries, tools, pretrained models, and datasets. Below is the list of components used and their respective credits/licenses: | |
## Libraries | |
- **Python** - Python Software Foundation License (PSF License) | |
- **Gradio** - Licensed under the Apache License 2.0 | |
- **Torch (PyTorch)** - Licensed under the BSD 3-Clause License | |
- **OpenCV (cv2)** - Licensed under the Apache License 2.0 | |
- **NumPy** - Licensed under the BSD License | |
- **Pillow (PIL)** - Licensed under the HPND License | |
- **Requests** - Licensed under the Apache License 2.0 | |
- **Wikipedia API** - Licensed under the MIT License | |
- **Transformers** - Licensed under the Apache License 2.0 | |
- **Diffusers** - Licensed under the Apache License 2.0 | |
- **Real-ESRGAN** - Licensed under the MIT License | |
- **BasicSR** - Licensed under the Apache License 2.0 | |
- **Ultralytics YOLO** - Licensed under the GPL-3.0 License | |
## Pretrained Models | |
- **deepset/roberta-base-squad2 (RoBERTa)** - Model provided by Hugging Face under the Apache License 2.0. | |
- **google/gemma-2-2b-it** - Model provided by Hugging Face under the Apache License 2.0. | |
- **prs-eth/marigold-depth-lcm-v1-0** - Licensed under the Apache License 2.0. | |
- **Real-ESRGAN model weights (RealESRGAN_x4plus.pth)** - Distributed under the MIT License. | |
- **FathomNet MBARI 315K YOLOv8 Model**: | |
- **Dataset**: Sourced from [FathomNet](https://fathomnet.org). | |
- **Model**: Derived from MBARI’s curated dataset of 315,000 marine annotations. | |
- **License**: Dataset and models adhere to MBARI’s Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0). | |
## Datasets | |
- **FathomNet MBARI Dataset**: | |
- A large-scale dataset for marine biodiversity image annotations. | |
- All content adheres to the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). | |
## Acknowledgments | |
- **Ultralytics YOLO**: For the YOLOv8 architecture used for object detection. | |
- **FathomNet and MBARI**: For providing the marine dataset and annotations that support object detection in underwater imagery. | |
- **Gradio**: For providing an intuitive interface for machine learning applications. | |
- **Hugging Face**: For pretrained models and pipelines (e.g., Transformers, Diffusers). | |
- **Real-ESRGAN**: For image enhancement and upscaling models. | |
- **Wikipedia API**: For fetching object descriptions and images. | |
""") | |
# Hidden state to store bounding boxes, original and processed images | |
state = gr.State({"bounding_boxes": [], "last_image": None, "original_image": None}) | |
# Event Handlers | |
def on_process_image(image, state): | |
processed_img, info, details, bounding_boxes, original_image_pil = process_image(image) | |
if processed_img is not None: | |
# Update the state with new bounding boxes and images | |
state['bounding_boxes'] = bounding_boxes | |
state['last_image'] = processed_img | |
state['original_image'] = original_image_pil | |
# Update the dropdown choices for detected objects | |
choices = [f"{idx+1}. {box['class_name']} ({box['confidence']:.2f}%)" for idx, box in enumerate(bounding_boxes)] | |
else: | |
choices = [] | |
return processed_img, info, details, gr.update(choices=choices), state | |
process_button.click( | |
on_process_image, | |
inputs=[image_input, state], | |
outputs=[processed_image, detection_info, detection_details_md, detected_objects, state] | |
) | |
def on_clear(state): | |
state = {"bounding_boxes": [], "last_image": None, "original_image": None} | |
return None, "No detections found.", "No detections to show.", gr.update(choices=[]), state | |
clear_button.click( | |
on_clear, | |
inputs=state, | |
outputs=[processed_image, detection_info, detection_details_md, detected_objects, state] | |
) | |
def on_ask_eurybia(question, state): | |
answer, state = ask_eurybia(question, state) | |
return answer, state | |
ask_button.click( | |
on_ask_eurybia, | |
inputs=[question_input, state], | |
outputs=[answer_output, state] | |
) | |
def on_depth_prediction(state): | |
original_image = state.get('original_image') | |
depth_img, status = run_depth_prediction(original_image) | |
return depth_img, status | |
depth_button.click( | |
on_depth_prediction, | |
inputs=state, | |
outputs=[depth_output, depth_status] | |
) | |
def on_enhance_image(selected_object, state): | |
if not selected_object: | |
return None, "No object selected.", state | |
try: | |
idx = int(selected_object.split('.')[0]) - 1 | |
box = state['bounding_boxes'][idx] | |
class_name = box['class_name'] | |
x1, y1, x2, y2 = box['coords'] | |
if not state.get('last_image'): | |
return None, "Processed image is not available.", state | |
# Ensure processed_image is stored in state | |
processed_img_pil = state['last_image'] | |
if not isinstance(processed_img_pil, Image.Image): | |
return None, "Processed image is in an unsupported format.", state | |
# Convert processed_image to OpenCV format with checks | |
processed_img_cv = np.array(processed_img_pil) | |
if processed_img_cv.dtype != np.uint8: | |
processed_img_cv = processed_img_cv.astype(np.uint8) | |
if len(processed_img_cv.shape) != 3 or processed_img_cv.shape[2] != 3: | |
return None, "Invalid processed image format.", state | |
processed_img_cv = cv2.cvtColor(processed_img_cv, cv2.COLOR_RGB2BGR) | |
# Crop the detected object from the processed image | |
cropped_img_cv = processed_img_cv[y1:y2, x1:x2] | |
if cropped_img_cv.size == 0: | |
return None, "Cropped image is empty.", state | |
cropped_img_pil = Image.fromarray(cv2.cvtColor(cropped_img_cv, cv2.COLOR_BGR2RGB)) | |
# Enhance the cropped image | |
enhanced_img, status = enhance_image(cropped_img_pil) | |
return enhanced_img, status, state | |
except Exception as e: | |
return None, f"Error: {e}", state | |
if models['yolo_model'] is not None and models['upscaler'] is not None: | |
enhance_btn.click( | |
on_enhance_image, | |
inputs=[detected_objects, state], | |
outputs=[enhanced_image, enhance_status, state] | |
) | |
# Optional: Add a note if the depth model isn't initialized | |
if models['depth_pipe'] is None and not models.get('depth_init_error'): | |
gr.Markdown("**Warning:** Depth estimation model is not initialized. Depth prediction functionality will be unavailable.") | |
# Optional: Add a note if the upscaler isn't initialized | |
if models['upscaler'] is None: | |
gr.Markdown("**Warning:** Upscaling model is not initialized. Image enhancement functionality will be unavailable.") | |
# Launch the Gradio app | |
demo.launch() | |