Filip
update
ed56d3f
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextStreamer
import torch
import gc
import os
from accelerate import init_empty_weights
from accelerate.utils import load_checkpoint_in_model
import psutil
# Enable better CPU performance
torch.set_num_threads(4)
device = "cpu"
def get_free_memory():
"""Get available memory in GB"""
return psutil.virtual_memory().available / (1024 * 1024 * 1024)
def load_model_in_chunks(model_path, chunk_size_gb=2):
"""Load model in chunks to manage memory"""
config = AutoModelForCausalLM.from_pretrained(model_path, return_dict=False).config
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_config(config)
# Get checkpoint files
index_path = os.path.join(model_path, "model.safetensors.index.json")
if os.path.exists(index_path):
checkpoint_files = [
os.path.join(model_path, f"model-{i:05d}-of-{5:05d}.safetensors")
for i in range(1, 6)
]
else:
checkpoint_files = [os.path.join(model_path, "pytorch_model.bin")]
# Load each chunk
for checkpoint in checkpoint_files:
if get_free_memory() < 2: # If less than 2GB free
gc.collect()
torch.cuda.empty_cache()
load_checkpoint_in_model(empty_model, checkpoint)
gc.collect()
return empty_model
def load_model():
model_name = "forestav/unsloth_vision_radiography_finetune"
base_model_name = "unsloth/Llama-3.2-11B-Vision-Instruct"
print("Loading tokenizer and processor...")
tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
trust_remote_code=True,
cache_dir="model_cache"
)
processor = AutoProcessor.from_pretrained(
base_model_name,
trust_remote_code=True,
cache_dir="model_cache"
)
print("Loading model in chunks...")
model = load_model_in_chunks(model_name)
print("Optimizing model for CPU...")
# Convert to float32 and quantize
model = model.to(torch.float32)
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear, torch.nn.Conv2d},
dtype=torch.qint8
)
return model, tokenizer, processor
# Create cache directories
os.makedirs("model_cache", exist_ok=True)
os.makedirs("offload", exist_ok=True)
print(f"Available memory before loading: {get_free_memory():.2f} GB")
# Initialize model and tokenizer globally
print("Starting model initialization...")
try:
model, tokenizer, processor = load_model()
print("Model loaded successfully!")
print(f"Available memory after loading: {get_free_memory():.2f} GB")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
def analyze_image(image, instruction):
try:
gc.collect()
if instruction.strip() == "":
instruction = "You are an expert radiographer. Describe accurately what you see in this image."
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": instruction}
]}
]
# Process with memory checks
if get_free_memory() < 2:
gc.collect()
inputs = processor(
images=image,
text=tokenizer.apply_chat_template(messages, add_generation_prompt=True),
return_tensors="pt"
)
# Generate with minimal memory usage
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
temperature=1.0,
min_p=0.1,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
num_beams=1,
do_sample=False # Disable sampling to save memory
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
del outputs, inputs
gc.collect()
return response
except Exception as e:
return f"Error processing image: {str(e)}\nPlease try again with a smaller image or different settings."
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""
# Medical Image Analysis Assistant
Upload a medical image and receive a professional description from an AI radiographer.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Medical Image",
max_pixels=1000000 # Reduced max image size
)
instruction_input = gr.Textbox(
label="Custom Instruction (optional)",
placeholder="You are an expert radiographer. Describe accurately what you see in this image.",
lines=2
)
submit_btn = gr.Button("Analyze Image")
with gr.Column():
output_text = gr.Textbox(label="Analysis Result", lines=10)
submit_btn.click(
fn=analyze_image,
inputs=[image_input, instruction_input],
outputs=output_text
)
gr.Markdown("""
### Notes:
- The model runs on CPU and may take several minutes to process each image
- For best results, upload images smaller than 1MP
- Initial loading may take some time
- Please be patient during processing
""")
if __name__ == "__main__":
demo.launch()