Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextStreamer | |
import torch | |
import gc | |
import os | |
from accelerate import init_empty_weights | |
from accelerate.utils import load_checkpoint_in_model | |
import psutil | |
# Enable better CPU performance | |
torch.set_num_threads(4) | |
device = "cpu" | |
def get_free_memory(): | |
"""Get available memory in GB""" | |
return psutil.virtual_memory().available / (1024 * 1024 * 1024) | |
def load_model_in_chunks(model_path, chunk_size_gb=2): | |
"""Load model in chunks to manage memory""" | |
config = AutoModelForCausalLM.from_pretrained(model_path, return_dict=False).config | |
with init_empty_weights(): | |
empty_model = AutoModelForCausalLM.from_config(config) | |
# Get checkpoint files | |
index_path = os.path.join(model_path, "model.safetensors.index.json") | |
if os.path.exists(index_path): | |
checkpoint_files = [ | |
os.path.join(model_path, f"model-{i:05d}-of-{5:05d}.safetensors") | |
for i in range(1, 6) | |
] | |
else: | |
checkpoint_files = [os.path.join(model_path, "pytorch_model.bin")] | |
# Load each chunk | |
for checkpoint in checkpoint_files: | |
if get_free_memory() < 2: # If less than 2GB free | |
gc.collect() | |
torch.cuda.empty_cache() | |
load_checkpoint_in_model(empty_model, checkpoint) | |
gc.collect() | |
return empty_model | |
def load_model(): | |
model_name = "forestav/unsloth_vision_radiography_finetune" | |
base_model_name = "unsloth/Llama-3.2-11B-Vision-Instruct" | |
print("Loading tokenizer and processor...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
base_model_name, | |
trust_remote_code=True, | |
cache_dir="model_cache" | |
) | |
processor = AutoProcessor.from_pretrained( | |
base_model_name, | |
trust_remote_code=True, | |
cache_dir="model_cache" | |
) | |
print("Loading model in chunks...") | |
model = load_model_in_chunks(model_name) | |
print("Optimizing model for CPU...") | |
# Convert to float32 and quantize | |
model = model.to(torch.float32) | |
model = torch.quantization.quantize_dynamic( | |
model, | |
{torch.nn.Linear, torch.nn.Conv2d}, | |
dtype=torch.qint8 | |
) | |
return model, tokenizer, processor | |
# Create cache directories | |
os.makedirs("model_cache", exist_ok=True) | |
os.makedirs("offload", exist_ok=True) | |
print(f"Available memory before loading: {get_free_memory():.2f} GB") | |
# Initialize model and tokenizer globally | |
print("Starting model initialization...") | |
try: | |
model, tokenizer, processor = load_model() | |
print("Model loaded successfully!") | |
print(f"Available memory after loading: {get_free_memory():.2f} GB") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
def analyze_image(image, instruction): | |
try: | |
gc.collect() | |
if instruction.strip() == "": | |
instruction = "You are an expert radiographer. Describe accurately what you see in this image." | |
messages = [ | |
{"role": "user", "content": [ | |
{"type": "image"}, | |
{"type": "text", "text": instruction} | |
]} | |
] | |
# Process with memory checks | |
if get_free_memory() < 2: | |
gc.collect() | |
inputs = processor( | |
images=image, | |
text=tokenizer.apply_chat_template(messages, add_generation_prompt=True), | |
return_tensors="pt" | |
) | |
# Generate with minimal memory usage | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=128, | |
temperature=1.0, | |
min_p=0.1, | |
use_cache=True, | |
pad_token_id=tokenizer.eos_token_id, | |
num_beams=1, | |
do_sample=False # Disable sampling to save memory | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
del outputs, inputs | |
gc.collect() | |
return response | |
except Exception as e: | |
return f"Error processing image: {str(e)}\nPlease try again with a smaller image or different settings." | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# Medical Image Analysis Assistant | |
Upload a medical image and receive a professional description from an AI radiographer. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
type="pil", | |
label="Upload Medical Image", | |
max_pixels=1000000 # Reduced max image size | |
) | |
instruction_input = gr.Textbox( | |
label="Custom Instruction (optional)", | |
placeholder="You are an expert radiographer. Describe accurately what you see in this image.", | |
lines=2 | |
) | |
submit_btn = gr.Button("Analyze Image") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Analysis Result", lines=10) | |
submit_btn.click( | |
fn=analyze_image, | |
inputs=[image_input, instruction_input], | |
outputs=output_text | |
) | |
gr.Markdown(""" | |
### Notes: | |
- The model runs on CPU and may take several minutes to process each image | |
- For best results, upload images smaller than 1MP | |
- Initial loading may take some time | |
- Please be patient during processing | |
""") | |
if __name__ == "__main__": | |
demo.launch() |