Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
from decord import VideoReader, cpu | |
from PIL import Image | |
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
from llava.conversation import conv_templates | |
from llava.mm_utils import (KeywordsStoppingCriteria, get_model_name_from_path, | |
process_images, tokenizer_image_token) | |
from llava.model.builder import load_pretrained_model | |
from llava.utils import disable_torch_init | |
title_markdown = (""" | |
<div style="display: flex; justify-content: flex-start; align-items: center; text-align: center;"> | |
<div style="margin-right: 20px; display: flex; align-items: center;"> | |
<a href="https://github.com/ShareGPT4Omni/ShareGPT4Video" style="text-decoration: none; display: flex; align-items: center;"> | |
<img src="https://raw.githubusercontent.com/ShareGPT4V/ShareGPT4V-Resources/master/images/share4video_tight.png" alt="ShareGPT4Video🚀" style="max-width: 120px; height: auto;"> | |
</a> | |
</div> | |
<div> | |
<h1>ShareGPT4Video: Improving Video Understanding and Generation with Better Captions</h1> | |
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5> | |
<h5 style="margin: 0;"> <a href="https://sharegpt4video.github.io/">[Project Page]</a> <a href="https://github.com/ShareGPT4Omni/ShareGPT4Video">[Code]</a> <a href="https://arxiv.org/abs/2406.04325v1">[Paper]</a> | |
</div> | |
</div> | |
""") | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
learn_more_markdown = (""" | |
### License | |
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. | |
""") | |
def create_frame_grid(img_array, interval_width=50): | |
n, h, w, c = img_array.shape | |
grid_size = int(np.ceil(np.sqrt(n))) | |
horizontal_band = np.ones((h, interval_width, c), | |
dtype=img_array.dtype) * 255 | |
vertical_band = np.ones((interval_width, w + (grid_size - 1) | |
* (w + interval_width), c), dtype=img_array.dtype) * 255 | |
rows = [] | |
for i in range(grid_size): | |
row_frames = [] | |
for j in range(grid_size): | |
idx = i * grid_size + j | |
if idx < n: | |
frame = img_array[idx] | |
else: | |
frame = np.ones_like(img_array[0]) * 255 | |
if j > 0: | |
row_frames.append(horizontal_band) | |
row_frames.append(frame) | |
combined_row = np.concatenate(row_frames, axis=1) | |
if i > 0: | |
rows.append(vertical_band) | |
rows.append(combined_row) | |
final_grid = np.concatenate(rows, axis=0) | |
return final_grid | |
def resize_image_grid(image, max_length=1920): | |
width, height = image.size | |
if max(width, height) > max_length: | |
if width > height: | |
scale = max_length / width | |
else: | |
scale = max_length / height | |
new_width = int(width * scale) | |
new_height = int(height * scale) | |
img_resized = image.resize((new_width, new_height), Image.BILINEAR) | |
else: | |
img_resized = image | |
return img_resized | |
def get_index(num_frames, num_segments): | |
seg_size = float(num_frames - 1) / num_segments | |
start = int(seg_size / 2) | |
offsets = np.array([ | |
start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
]) | |
return offsets | |
def load_video(video_path, num_segments=8, return_msg=False, num_frames=4): | |
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
num_frames = len(vr) | |
frame_indices = get_index(num_frames, num_segments) | |
img_array = vr.get_batch(frame_indices).asnumpy() | |
img_grid = create_frame_grid(img_array, 50) | |
img_grid = Image.fromarray(img_grid).convert("RGB") | |
img_grid = resize_image_grid(img_grid) | |
if return_msg: | |
fps = float(vr.get_avg_fps()) | |
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) | |
# " " should be added in the start and end | |
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." | |
return img_grid, msg | |
else: | |
return img_grid | |
def video_answer(prompt, model, processor, tokenizer, img_grid, do_sample=True, | |
max_new_tokens=200, num_beams=1, top_p=0.9, | |
temperature=1.0, print_res=False, **kwargs): | |
if not isinstance(img_grid, (list, tuple)): | |
img_grid = [img_grid] | |
image_size = img_grid[0].size | |
image_tensor = process_images(img_grid, processor, model.config)[0] | |
input_ids = tokenizer_image_token( | |
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') | |
input_ids = input_ids.unsqueeze(0).to( | |
device=model.device, non_blocking=True) | |
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token is not None else tokenizer.eos_token_id | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=image_tensor.to( | |
dtype=torch.float16, device=model.device, non_blocking=True), | |
image_sizes=[image_size], | |
do_sample=do_sample, | |
temperature=temperature, | |
top_p=top_p, | |
num_beams=num_beams, | |
max_new_tokens=max_new_tokens, | |
pad_token_id=pad_token_id, | |
use_cache=True, | |
**kwargs) | |
outputs = tokenizer.batch_decode( | |
output_ids, skip_special_tokens=True)[0].strip() | |
if print_res: # debug usage | |
print('### PROMPTING LM WITH: ', prompt) | |
print('### LM OUTPUT TEXT: ', outputs) | |
return outputs | |
class Chat: | |
def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', num_frames=16): | |
disable_torch_init() | |
model_name = get_model_name_from_path(model_path) | |
self.tokenizer, self.model, self.processor, context_len = load_pretrained_model( | |
model_path, model_base, model_name, | |
load_8bit, load_4bit, | |
device=device) | |
self.model.eval() | |
self.conv_mode = conv_mode | |
self.device = self.model.device | |
self.num_frames = num_frames | |
self.pre_query_prompt = "The provided image arranges keyframes from a video in a grid view, keyframes are separated with white bands. Answer concisely with overall content and context of the video, highlighting any significant events, characters, or objects that appear throughout the frames." | |
def get_prompt(self, qs, state): | |
state.append_message(state.roles[0], qs) | |
state.append_message(state.roles[1], None) | |
return state | |
def generate(self, vid_path: list, prompt: str, first_run: bool, state): | |
if self.num_frames != 0: | |
vid, msg = load_video( | |
vid_path, num_segments=self.num_frames, return_msg=True) | |
else: | |
vid, msg = None, 'num_frames is 0, not inputing image' | |
img_grid = vid | |
if self.pre_query_prompt is not None: | |
prompt = DEFAULT_IMAGE_TOKEN + '\n' + self.pre_query_prompt + prompt | |
else: | |
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt | |
state = self.get_prompt(prompt, state) | |
prompt = state.get_prompt() | |
llm_response = video_answer(prompt, model=self.model, processor=self.processor, tokenizer=self.tokenizer, | |
do_sample=True, temperature=0.1, img_grid=img_grid, max_new_tokens=1024, print_res=True) | |
return llm_response, state | |