lixin4ever commited on
Commit
99b1651
·
verified ·
1 Parent(s): 674be41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -88
app.py CHANGED
@@ -22,88 +22,84 @@ model.to(device)
22
  processor = AutoProcessor.from_pretrained("DAMO-NLP-SG/VideoLLaMA3-7B", trust_remote_code=True)
23
 
24
 
25
- class VideoLLaMA3GradioInterface(object):
26
-
27
- def __init__(self, device="cpu", example_dir=None, **server_kwargs):
28
- self.device = device
29
- self.server_kwargs = server_kwargs
30
-
31
- self.image_formats = ("png", "jpg", "jpeg")
32
- self.video_formats = ("mp4",)
33
-
34
- image_examples, video_examples = [], []
35
- if example_dir is not None:
36
- example_files = [
37
- osp.join(example_dir, f) for f in os.listdir(example_dir)
38
- ]
39
- for example_file in example_files:
40
- if example_file.endswith(self.image_formats):
41
- image_examples.append([example_file])
42
- elif example_file.endswith(self.video_formats):
43
- video_examples.append([example_file])
44
-
45
- with gr.Blocks() as self.interface:
46
- gr.Markdown(HEADER)
47
- with gr.Row():
48
- chatbot = gr.Chatbot(type="messages", elem_id="chatbot", height=710)
49
-
50
- with gr.Column():
51
- with gr.Tab(label="Input"):
52
-
53
- with gr.Row():
54
- input_video = gr.Video(sources=["upload"], label="Upload Video")
55
- input_image = gr.Image(sources=["upload"], type="filepath", label="Upload Image")
56
-
57
- if len(image_examples):
58
- gr.Examples(image_examples, inputs=[input_image], label="Example Images")
59
- if len(video_examples):
60
- gr.Examples(video_examples, inputs=[input_video], label="Example Videos")
61
-
62
- input_text = gr.Textbox(label="Input Text", placeholder="Type your message here and press enter to submit")
63
-
64
- submit_button = gr.Button("Generate")
65
-
66
- with gr.Tab(label="Configure"):
67
- with gr.Accordion("Generation Config", open=True):
68
- do_sample = gr.Checkbox(value=True, label="Do Sample")
69
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature")
70
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
71
- max_new_tokens = gr.Slider(minimum=0, maximum=4096, value=2048, step=1, label="Max New Tokens")
72
-
73
- with gr.Accordion("Video Config", open=True):
74
- fps = gr.Slider(minimum=0.0, maximum=10.0, value=1, label="FPS")
75
- max_frames = gr.Slider(minimum=0, maximum=256, value=180, step=1, label="Max Frames")
76
-
77
- input_video.change(self._on_video_upload, [chatbot, input_video], [chatbot, input_video])
78
- input_image.change(self._on_image_upload, [chatbot, input_image], [chatbot, input_image])
79
- input_text.submit(self._on_text_submit, [chatbot, input_text], [chatbot, input_text])
80
- submit_button.click(
81
- self._predict,
82
- [
83
- chatbot, input_text, do_sample, temperature, top_p, max_new_tokens,
84
- fps, max_frames
85
- ],
86
- [chatbot],
87
- )
88
-
89
- def _on_video_upload(self, messages, video):
90
  if video is not None:
91
  # messages.append({"role": "user", "content": gr.Video(video)})
92
  messages.append({"role": "user", "content": {"path": video}})
93
  return messages, None
94
-
95
- def _on_image_upload(self, messages, image):
96
  if image is not None:
97
  # messages.append({"role": "user", "content": gr.Image(image)})
98
  messages.append({"role": "user", "content": {"path": image}})
99
  return messages, None
100
-
101
- def _on_text_submit(self, messages, text):
102
  messages.append({"role": "user", "content": text})
103
  return messages, ""
104
-
105
  @spaces.GPU(duration=120)
106
- def _predict(self, messages, input_text, do_sample, temperature, top_p, max_new_tokens,
107
  fps, max_frames):
108
  if len(input_text) > 0:
109
  messages.append({"role": "user", "content": input_text})
@@ -120,58 +116,51 @@ class VideoLLaMA3GradioInterface(object):
120
  contents.append(message["content"])
121
  else:
122
  media_path = message["content"][0]
123
- if media_path.endswith(self.video_formats):
124
  contents.append({"type": "video", "video": {"video_path": media_path, "fps": fps, "max_frames": max_frames}})
125
- elif media_path.endswith(self.image_formats):
126
  contents.append({"type": "image", "image": {"image_path": media_path}})
127
  else:
128
  raise ValueError(f"Unsupported media type: {media_path}")
129
-
130
  if len(contents):
131
  new_messages.append({"role": "user", "content": contents})
132
-
133
  if len(new_messages) == 0 or new_messages[-1]["role"] != "user":
134
  return messages
135
-
136
  generation_config = {
137
  "do_sample": do_sample,
138
  "temperature": temperature,
139
  "top_p": top_p,
140
  "max_new_tokens": max_new_tokens
141
  }
142
-
143
  inputs = processor(
144
  conversation=new_messages,
145
  add_system_prompt=True,
146
  add_generation_prompt=True,
147
  return_tensors="pt"
148
  )
149
- inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
150
  if "pixel_values" in inputs:
151
  inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
152
-
153
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
154
  generation_kwargs = {
155
  **inputs,
156
  **generation_config,
157
  "streamer": streamer,
158
  }
159
-
160
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
161
  thread.start()
162
-
163
  messages.append({"role": "assistant", "content": ""})
164
  for token in streamer:
165
  messages[-1]['content'] += token
166
  yield messages
167
 
168
- def launch(self):
169
- self.interface.launch(**self.server_kwargs)
170
-
171
 
172
  if __name__ == "__main__":
173
- interface = VideoLLaMA3GradioInterface(
174
- device=device,
175
- example_dir="./examples",
176
- )
177
  interface.launch()
 
22
  processor = AutoProcessor.from_pretrained("DAMO-NLP-SG/VideoLLaMA3-7B", trust_remote_code=True)
23
 
24
 
25
+ example_dir = "./examples"
26
+ image_formats = ("png", "jpg", "jpeg")
27
+ video_formats = ("mp4",)
28
+
29
+ image_examples, video_examples = [], []
30
+ if example_dir is not None:
31
+ example_files = [
32
+ osp.join(example_dir, f) for f in os.listdir(example_dir)
33
+ ]
34
+ for example_file in example_files:
35
+ if example_file.endswith(image_formats):
36
+ image_examples.append([example_file])
37
+ elif example_file.endswith(video_formats):
38
+ video_examples.append([example_file])
39
+
40
+
41
+ with gr.Blocks() as interface:
42
+ gr.Markdown(HEADER)
43
+ with gr.Row():
44
+ chatbot = gr.Chatbot(type="messages", elem_id="chatbot", height=710)
45
+
46
+ with gr.Column():
47
+ with gr.Tab(label="Input"):
48
+
49
+ with gr.Row():
50
+ input_video = gr.Video(sources=["upload"], label="Upload Video")
51
+ input_image = gr.Image(sources=["upload"], type="filepath", label="Upload Image")
52
+
53
+ if len(image_examples):
54
+ gr.Examples(image_examples, inputs=[input_image], label="Example Images")
55
+ if len(video_examples):
56
+ gr.Examples(video_examples, inputs=[input_video], label="Example Videos")
57
+
58
+ input_text = gr.Textbox(label="Input Text", placeholder="Type your message here and press enter to submit")
59
+
60
+ submit_button = gr.Button("Generate")
61
+
62
+ with gr.Tab(label="Configure"):
63
+ with gr.Accordion("Generation Config", open=True):
64
+ do_sample = gr.Checkbox(value=True, label="Do Sample")
65
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature")
66
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
67
+ max_new_tokens = gr.Slider(minimum=0, maximum=4096, value=2048, step=1, label="Max New Tokens")
68
+
69
+ with gr.Accordion("Video Config", open=True):
70
+ fps = gr.Slider(minimum=0.0, maximum=10.0, value=1, label="FPS")
71
+ max_frames = gr.Slider(minimum=0, maximum=256, value=180, step=1, label="Max Frames")
72
+
73
+ input_video.change(_on_video_upload, [chatbot, input_video], [chatbot, input_video])
74
+ input_image.change(_on_image_upload, [chatbot, input_image], [chatbot, input_image])
75
+ input_text.submit(_on_text_submit, [chatbot, input_text], [chatbot, input_text])
76
+ submit_button.click(
77
+ _predict,
78
+ [
79
+ chatbot, input_text, do_sample, temperature, top_p, max_new_tokens,
80
+ fps, max_frames
81
+ ],
82
+ [chatbot],
83
+ )
84
+
85
+ def _on_video_upload(messages, video):
 
 
 
 
86
  if video is not None:
87
  # messages.append({"role": "user", "content": gr.Video(video)})
88
  messages.append({"role": "user", "content": {"path": video}})
89
  return messages, None
90
+
91
+ def _on_image_upload(messages, image):
92
  if image is not None:
93
  # messages.append({"role": "user", "content": gr.Image(image)})
94
  messages.append({"role": "user", "content": {"path": image}})
95
  return messages, None
96
+
97
+ def _on_text_submit(messages, text):
98
  messages.append({"role": "user", "content": text})
99
  return messages, ""
100
+
101
  @spaces.GPU(duration=120)
102
+ def _predict(messages, input_text, do_sample, temperature, top_p, max_new_tokens,
103
  fps, max_frames):
104
  if len(input_text) > 0:
105
  messages.append({"role": "user", "content": input_text})
 
116
  contents.append(message["content"])
117
  else:
118
  media_path = message["content"][0]
119
+ if media_path.endswith(video_formats):
120
  contents.append({"type": "video", "video": {"video_path": media_path, "fps": fps, "max_frames": max_frames}})
121
+ elif media_path.endswith(image_formats):
122
  contents.append({"type": "image", "image": {"image_path": media_path}})
123
  else:
124
  raise ValueError(f"Unsupported media type: {media_path}")
125
+
126
  if len(contents):
127
  new_messages.append({"role": "user", "content": contents})
128
+
129
  if len(new_messages) == 0 or new_messages[-1]["role"] != "user":
130
  return messages
131
+
132
  generation_config = {
133
  "do_sample": do_sample,
134
  "temperature": temperature,
135
  "top_p": top_p,
136
  "max_new_tokens": max_new_tokens
137
  }
138
+
139
  inputs = processor(
140
  conversation=new_messages,
141
  add_system_prompt=True,
142
  add_generation_prompt=True,
143
  return_tensors="pt"
144
  )
145
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
146
  if "pixel_values" in inputs:
147
  inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
148
+
149
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
150
  generation_kwargs = {
151
  **inputs,
152
  **generation_config,
153
  "streamer": streamer,
154
  }
155
+
156
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
157
  thread.start()
158
+
159
  messages.append({"role": "assistant", "content": ""})
160
  for token in streamer:
161
  messages[-1]['content'] += token
162
  yield messages
163
 
 
 
 
164
 
165
  if __name__ == "__main__":
 
 
 
 
166
  interface.launch()