Filip commited on
Commit
56d8f41
·
1 Parent(s): 57a1258

update torch

Browse files
Files changed (1) hide show
  1. app.py +43 -28
app.py CHANGED
@@ -2,41 +2,56 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextStreamer
3
  import torch
4
  import gc
 
5
 
6
- # Configure torch to use CPU
 
7
  device = "cpu"
8
- torch.set_default_device(device)
9
 
10
- # Load model and tokenizer
11
  def load_model():
12
  model_name = "forestav/unsloth_vision_radiography_finetune"
13
 
14
- # Load with 8-bit quantization and CPU optimization settings
 
 
 
 
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name,
17
  device_map="cpu",
18
- load_in_8bit=True,
19
- torch_dtype=torch.float16,
20
- low_cpu_mem_usage=True
 
21
  )
22
- tokenizer = AutoTokenizer.from_pretrained(model_name)
23
- processor = AutoProcessor.from_pretrained(model_name)
 
 
 
 
 
 
 
24
  return model, tokenizer, processor
25
 
 
 
 
26
  # Initialize model and tokenizer globally
27
- print("Loading model...")
28
  try:
29
  model, tokenizer, processor = load_model()
30
- print("Model loaded successfully!")
31
  except Exception as e:
32
  print(f"Error loading model: {str(e)}")
33
  raise
34
 
35
  def analyze_image(image, instruction):
36
  try:
37
- # Clear CUDA cache and garbage collect
38
- if torch.cuda.is_available():
39
- torch.cuda.empty_cache()
40
  gc.collect()
41
 
42
  if instruction.strip() == "":
@@ -57,32 +72,28 @@ def analyze_image(image, instruction):
57
  return_tensors="pt"
58
  )
59
 
60
- # Generate the response
61
- text_streamer = TextStreamer(tokenizer, skip_prompt=True)
62
-
63
- # Generate with lower resource settings
64
  with torch.no_grad():
65
  outputs = model.generate(
66
  **inputs,
67
  max_new_tokens=128,
68
- temperature=1.2,
69
  min_p=0.1,
70
  use_cache=True,
71
- streamer=text_streamer
 
72
  )
73
 
74
  # Decode the response
75
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
 
77
- # Clear memory
78
  del outputs
79
  gc.collect()
80
- if torch.cuda.is_available():
81
- torch.cuda.empty_cache()
82
 
83
  return response
84
  except Exception as e:
85
- return f"Error processing image: {str(e)}"
86
 
87
  # Create the Gradio interface
88
  with gr.Blocks() as demo:
@@ -93,7 +104,11 @@ with gr.Blocks() as demo:
93
 
94
  with gr.Row():
95
  with gr.Column():
96
- image_input = gr.Image(type="pil", label="Upload Medical Image")
 
 
 
 
97
  instruction_input = gr.Textbox(
98
  label="Custom Instruction (optional)",
99
  placeholder="You are an expert radiographer. Describe accurately what you see in this image.",
@@ -113,9 +128,9 @@ with gr.Blocks() as demo:
113
 
114
  gr.Markdown("""
115
  ### Notes:
116
- - The model runs on CPU and may take a few moments to process each image
117
- - For best results, upload clear, high-quality medical images
118
- - Default instruction will be used if none is provided
119
  """)
120
 
121
  # Launch the app
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextStreamer
3
  import torch
4
  import gc
5
+ import os
6
 
7
+ # Enable better CPU performance
8
+ torch.set_num_threads(4) # Adjust based on available CPU cores
9
  device = "cpu"
 
10
 
 
11
  def load_model():
12
  model_name = "forestav/unsloth_vision_radiography_finetune"
13
 
14
+ # Load tokenizer and processor first to free up memory
15
+ print("Loading tokenizer and processor...")
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ processor = AutoProcessor.from_pretrained(model_name)
18
+
19
+ print("Loading model...")
20
+ # Load model with CPU optimizations
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_name,
23
  device_map="cpu",
24
+ torch_dtype=torch.float32, # Use float32 for CPU
25
+ low_cpu_mem_usage=True,
26
+ offload_folder="offload", # Enable disk offloading
27
+ offload_state_dict=True # Offload state dict to disk
28
  )
29
+
30
+ # Quantize the model for CPU
31
+ print("Quantizing model...")
32
+ model = torch.quantization.quantize_dynamic(
33
+ model,
34
+ {torch.nn.Linear}, # Quantize linear layers
35
+ dtype=torch.qint8
36
+ )
37
+
38
  return model, tokenizer, processor
39
 
40
+ # Create offload directory if it doesn't exist
41
+ os.makedirs("offload", exist_ok=True)
42
+
43
  # Initialize model and tokenizer globally
44
+ print("Starting model initialization...")
45
  try:
46
  model, tokenizer, processor = load_model()
47
+ print("Model loaded and quantized successfully!")
48
  except Exception as e:
49
  print(f"Error loading model: {str(e)}")
50
  raise
51
 
52
  def analyze_image(image, instruction):
53
  try:
54
+ # Clear memory
 
 
55
  gc.collect()
56
 
57
  if instruction.strip() == "":
 
72
  return_tensors="pt"
73
  )
74
 
75
+ # Generate with conservative settings for CPU
 
 
 
76
  with torch.no_grad():
77
  outputs = model.generate(
78
  **inputs,
79
  max_new_tokens=128,
80
+ temperature=1.0,
81
  min_p=0.1,
82
  use_cache=True,
83
+ pad_token_id=tokenizer.eos_token_id,
84
+ num_beams=1 # Reduce beam search to save memory
85
  )
86
 
87
  # Decode the response
88
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
 
90
+ # Clean up
91
  del outputs
92
  gc.collect()
 
 
93
 
94
  return response
95
  except Exception as e:
96
+ return f"Error processing image: {str(e)}\nPlease try again with a smaller image or different settings."
97
 
98
  # Create the Gradio interface
99
  with gr.Blocks() as demo:
 
104
 
105
  with gr.Row():
106
  with gr.Column():
107
+ image_input = gr.Image(
108
+ type="pil",
109
+ label="Upload Medical Image",
110
+ max_pixels=1500000 # Limit image size
111
+ )
112
  instruction_input = gr.Textbox(
113
  label="Custom Instruction (optional)",
114
  placeholder="You are an expert radiographer. Describe accurately what you see in this image.",
 
128
 
129
  gr.Markdown("""
130
  ### Notes:
131
+ - The model runs on CPU and may take several moments to process each image
132
+ - For best results, upload images smaller than 1.5MP
133
+ - Please be patient during processing
134
  """)
135
 
136
  # Launch the app