Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,863 Bytes
d9dadf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import numpy as np
import torch
from decord import VideoReader, cpu
from PIL import Image
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.mm_utils import (KeywordsStoppingCriteria, get_model_name_from_path,
process_images, tokenizer_image_token)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
title_markdown = ("""
<div style="display: flex; justify-content: flex-start; align-items: center; text-align: center;">
<div style="margin-right: 20px; display: flex; align-items: center;">
<a href="https://github.com/ShareGPT4Omni/ShareGPT4Video" style="text-decoration: none; display: flex; align-items: center;">
<img src="https://raw.githubusercontent.com/ShareGPT4V/ShareGPT4V-Resources/master/images/share4video_tight.png" alt="ShareGPT4Video🚀" style="max-width: 120px; height: auto;">
</a>
</div>
<div>
<h1>ShareGPT4Video: Improving Video Understanding and Generation with Better Captions</h1>
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
<h5 style="margin: 0;"> <a href="https://sharegpt4video.github.io/">[Project Page]</a> <a href="https://github.com/ShareGPT4Omni/ShareGPT4Video">[Code]</a> <a href="https://arxiv.org/abs/2406.04325v1">[Paper]</a>
</div>
</div>
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
learn_more_markdown = ("""
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")
def create_frame_grid(img_array, interval_width=50):
n, h, w, c = img_array.shape
grid_size = int(np.ceil(np.sqrt(n)))
horizontal_band = np.ones((h, interval_width, c),
dtype=img_array.dtype) * 255
vertical_band = np.ones((interval_width, w + (grid_size - 1)
* (w + interval_width), c), dtype=img_array.dtype) * 255
rows = []
for i in range(grid_size):
row_frames = []
for j in range(grid_size):
idx = i * grid_size + j
if idx < n:
frame = img_array[idx]
else:
frame = np.ones_like(img_array[0]) * 255
if j > 0:
row_frames.append(horizontal_band)
row_frames.append(frame)
combined_row = np.concatenate(row_frames, axis=1)
if i > 0:
rows.append(vertical_band)
rows.append(combined_row)
final_grid = np.concatenate(rows, axis=0)
return final_grid
def resize_image_grid(image, max_length=1920):
width, height = image.size
if max(width, height) > max_length:
if width > height:
scale = max_length / width
else:
scale = max_length / height
new_width = int(width * scale)
new_height = int(height * scale)
img_resized = image.resize((new_width, new_height), Image.BILINEAR)
else:
img_resized = image
return img_resized
def get_index(num_frames, num_segments):
seg_size = float(num_frames - 1) / num_segments
start = int(seg_size / 2)
offsets = np.array([
start + int(np.round(seg_size * idx)) for idx in range(num_segments)
])
return offsets
def load_video(video_path, num_segments=8, return_msg=False, num_frames=4):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr)
frame_indices = get_index(num_frames, num_segments)
img_array = vr.get_batch(frame_indices).asnumpy()
img_grid = create_frame_grid(img_array, 50)
img_grid = Image.fromarray(img_grid).convert("RGB")
img_grid = resize_image_grid(img_grid)
if return_msg:
fps = float(vr.get_avg_fps())
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# " " should be added in the start and end
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
return img_grid, msg
else:
return img_grid
def video_answer(prompt, model, processor, tokenizer, img_grid, do_sample=True,
max_new_tokens=200, num_beams=1, top_p=0.9,
temperature=1.0, print_res=False, **kwargs):
if not isinstance(img_grid, (list, tuple)):
img_grid = [img_grid]
image_size = img_grid[0].size
image_tensor = process_images(img_grid, processor, model.config)[0]
input_ids = tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
input_ids = input_ids.unsqueeze(0).to(
device=model.device, non_blocking=True)
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token is not None else tokenizer.eos_token_id
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.to(
dtype=torch.float16, device=model.device, non_blocking=True),
image_sizes=[image_size],
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
pad_token_id=pad_token_id,
use_cache=True,
**kwargs)
outputs = tokenizer.batch_decode(
output_ids, skip_special_tokens=True)[0].strip()
if print_res: # debug usage
print('### PROMPTING LM WITH: ', prompt)
print('### LM OUTPUT TEXT: ', outputs)
return outputs
class Chat:
def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', num_frames=16):
disable_torch_init()
model_name = get_model_name_from_path(model_path)
self.tokenizer, self.model, self.processor, context_len = load_pretrained_model(
model_path, model_base, model_name,
load_8bit, load_4bit,
device=device)
self.model.eval()
self.conv_mode = conv_mode
self.device = self.model.device
self.num_frames = num_frames
self.pre_query_prompt = "The provided image arranges keyframes from a video in a grid view, keyframes are separated with white bands. Answer concisely with overall content and context of the video, highlighting any significant events, characters, or objects that appear throughout the frames."
def get_prompt(self, qs, state):
state.append_message(state.roles[0], qs)
state.append_message(state.roles[1], None)
return state
@torch.inference_mode()
def generate(self, vid_path: list, prompt: str, first_run: bool, state):
if self.num_frames != 0:
vid, msg = load_video(
vid_path, num_segments=self.num_frames, return_msg=True)
else:
vid, msg = None, 'num_frames is 0, not inputing image'
img_grid = vid
if self.pre_query_prompt is not None:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + self.pre_query_prompt + prompt
else:
prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
state = self.get_prompt(prompt, state)
prompt = state.get_prompt()
llm_response = video_answer(prompt, model=self.model, processor=self.processor, tokenizer=self.tokenizer,
do_sample=True, temperature=0.1, img_grid=img_grid, max_new_tokens=1024, print_res=True)
return llm_response, state
|