Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -11,19 +11,21 @@ HEADER = """
|
|
11 |
"""
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
class VideoLLaMA3GradioInterface(object):
|
15 |
|
16 |
-
def __init__(self,
|
17 |
self.device = device
|
18 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
19 |
-
model_name,
|
20 |
-
trust_remote_code=True,
|
21 |
-
torch_dtype=torch.bfloat16,
|
22 |
-
attn_implementation="flash_attention_2",
|
23 |
-
)
|
24 |
-
self.model.to(self.device)
|
25 |
-
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
26 |
-
|
27 |
self.server_kwargs = server_kwargs
|
28 |
|
29 |
self.image_formats = ("png", "jpg", "jpeg")
|
@@ -138,7 +140,7 @@ class VideoLLaMA3GradioInterface(object):
|
|
138 |
"max_new_tokens": max_new_tokens
|
139 |
}
|
140 |
|
141 |
-
inputs =
|
142 |
conversation=new_messages,
|
143 |
add_system_prompt=True,
|
144 |
add_generation_prompt=True,
|
@@ -148,14 +150,14 @@ class VideoLLaMA3GradioInterface(object):
|
|
148 |
if "pixel_values" in inputs:
|
149 |
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
|
150 |
|
151 |
-
streamer = TextIteratorStreamer(
|
152 |
generation_kwargs = {
|
153 |
**inputs,
|
154 |
**generation_config,
|
155 |
"streamer": streamer,
|
156 |
}
|
157 |
|
158 |
-
thread = Thread(target=
|
159 |
thread.start()
|
160 |
|
161 |
messages.append({"role": "assistant", "content": ""})
|
@@ -169,8 +171,7 @@ class VideoLLaMA3GradioInterface(object):
|
|
169 |
|
170 |
if __name__ == "__main__":
|
171 |
interface = VideoLLaMA3GradioInterface(
|
172 |
-
|
173 |
-
device="cuda",
|
174 |
example_dir="./examples",
|
175 |
)
|
176 |
interface.launch()
|
|
|
11 |
"""
|
12 |
|
13 |
|
14 |
+
device = "cuda"
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(
|
16 |
+
"DAMO-NLP-SG/VideoLLaMA3-7B",
|
17 |
+
trust_remote_code=True,
|
18 |
+
torch_dtype=torch.bfloat16,
|
19 |
+
attn_implementation="flash_attention_2",
|
20 |
+
)
|
21 |
+
model.to(device)
|
22 |
+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
23 |
+
|
24 |
+
|
25 |
class VideoLLaMA3GradioInterface(object):
|
26 |
|
27 |
+
def __init__(self, device="cpu", example_dir=None, **server_kwargs):
|
28 |
self.device = device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
self.server_kwargs = server_kwargs
|
30 |
|
31 |
self.image_formats = ("png", "jpg", "jpeg")
|
|
|
140 |
"max_new_tokens": max_new_tokens
|
141 |
}
|
142 |
|
143 |
+
inputs = processor(
|
144 |
conversation=new_messages,
|
145 |
add_system_prompt=True,
|
146 |
add_generation_prompt=True,
|
|
|
150 |
if "pixel_values" in inputs:
|
151 |
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
|
152 |
|
153 |
+
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
154 |
generation_kwargs = {
|
155 |
**inputs,
|
156 |
**generation_config,
|
157 |
"streamer": streamer,
|
158 |
}
|
159 |
|
160 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
161 |
thread.start()
|
162 |
|
163 |
messages.append({"role": "assistant", "content": ""})
|
|
|
171 |
|
172 |
if __name__ == "__main__":
|
173 |
interface = VideoLLaMA3GradioInterface(
|
174 |
+
device=device,
|
|
|
175 |
example_dir="./examples",
|
176 |
)
|
177 |
interface.launch()
|