import numpy as np
import torch
from decord import VideoReader, cpu
from PIL import Image
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.mm_utils import (KeywordsStoppingCriteria, get_model_name_from_path,
process_images, tokenizer_image_token)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
title_markdown = ("""
ShareGPT4Video: Improving Video Understanding and Generation with Better Captions
If you like our project, please give us a star ✨ on Github for the latest update.
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
learn_more_markdown = ("""
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")
def create_frame_grid(img_array, interval_width=50):
n, h, w, c = img_array.shape
grid_size = int(np.ceil(np.sqrt(n)))
horizontal_band = np.ones((h, interval_width, c),
dtype=img_array.dtype) * 255
vertical_band = np.ones((interval_width, w + (grid_size - 1)
* (w + interval_width), c), dtype=img_array.dtype) * 255
rows = []
for i in range(grid_size):
row_frames = []
for j in range(grid_size):
idx = i * grid_size + j
if idx < n:
frame = img_array[idx]
else:
frame = np.ones_like(img_array[0]) * 255
if j > 0:
row_frames.append(horizontal_band)
row_frames.append(frame)
combined_row = np.concatenate(row_frames, axis=1)
if i > 0:
rows.append(vertical_band)
rows.append(combined_row)
final_grid = np.concatenate(rows, axis=0)
return final_grid
def resize_image_grid(image, max_length=1920):
width, height = image.size
if max(width, height) > max_length:
if width > height:
scale = max_length / width
else:
scale = max_length / height
new_width = int(width * scale)
new_height = int(height * scale)
img_resized = image.resize((new_width, new_height), Image.BILINEAR)
else:
img_resized = image
return img_resized
def get_index(num_frames, num_segments):
seg_size = float(num_frames - 1) / num_segments
start = int(seg_size / 2)
offsets = np.array([
start + int(np.round(seg_size * idx)) for idx in range(num_segments)
])
return offsets
def load_video(video_path, num_segments=8, return_msg=False, num_frames=4):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr)
frame_indices = get_index(num_frames, num_segments)
img_array = vr.get_batch(frame_indices).asnumpy()
img_grid = create_frame_grid(img_array, 50)
img_grid = Image.fromarray(img_grid).convert("RGB")
img_grid = resize_image_grid(img_grid)
if return_msg:
fps = float(vr.get_avg_fps())
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# " " should be added in the start and end
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
return img_grid, msg
else:
return img_grid
def video_answer(prompt, model, processor, tokenizer, img_grid, do_sample=True,
max_new_tokens=200, num_beams=1, top_p=0.9,
temperature=1.0, print_res=False, **kwargs):
if not isinstance(img_grid, (list, tuple)):
img_grid = [img_grid]
image_size = img_grid[0].size
image_tensor = process_images(img_grid, processor, model.config)[0]
input_ids = tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
input_ids = input_ids.unsqueeze(0).to(
device=model.device, non_blocking=True)
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token is not None else tokenizer.eos_token_id
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.to(
dtype=torch.float16, device=model.device, non_blocking=True),
image_sizes=[image_size],
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
pad_token_id=pad_token_id,
use_cache=True,
**kwargs)
outputs = tokenizer.batch_decode(
output_ids, skip_special_tokens=True)[0].strip()
if print_res: # debug usage
print('### PROMPTING LM WITH: ', prompt)
print('### LM OUTPUT TEXT: ', outputs)
return outputs
class Chat:
def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', num_frames=16):
disable_torch_init()
model_name = get_model_name_from_path(model_path)
self.tokenizer, self.model, self.processor, context_len = load_pretrained_model(
model_path, model_base, model_name,
load_8bit, load_4bit,
device=device)
self.model.eval()
self.conv_mode = conv_mode
self.device = self.model.device
self.num_frames = num_frames
self.pre_query_prompt = "The provided image arranges keyframes from a video in a grid view, keyframes are separated with white bands. Answer concisely with overall content and context of the video, highlighting any significant events, characters, or objects that appear throughout the frames."
def get_prompt(self, qs, state):
state.append_message(state.roles[0], qs)
state.append_message(state.roles[1], None)
return state
@torch.inference_mode()
def generate(self, vid_path: list, prompt: str, first_run: bool, state):
if self.num_frames != 0:
vid, msg = load_video(
vid_path, num_segments=self.num_frames, return_msg=True)
else:
vid, msg = None, 'num_frames is 0, not inputing image'
img_grid = vid
if self.pre_query_prompt is not None:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + self.pre_query_prompt + prompt
else:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
state = self.get_prompt(prompt, state)
prompt = state.get_prompt()
llm_response = video_answer(prompt, model=self.model, processor=self.processor, tokenizer=self.tokenizer,
do_sample=True, temperature=0.1, img_grid=img_grid, max_new_tokens=1024, print_res=True)
return llm_response, state