import numpy as np import torch from decord import VideoReader, cpu from PIL import Image from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX from llava.conversation import conv_templates from llava.mm_utils import (KeywordsStoppingCriteria, get_model_name_from_path, process_images, tokenizer_image_token) from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init title_markdown = ("""
ShareGPT4Video🚀

ShareGPT4Video: Improving Video Understanding and Generation with Better Captions

If you like our project, please give us a star ✨ on Github for the latest update.
[Project Page] [Code] [Paper]
""") block_css = """ #buttons button { min-width: min(120px,100%); } """ learn_more_markdown = (""" ### License The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. """) def create_frame_grid(img_array, interval_width=50): n, h, w, c = img_array.shape grid_size = int(np.ceil(np.sqrt(n))) horizontal_band = np.ones((h, interval_width, c), dtype=img_array.dtype) * 255 vertical_band = np.ones((interval_width, w + (grid_size - 1) * (w + interval_width), c), dtype=img_array.dtype) * 255 rows = [] for i in range(grid_size): row_frames = [] for j in range(grid_size): idx = i * grid_size + j if idx < n: frame = img_array[idx] else: frame = np.ones_like(img_array[0]) * 255 if j > 0: row_frames.append(horizontal_band) row_frames.append(frame) combined_row = np.concatenate(row_frames, axis=1) if i > 0: rows.append(vertical_band) rows.append(combined_row) final_grid = np.concatenate(rows, axis=0) return final_grid def resize_image_grid(image, max_length=1920): width, height = image.size if max(width, height) > max_length: if width > height: scale = max_length / width else: scale = max_length / height new_width = int(width * scale) new_height = int(height * scale) img_resized = image.resize((new_width, new_height), Image.BILINEAR) else: img_resized = image return img_resized def get_index(num_frames, num_segments): seg_size = float(num_frames - 1) / num_segments start = int(seg_size / 2) offsets = np.array([ start + int(np.round(seg_size * idx)) for idx in range(num_segments) ]) return offsets def load_video(video_path, num_segments=8, return_msg=False, num_frames=4): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) num_frames = len(vr) frame_indices = get_index(num_frames, num_segments) img_array = vr.get_batch(frame_indices).asnumpy() img_grid = create_frame_grid(img_array, 50) img_grid = Image.fromarray(img_grid).convert("RGB") img_grid = resize_image_grid(img_grid) if return_msg: fps = float(vr.get_avg_fps()) sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) # " " should be added in the start and end msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." return img_grid, msg else: return img_grid def video_answer(prompt, model, processor, tokenizer, img_grid, do_sample=True, max_new_tokens=200, num_beams=1, top_p=0.9, temperature=1.0, print_res=False, **kwargs): if not isinstance(img_grid, (list, tuple)): img_grid = [img_grid] image_size = img_grid[0].size image_tensor = process_images(img_grid, processor, model.config)[0] input_ids = tokenizer_image_token( prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') input_ids = input_ids.unsqueeze(0).to( device=model.device, non_blocking=True) pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token is not None else tokenizer.eos_token_id with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor.to( dtype=torch.float16, device=model.device, non_blocking=True), image_sizes=[image_size], do_sample=do_sample, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=max_new_tokens, pad_token_id=pad_token_id, use_cache=True, **kwargs) outputs = tokenizer.batch_decode( output_ids, skip_special_tokens=True)[0].strip() if print_res: # debug usage print('### PROMPTING LM WITH: ', prompt) print('### LM OUTPUT TEXT: ', outputs) return outputs class Chat: def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', num_frames=16): disable_torch_init() model_name = get_model_name_from_path(model_path) self.tokenizer, self.model, self.processor, context_len = load_pretrained_model( model_path, model_base, model_name, load_8bit, load_4bit, device=device) self.model.eval() self.conv_mode = conv_mode self.device = self.model.device self.num_frames = num_frames self.pre_query_prompt = "The provided image arranges keyframes from a video in a grid view, keyframes are separated with white bands. Answer concisely with overall content and context of the video, highlighting any significant events, characters, or objects that appear throughout the frames." def get_prompt(self, qs, state): state.append_message(state.roles[0], qs) state.append_message(state.roles[1], None) return state @torch.inference_mode() def generate(self, vid_path: list, prompt: str, first_run: bool, state): if self.num_frames != 0: vid, msg = load_video( vid_path, num_segments=self.num_frames, return_msg=True) else: vid, msg = None, 'num_frames is 0, not inputing image' img_grid = vid if self.pre_query_prompt is not None: prompt = DEFAULT_IMAGE_TOKEN + '\n' + self.pre_query_prompt + prompt else: prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt state = self.get_prompt(prompt, state) prompt = state.get_prompt() llm_response = video_answer(prompt, model=self.model, processor=self.processor, tokenizer=self.tokenizer, do_sample=True, temperature=0.1, img_grid=img_grid, max_new_tokens=1024, print_res=True) return llm_response, state