import os import glob import time import threading import requests import wikipedia import torch import cv2 import numpy as np from io import BytesIO from PIL import Image import base64 # Added import import gradio as gr from ultralytics import YOLO from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from diffusers import MarigoldDepthPipeline # Updated import for depth model from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet # Set environment variable for PyTorch MPS fallback before importing torch os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # Initialize Models def initialize_models(): models = {} # Device detection if torch.cuda.is_available(): device = 'cuda' elif torch.backends.mps.is_available(): device = 'mps' else: device = 'cpu' models['device'] = device print(f"Using device: {device}") # Initialize the RoBERTa model for question answering try: models['qa_pipeline'] = pipeline( "question-answering", model="deepset/roberta-base-squad2", device=0 if device == 'cuda' else -1) print("RoBERTa QA pipeline initialized.") except Exception as e: print(f"Error initializing the RoBERTa model: {e}") models['qa_pipeline'] = None # Initialize the Gemma model try: models['gemma_tokenizer'] = AutoTokenizer.from_pretrained("google/gemma-2-2b-it") models['gemma_model'] = AutoModelForCausalLM.from_pretrained( "google/gemma-2-2b-it", device_map="auto", torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32 ) print("Gemma model initialized.") except Exception as e: print(f"Error initializing the Gemma model: {e}") models['gemma_model'] = None # Initialize the depth estimation model using MarigoldDepthPipeline exactly as per your sample try: if device == 'cuda': models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained( "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 ).to('cuda') else: # For CPU or MPS devices, keep on 'cpu' to avoid unsupported operators models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained( "prs-eth/marigold-depth-lcm-v1-0", torch_dtype=torch.float32 ).to('cpu') print("Depth estimation model initialized.") except Exception as e: error_message = f"Error initializing the depth estimation model: {e}" print(error_message) models['depth_pipe'] = None models['depth_init_error'] = error_message # Store the error message # Initialize the upscaling model try: upscaler_model_path = 'weights/RealESRGAN_x4plus.pth' # Ensure this path is correct if not os.path.exists(upscaler_model_path): print(f"Upscaling model weights not found at {upscaler_model_path}. Please download them.") models['upscaler'] = None else: # Define the model architecture model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) # Initialize RealESRGANer models['upscaler'] = RealESRGANer( scale=4, model_path=upscaler_model_path, model=model, pre_pad=0, half=(device == 'cuda'), device=device ) print("Real-ESRGAN upscaler initialized.") except Exception as e: print(f"Error initializing the upscaling model: {e}") models['upscaler'] = None # Initialize YOLO model try: source_weights_path = "/Users/David/Downloads/WheelOfFortuneLab-DavidDriscoll/Eurybia1.3/mbari_315k_yolov8.pt" if not os.path.exists(source_weights_path): print(f"YOLO weights not found at {source_weights_path}. Please download them.") models['yolo_model'] = None else: models['yolo_model'] = YOLO(source_weights_path) print("YOLO model initialized.") except Exception as e: print(f"Error initializing YOLO model: {e}") models['yolo_model'] = None return models models = initialize_models() # Utility Functions def search_class_description(class_name): wikipedia.set_lang("en") wikipedia.set_rate_limiting(True) description = "" try: page = wikipedia.page(class_name) if page: description = page.content[:5000] # Get more content except Exception as e: print(f"Error fetching description for {class_name}: {e}") return description def search_class_image(class_name): wikipedia.set_lang("en") wikipedia.set_rate_limiting(True) img_url = "" try: page = wikipedia.page(class_name) if page: for img in page.images: if img.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')): img_url = img break except Exception as e: print(f"Error fetching image for {class_name}: {e}") return img_url def process_image(image): if models['yolo_model'] is None: return None, "YOLO model is not initialized.", "YOLO model is not initialized.", [], None try: if image is None: return None, "No image uploaded.", "No image uploaded.", [], None # Convert Gradio Image to OpenCV format image_np = np.array(image) if image_np.dtype != np.uint8: image_np = image_np.astype(np.uint8) if len(image_np.shape) != 3 or image_np.shape[2] != 3: return None, "Invalid image format. Please upload a RGB image.", "Invalid image format. Please upload a RGB image.", [], None image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # Store the original image before drawing bounding boxes original_image_cv = image_cv.copy() original_image_pil = Image.fromarray(cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB)) # Perform YOLO prediction results = models['yolo_model'].predict( source=image_cv, conf=0.075)[0] # Lowered the threshold bounding_boxes = [] image_processed = image_cv.copy() if results.boxes is not None: for box in results.boxes: x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) class_name = models['yolo_model'].names[int(box.cls)] confidence = box.conf.item() * 100 # Convert to percentage bounding_boxes.append({ "coords": (x1, y1, x2, y2), "class_name": class_name, "confidence": confidence }) cv2.rectangle(image_processed, (x1, y1), (x2, y2), (0, 0, 255), 2) cv2.putText(image_processed, f'{class_name} {confidence:.2f}%', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2) # Convert back to PIL Image processed_image = Image.fromarray(cv2.cvtColor(image_processed, cv2.COLOR_BGR2RGB)) # Prepare detection info if bounding_boxes: detection_info = "\n".join( [f'{box["class_name"]}: {box["confidence"]:.2f}%' for box in bounding_boxes] ) else: detection_info = "No detections found." # Prepare detection details as Markdown if bounding_boxes: details = "" for idx, box in enumerate(bounding_boxes): class_name = box['class_name'] confidence = box['confidence'] description = search_class_description(class_name) img_url = search_class_image(class_name) img_md = "" if img_url: try: headers = { 'User-Agent': 'MyApp/1.0 (https://example.com/contact; myemail@example.com)' } response = requests.get(img_url, headers=headers, timeout=10) img_data = response.content img = Image.open(BytesIO(img_data)).convert("RGB") img.thumbnail((400, 400)) # Resize for faster loading buffered = BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() img_md = f"![{class_name}](data:image/png;base64,{img_str})\n\n" except Exception as e: print(f"Error fetching image for {class_name}: {e}") details += f"### {idx+1}. {class_name} ({confidence:.2f}%)\n\n" if description: details += f"{description}\n\n" if img_md: details += f"{img_md}\n\n" detection_details_md = details else: detection_details_md = "No detections to show." return processed_image, detection_info, detection_details_md, bounding_boxes, original_image_pil except Exception as e: print(f"Error processing image: {e}") return None, f"Error processing image: {e}", f"Error processing image: {e}", [], None def ask_eurybia(question, state): if not question.strip(): return "Please enter a valid question.", state if not state['bounding_boxes']: return "No detected objects to ask about.", state # Combine descriptions of all detected objects as context context = "" for box in state['bounding_boxes']: description = search_class_description(box['class_name']) if description: context += description + "\n" if not context.strip(): return "No sufficient context available to answer the question.", state try: if models['qa_pipeline'] is None: return "QA pipeline is not initialized.", state answer = models['qa_pipeline'](question=question, context=context) answer_text = answer['answer'].strip() if not answer_text: return "I couldn't find an answer to that question based on the detected objects.", state return answer_text, state except Exception as e: print(f"Error during question answering: {e}") return f"Error during question answering: {e}", state def enhance_image(cropped_image_pil): if models['upscaler'] is None: return None, "Upscaling model is not initialized." try: input_image = cropped_image_pil.convert("RGB") img = np.array(input_image) # Run the model to enhance the image output, _ = models['upscaler'].enhance(img, outscale=4) enhanced_image = Image.fromarray(output) return enhanced_image, "Image enhanced successfully." except Exception as e: print(f"Error during image enhancement: {e}") return None, f"Error during image enhancement: {e}" def run_depth_prediction(original_image): if models['depth_pipe'] is None: error_msg = models.get('depth_init_error', "Depth estimation model is not initialized.") return None, error_msg try: if original_image is None: return None, "No image uploaded for depth prediction." # Prepare the image input_image = original_image.convert("RGB") # Run the depth pipeline result = models['depth_pipe'](input_image) # Access the depth prediction depth_prediction = result.prediction # Adjust based on sample code # Visualize the depth map vis_depth = models['depth_pipe'].image_processor.visualize_depth(depth_prediction) # Ensure vis_depth is a list and extract the first image if isinstance(vis_depth, list) and len(vis_depth) > 0: vis_depth_image = vis_depth[0] else: vis_depth_image = vis_depth # Fallback if not a list return vis_depth_image, "Depth prediction completed." except Exception as e: print(f"Error during depth prediction: {e}") return None, f"Error during depth prediction: {e}" # Gradio Interface Components with gr.Blocks() as demo: gr.Markdown("# Eurybia Mini - Object Detection and Analysis Tool") with gr.Tab("Upload & Process"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") process_button = gr.Button("Process Image") clear_button = gr.Button("Clear") with gr.Column(): processed_image = gr.Image(type="pil", label="Processed Image") detection_info = gr.Textbox(label="Detection Information", lines=10) with gr.Tab("Detection Details"): with gr.Accordion("Click to see detection details", open=False): detection_details_md = gr.Markdown("No detections to show.") with gr.Tab("Ask Eurybia"): with gr.Row(): with gr.Column(): question_input = gr.Textbox(label="Ask a question about the detected objects") ask_button = gr.Button("Ask Eurybia") with gr.Column(): answer_output = gr.Markdown(label="Eurybia's Answer") with gr.Tab("Depth Estimation"): with gr.Row(): with gr.Column(): depth_button = gr.Button("Run Depth Prediction") with gr.Column(): depth_output = gr.Image(type="pil", label="Depth Map") depth_status = gr.Textbox(label="Status", lines=2) # Display error message if depth estimation model failed to initialize if models.get('depth_init_error'): gr.Markdown(f"**Depth Estimation Initialization Error:** {models['depth_init_error']}") with gr.Tab("Enhance Detected Objects"): if models['yolo_model'] is not None and models['upscaler'] is not None: with gr.Row(): detected_objects = gr.Dropdown(choices=[], label="Select Detected Object", interactive=True) enhance_btn = gr.Button("Enhance Image") with gr.Column(): enhanced_image = gr.Image(type="pil", label="Enhanced Image") enhance_status = gr.Textbox(label="Status", lines=2) else: gr.Markdown("**Warning:** YOLO model or Upscaling model is not initialized. Image enhancement functionality will be unavailable.") with gr.Tab("Credits"): gr.Markdown(""" # Credits and Licensing Information This project utilizes various open-source libraries, tools, pretrained models, and datasets. Below is the list of components used and their respective credits/licenses: ## Libraries - **Python** - Python Software Foundation License (PSF License) - **Gradio** - Licensed under the Apache License 2.0 - **Torch (PyTorch)** - Licensed under the BSD 3-Clause License - **OpenCV (cv2)** - Licensed under the Apache License 2.0 - **NumPy** - Licensed under the BSD License - **Pillow (PIL)** - Licensed under the HPND License - **Requests** - Licensed under the Apache License 2.0 - **Wikipedia API** - Licensed under the MIT License - **Transformers** - Licensed under the Apache License 2.0 - **Diffusers** - Licensed under the Apache License 2.0 - **Real-ESRGAN** - Licensed under the MIT License - **BasicSR** - Licensed under the Apache License 2.0 - **Ultralytics YOLO** - Licensed under the GPL-3.0 License ## Pretrained Models - **deepset/roberta-base-squad2 (RoBERTa)** - Model provided by Hugging Face under the Apache License 2.0. - **google/gemma-2-2b-it** - Model provided by Hugging Face under the Apache License 2.0. - **prs-eth/marigold-depth-lcm-v1-0** - Licensed under the Apache License 2.0. - **Real-ESRGAN model weights (RealESRGAN_x4plus.pth)** - Distributed under the MIT License. - **FathomNet MBARI 315K YOLOv8 Model**: - **Dataset**: Sourced from [FathomNet](https://fathomnet.org). - **Model**: Derived from MBARI’s curated dataset of 315,000 marine annotations. - **License**: Dataset and models adhere to MBARI’s Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0). ## Datasets - **FathomNet MBARI Dataset**: - A large-scale dataset for marine biodiversity image annotations. - All content adheres to the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). ## Acknowledgments - **Ultralytics YOLO**: For the YOLOv8 architecture used for object detection. - **FathomNet and MBARI**: For providing the marine dataset and annotations that support object detection in underwater imagery. - **Gradio**: For providing an intuitive interface for machine learning applications. - **Hugging Face**: For pretrained models and pipelines (e.g., Transformers, Diffusers). - **Real-ESRGAN**: For image enhancement and upscaling models. - **Wikipedia API**: For fetching object descriptions and images. """) # Hidden state to store bounding boxes, original and processed images state = gr.State({"bounding_boxes": [], "last_image": None, "original_image": None}) # Event Handlers def on_process_image(image, state): processed_img, info, details, bounding_boxes, original_image_pil = process_image(image) if processed_img is not None: # Update the state with new bounding boxes and images state['bounding_boxes'] = bounding_boxes state['last_image'] = processed_img state['original_image'] = original_image_pil # Update the dropdown choices for detected objects choices = [f"{idx+1}. {box['class_name']} ({box['confidence']:.2f}%)" for idx, box in enumerate(bounding_boxes)] else: choices = [] return processed_img, info, details, gr.update(choices=choices), state process_button.click( on_process_image, inputs=[image_input, state], outputs=[processed_image, detection_info, detection_details_md, detected_objects, state] ) def on_clear(state): state = {"bounding_boxes": [], "last_image": None, "original_image": None} return None, "No detections found.", "No detections to show.", gr.update(choices=[]), state clear_button.click( on_clear, inputs=state, outputs=[processed_image, detection_info, detection_details_md, detected_objects, state] ) def on_ask_eurybia(question, state): answer, state = ask_eurybia(question, state) return answer, state ask_button.click( on_ask_eurybia, inputs=[question_input, state], outputs=[answer_output, state] ) def on_depth_prediction(state): original_image = state.get('original_image') depth_img, status = run_depth_prediction(original_image) return depth_img, status depth_button.click( on_depth_prediction, inputs=state, outputs=[depth_output, depth_status] ) def on_enhance_image(selected_object, state): if not selected_object: return None, "No object selected.", state try: idx = int(selected_object.split('.')[0]) - 1 box = state['bounding_boxes'][idx] class_name = box['class_name'] x1, y1, x2, y2 = box['coords'] if not state.get('last_image'): return None, "Processed image is not available.", state # Ensure processed_image is stored in state processed_img_pil = state['last_image'] if not isinstance(processed_img_pil, Image.Image): return None, "Processed image is in an unsupported format.", state # Convert processed_image to OpenCV format with checks processed_img_cv = np.array(processed_img_pil) if processed_img_cv.dtype != np.uint8: processed_img_cv = processed_img_cv.astype(np.uint8) if len(processed_img_cv.shape) != 3 or processed_img_cv.shape[2] != 3: return None, "Invalid processed image format.", state processed_img_cv = cv2.cvtColor(processed_img_cv, cv2.COLOR_RGB2BGR) # Crop the detected object from the processed image cropped_img_cv = processed_img_cv[y1:y2, x1:x2] if cropped_img_cv.size == 0: return None, "Cropped image is empty.", state cropped_img_pil = Image.fromarray(cv2.cvtColor(cropped_img_cv, cv2.COLOR_BGR2RGB)) # Enhance the cropped image enhanced_img, status = enhance_image(cropped_img_pil) return enhanced_img, status, state except Exception as e: return None, f"Error: {e}", state if models['yolo_model'] is not None and models['upscaler'] is not None: enhance_btn.click( on_enhance_image, inputs=[detected_objects, state], outputs=[enhanced_image, enhance_status, state] ) # Optional: Add a note if the depth model isn't initialized if models['depth_pipe'] is None and not models.get('depth_init_error'): gr.Markdown("**Warning:** Depth estimation model is not initialized. Depth prediction functionality will be unavailable.") # Optional: Add a note if the upscaler isn't initialized if models['upscaler'] is None: gr.Markdown("**Warning:** Upscaling model is not initialized. Image enhancement functionality will be unavailable.") # Launch the Gradio app demo.launch()