aifeifei798 commited on
Commit
7324b5a
·
verified ·
1 Parent(s): 2b80d14

Upload feifeiflorencebase.py

Browse files
Files changed (1) hide show
  1. feifeilib/feifeiflorencebase.py +221 -0
feifeilib/feifeiflorencebase.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ import spaces
4
+
5
+ import requests
6
+ import copy
7
+
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import io
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+
13
+ import random
14
+ import numpy as np
15
+
16
+ import subprocess
17
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
18
+
19
+ models = {
20
+ 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval()
21
+ }
22
+
23
+ processors = {
24
+ 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
25
+ }
26
+
27
+
28
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
29
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
30
+
31
+ def fig_to_pil(fig):
32
+ buf = io.BytesIO()
33
+ fig.savefig(buf, format='png')
34
+ buf.seek(0)
35
+ return Image.open(buf)
36
+
37
+ @spaces.GPU
38
+ def run_example(task_prompt = "<MORE_DETAILED_CAPTION>", image = None, text_input = None, model_id='microsoft/Florence-2-base', progress=gr.Progress(track_tqdm=True)):
39
+ model = models[model_id]
40
+ processor = processors[model_id]
41
+ if text_input is None:
42
+ prompt = task_prompt
43
+ else:
44
+ prompt = task_prompt + text_input
45
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
46
+ generated_ids = model.generate(
47
+ input_ids=inputs["input_ids"],
48
+ pixel_values=inputs["pixel_values"],
49
+ max_new_tokens=1024,
50
+ early_stopping=False,
51
+ do_sample=False,
52
+ num_beams=3,
53
+ )
54
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
55
+ parsed_answer = processor.post_process_generation(
56
+ generated_text,
57
+ task=task_prompt,
58
+ image_size=(image.width, image.height)
59
+ )
60
+ return parsed_answer
61
+
62
+ def plot_bbox(image, data):
63
+ fig, ax = plt.subplots()
64
+ ax.imshow(image)
65
+ for bbox, label in zip(data['bboxes'], data['labels']):
66
+ x1, y1, x2, y2 = bbox
67
+ rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
68
+ ax.add_patch(rect)
69
+ plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
70
+ ax.axis('off')
71
+ return fig
72
+
73
+ def draw_polygons(image, prediction, fill_mask=False):
74
+
75
+ draw = ImageDraw.Draw(image)
76
+ scale = 1
77
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
78
+ color = random.choice(colormap)
79
+ fill_color = random.choice(colormap) if fill_mask else None
80
+ for _polygon in polygons:
81
+ _polygon = np.array(_polygon).reshape(-1, 2)
82
+ if len(_polygon) < 3:
83
+ print('Invalid polygon:', _polygon)
84
+ continue
85
+ _polygon = (_polygon * scale).reshape(-1).tolist()
86
+ if fill_mask:
87
+ draw.polygon(_polygon, outline=color, fill=fill_color)
88
+ else:
89
+ draw.polygon(_polygon, outline=color)
90
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
91
+ return image
92
+
93
+ def convert_to_od_format(data):
94
+ bboxes = data.get('bboxes', [])
95
+ labels = data.get('bboxes_labels', [])
96
+ od_results = {
97
+ 'bboxes': bboxes,
98
+ 'labels': labels
99
+ }
100
+ return od_results
101
+
102
+ def draw_ocr_bboxes(image, prediction):
103
+ scale = 1
104
+ draw = ImageDraw.Draw(image)
105
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
106
+ for box, label in zip(bboxes, labels):
107
+ color = random.choice(colormap)
108
+ new_box = (np.array(box) * scale).tolist()
109
+ draw.polygon(new_box, width=3, outline=color)
110
+ draw.text((new_box[0]+8, new_box[1]+2),
111
+ "{}".format(label),
112
+ align="right",
113
+ fill=color)
114
+ return image
115
+
116
+ def process_image(image, task_prompt = "More Detailed Caption", text_input=None, model_id='microsoft/Florence-2-base'):
117
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
118
+ if task_prompt == 'Caption':
119
+ task_prompt = '<CAPTION>'
120
+ results = run_example(task_prompt, image, model_id=model_id)
121
+ return results
122
+ elif task_prompt == 'Detailed Caption':
123
+ task_prompt = '<DETAILED_CAPTION>'
124
+ results = run_example(task_prompt, image, model_id=model_id)
125
+ return results
126
+ elif task_prompt == 'More Detailed Caption':
127
+ task_prompt = '<MORE_DETAILED_CAPTION>'
128
+ results = run_example(task_prompt, image, model_id=model_id)
129
+ results = results[task_prompt]
130
+ return results
131
+ elif task_prompt == 'Caption + Grounding':
132
+ task_prompt = '<CAPTION>'
133
+ results = run_example(task_prompt, image, model_id=model_id)
134
+ text_input = results[task_prompt]
135
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
136
+ results = run_example(task_prompt, image, text_input, model_id)
137
+ results['<CAPTION>'] = text_input
138
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
139
+ return results, fig_to_pil(fig)
140
+ elif task_prompt == 'Detailed Caption + Grounding':
141
+ task_prompt = '<DETAILED_CAPTION>'
142
+ results = run_example(task_prompt, image, model_id=model_id)
143
+ text_input = results[task_prompt]
144
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
145
+ results = run_example(task_prompt, image, text_input, model_id)
146
+ results['<DETAILED_CAPTION>'] = text_input
147
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
148
+ return results, fig_to_pil(fig)
149
+ elif task_prompt == 'More Detailed Caption + Grounding':
150
+ task_prompt = '<MORE_DETAILED_CAPTION>'
151
+ results = run_example(task_prompt, image, model_id=model_id)
152
+ text_input = results[task_prompt]
153
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
154
+ results = run_example(task_prompt, image, text_input, model_id)
155
+ results['<MORE_DETAILED_CAPTION>'] = text_input
156
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
157
+ return results, fig_to_pil(fig)
158
+ elif task_prompt == 'Object Detection':
159
+ task_prompt = '<OD>'
160
+ results = run_example(task_prompt, image, model_id=model_id)
161
+ fig = plot_bbox(image, results['<OD>'])
162
+ return results, fig_to_pil(fig)
163
+ elif task_prompt == 'Dense Region Caption':
164
+ task_prompt = '<DENSE_REGION_CAPTION>'
165
+ results = run_example(task_prompt, image, model_id=model_id)
166
+ fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
167
+ return results, fig_to_pil(fig)
168
+ elif task_prompt == 'Region Proposal':
169
+ task_prompt = '<REGION_PROPOSAL>'
170
+ results = run_example(task_prompt, image, model_id=model_id)
171
+ fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
172
+ return results, fig_to_pil(fig)
173
+ elif task_prompt == 'Caption to Phrase Grounding':
174
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
175
+ results = run_example(task_prompt, image, text_input, model_id)
176
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
177
+ return results, fig_to_pil(fig)
178
+ elif task_prompt == 'Referring Expression Segmentation':
179
+ task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
180
+ results = run_example(task_prompt, image, text_input, model_id)
181
+ output_image = copy.deepcopy(image)
182
+ output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
183
+ return results, output_image
184
+ elif task_prompt == 'Region to Segmentation':
185
+ task_prompt = '<REGION_TO_SEGMENTATION>'
186
+ results = run_example(task_prompt, image, text_input, model_id)
187
+ output_image = copy.deepcopy(image)
188
+ output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
189
+ return results, output_image
190
+ elif task_prompt == 'Open Vocabulary Detection':
191
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
192
+ results = run_example(task_prompt, image, text_input, model_id)
193
+ bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
194
+ fig = plot_bbox(image, bbox_results)
195
+ return results, fig_to_pil(fig)
196
+ elif task_prompt == 'Region to Category':
197
+ task_prompt = '<REGION_TO_CATEGORY>'
198
+ results = run_example(task_prompt, image, text_input, model_id)
199
+ return results
200
+ elif task_prompt == 'Region to Description':
201
+ task_prompt = '<REGION_TO_DESCRIPTION>'
202
+ results = run_example(task_prompt, image, text_input, model_id)
203
+ return results
204
+ elif task_prompt == 'OCR':
205
+ task_prompt = '<OCR>'
206
+ results = run_example(task_prompt, image, model_id=model_id)
207
+ return results
208
+ elif task_prompt == 'OCR with Region':
209
+ task_prompt = '<OCR_WITH_REGION>'
210
+ results = run_example(task_prompt, image, model_id=model_id)
211
+ output_image = copy.deepcopy(image)
212
+ output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
213
+ return results, output_image
214
+ else:
215
+ return "", None # Return empty string and None for unknown task prompts
216
+
217
+ def update_task_dropdown(choice):
218
+ if choice == 'Cascased task':
219
+ return gr.Dropdown(choices=cascased_task_list, value='Caption + Grounding')
220
+ else:
221
+ return gr.Dropdown(choices=single_task_list, value='Caption')