File size: 5,582 Bytes
8ed8457
 
 
57a1258
56d8f41
ed56d3f
 
 
8ed8457
56d8f41
fe01251
8ed8457
 
ed56d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ed8457
 
ed56d3f
8ed8457
56d8f41
fe01251
 
ed56d3f
 
fe01251
 
 
 
ed56d3f
 
fe01251
56d8f41
ed56d3f
 
56d8f41
ed56d3f
 
 
56d8f41
 
ed56d3f
56d8f41
 
 
8ed8457
 
ed56d3f
 
56d8f41
 
ed56d3f
 
8ed8457
56d8f41
57a1258
 
ed56d3f
 
57a1258
 
 
8ed8457
 
57a1258
 
 
 
 
 
 
 
 
 
 
 
 
ed56d3f
 
 
 
57a1258
 
 
 
8ed8457
57a1258
ed56d3f
57a1258
 
 
 
56d8f41
57a1258
 
56d8f41
ed56d3f
 
57a1258
 
 
 
ed56d3f
57a1258
 
 
 
56d8f41
8ed8457
 
 
 
 
 
 
 
 
 
56d8f41
 
 
ed56d3f
56d8f41
8ed8457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed56d3f
 
 
56d8f41
8ed8457
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextStreamer
import torch
import gc
import os
from accelerate import init_empty_weights
from accelerate.utils import load_checkpoint_in_model
import psutil

# Enable better CPU performance
torch.set_num_threads(4)
device = "cpu"

def get_free_memory():
    """Get available memory in GB"""
    return psutil.virtual_memory().available / (1024 * 1024 * 1024)

def load_model_in_chunks(model_path, chunk_size_gb=2):
    """Load model in chunks to manage memory"""
    config = AutoModelForCausalLM.from_pretrained(model_path, return_dict=False).config
    
    with init_empty_weights():
        empty_model = AutoModelForCausalLM.from_config(config)
    
    # Get checkpoint files
    index_path = os.path.join(model_path, "model.safetensors.index.json")
    if os.path.exists(index_path):
        checkpoint_files = [
            os.path.join(model_path, f"model-{i:05d}-of-{5:05d}.safetensors")
            for i in range(1, 6)
        ]
    else:
        checkpoint_files = [os.path.join(model_path, "pytorch_model.bin")]
    
    # Load each chunk
    for checkpoint in checkpoint_files:
        if get_free_memory() < 2:  # If less than 2GB free
            gc.collect()
            torch.cuda.empty_cache()
        
        load_checkpoint_in_model(empty_model, checkpoint)
        gc.collect()
    
    return empty_model

def load_model():
    model_name = "forestav/unsloth_vision_radiography_finetune"
    base_model_name = "unsloth/Llama-3.2-11B-Vision-Instruct"
    
    print("Loading tokenizer and processor...")
    tokenizer = AutoTokenizer.from_pretrained(
        base_model_name,
        trust_remote_code=True,
        cache_dir="model_cache"
    )
    
    processor = AutoProcessor.from_pretrained(
        base_model_name,
        trust_remote_code=True,
        cache_dir="model_cache"
    )
    
    print("Loading model in chunks...")
    model = load_model_in_chunks(model_name)
    
    print("Optimizing model for CPU...")
    # Convert to float32 and quantize
    model = model.to(torch.float32)
    model = torch.quantization.quantize_dynamic(
        model,
        {torch.nn.Linear, torch.nn.Conv2d},
        dtype=torch.qint8
    )
    
    return model, tokenizer, processor

# Create cache directories
os.makedirs("model_cache", exist_ok=True)
os.makedirs("offload", exist_ok=True)

print(f"Available memory before loading: {get_free_memory():.2f} GB")

# Initialize model and tokenizer globally
print("Starting model initialization...")
try:
    model, tokenizer, processor = load_model()
    print("Model loaded successfully!")
    print(f"Available memory after loading: {get_free_memory():.2f} GB")
except Exception as e:
    print(f"Error loading model: {str(e)}")
    raise

def analyze_image(image, instruction):
    try:
        gc.collect()
        
        if instruction.strip() == "":
            instruction = "You are an expert radiographer. Describe accurately what you see in this image."
        
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": instruction}
            ]}
        ]
        
        # Process with memory checks
        if get_free_memory() < 2:
            gc.collect()
        
        inputs = processor(
            images=image,
            text=tokenizer.apply_chat_template(messages, add_generation_prompt=True),
            return_tensors="pt"
        )
        
        # Generate with minimal memory usage
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                temperature=1.0,
                min_p=0.1,
                use_cache=True,
                pad_token_id=tokenizer.eos_token_id,
                num_beams=1,
                do_sample=False  # Disable sampling to save memory
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        del outputs, inputs
        gc.collect()
            
        return response
    except Exception as e:
        return f"Error processing image: {str(e)}\nPlease try again with a smaller image or different settings."

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("""
    # Medical Image Analysis Assistant
    Upload a medical image and receive a professional description from an AI radiographer.
    """)
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(
                type="pil",
                label="Upload Medical Image",
                max_pixels=1000000  # Reduced max image size
            )
            instruction_input = gr.Textbox(
                label="Custom Instruction (optional)",
                placeholder="You are an expert radiographer. Describe accurately what you see in this image.",
                lines=2
            )
            submit_btn = gr.Button("Analyze Image")
        
        with gr.Column():
            output_text = gr.Textbox(label="Analysis Result", lines=10)
    
    submit_btn.click(
        fn=analyze_image,
        inputs=[image_input, instruction_input],
        outputs=output_text
    )
    
    gr.Markdown("""
    ### Notes:
    - The model runs on CPU and may take several minutes to process each image
    - For best results, upload images smaller than 1MP
    - Initial loading may take some time
    - Please be patient during processing
    """)

if __name__ == "__main__":
    demo.launch()