lixin4ever commited on
Commit
593a1aa
·
verified ·
1 Parent(s): 0911cc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -11,19 +11,21 @@ HEADER = """
11
  """
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
14
  class VideoLLaMA3GradioInterface(object):
15
 
16
- def __init__(self, model_name, device="cpu", example_dir=None, **server_kwargs):
17
  self.device = device
18
- self.model = AutoModelForCausalLM.from_pretrained(
19
- model_name,
20
- trust_remote_code=True,
21
- torch_dtype=torch.bfloat16,
22
- attn_implementation="flash_attention_2",
23
- )
24
- self.model.to(self.device)
25
- self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
26
-
27
  self.server_kwargs = server_kwargs
28
 
29
  self.image_formats = ("png", "jpg", "jpeg")
@@ -138,7 +140,7 @@ class VideoLLaMA3GradioInterface(object):
138
  "max_new_tokens": max_new_tokens
139
  }
140
 
141
- inputs = self.processor(
142
  conversation=new_messages,
143
  add_system_prompt=True,
144
  add_generation_prompt=True,
@@ -148,14 +150,14 @@ class VideoLLaMA3GradioInterface(object):
148
  if "pixel_values" in inputs:
149
  inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
150
 
151
- streamer = TextIteratorStreamer(self.processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
152
  generation_kwargs = {
153
  **inputs,
154
  **generation_config,
155
  "streamer": streamer,
156
  }
157
 
158
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
159
  thread.start()
160
 
161
  messages.append({"role": "assistant", "content": ""})
@@ -169,8 +171,7 @@ class VideoLLaMA3GradioInterface(object):
169
 
170
  if __name__ == "__main__":
171
  interface = VideoLLaMA3GradioInterface(
172
- model_name="DAMO-NLP-SG/VideoLLaMA3-7B",
173
- device="cuda",
174
  example_dir="./examples",
175
  )
176
  interface.launch()
 
11
  """
12
 
13
 
14
+ device = "cuda"
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ "DAMO-NLP-SG/VideoLLaMA3-7B",
17
+ trust_remote_code=True,
18
+ torch_dtype=torch.bfloat16,
19
+ attn_implementation="flash_attention_2",
20
+ )
21
+ model.to(device)
22
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
23
+
24
+
25
  class VideoLLaMA3GradioInterface(object):
26
 
27
+ def __init__(self, device="cpu", example_dir=None, **server_kwargs):
28
  self.device = device
 
 
 
 
 
 
 
 
 
29
  self.server_kwargs = server_kwargs
30
 
31
  self.image_formats = ("png", "jpg", "jpeg")
 
140
  "max_new_tokens": max_new_tokens
141
  }
142
 
143
+ inputs = processor(
144
  conversation=new_messages,
145
  add_system_prompt=True,
146
  add_generation_prompt=True,
 
150
  if "pixel_values" in inputs:
151
  inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
152
 
153
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
154
  generation_kwargs = {
155
  **inputs,
156
  **generation_config,
157
  "streamer": streamer,
158
  }
159
 
160
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
161
  thread.start()
162
 
163
  messages.append({"role": "assistant", "content": ""})
 
171
 
172
  if __name__ == "__main__":
173
  interface = VideoLLaMA3GradioInterface(
174
+ device=device,
 
175
  example_dir="./examples",
176
  )
177
  interface.launch()