ShareGPT4Video-8B / llava /video_utils.py
chenlin
init
d9dadf3
raw
history blame
5.72 kB
import argparse
import io
import json
import os
import random
import tempfile
from multiprocessing import Manager, Pool, cpu_count
import cv2
import imageio
import numpy as np
from decord import VideoReader
from PIL import Image
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ["rand", "middle"]: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(
start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(
range(x[0], x[1])) for x in ranges]
except Exception:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
# gap between frames, this is also the clip length each frame represents
delta = 1 / output_fps
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
else:
raise ValueError
return frame_indices
def get_index(num_frames, bound, fps, max_frame, first_idx=0):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_frames
frame_indices = np.array([
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
for idx in range(num_frames)
])
return frame_indices
def read_frames_gif(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, clip=None,
):
if video_path.startswith('s3') or video_path.startswith('p2'):
video_bytes = client.get(video_path)
gif = imageio.get_reader(io.BytesIO(video_bytes))
else:
gif = imageio.get_reader(video_path)
vlen = len(gif)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
max_num_frames=max_num_frames
)
frames = []
reference_size = None
for index, frame in enumerate(gif):
# for index in frame_idxs:
if index in frame_indices:
if frame.ndim == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
if reference_size is None:
reference_size = (frame.shape[1], frame.shape[0])
frame = cv2.resize(frame, reference_size,
interpolation=cv2.INTER_LINEAR)
frames.append(frame)
frames = np.stack(frames, axis=0) # .float() / 255
return frames
def read_frames_decord(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, clip=None
):
if video_path.startswith('s3') or video_path.startswith('p2') or video_path.startswith('p_hdd') or video_path.startswith('cluster1'):
video_bytes = client.get(video_path)
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
else:
video_reader = VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
if clip:
vlen = int(duration * fps)
frame_indices = get_index(num_frames, clip, fps, vlen)
else:
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
# if clip:
# frame_indices = [f + start_index for f in frame_indices]
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C)
return frames
def read_diff_frames_decord(
video_path, clip, client=None
):
if video_path.startswith('s3') or video_path.startswith('p2') or video_path.startswith('p_hdd') or video_path.startswith('cluster1') or video_path.startswith('s_hdd'):
video_bytes = client.get(video_path)
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
else:
video_reader = VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
start_idx = round(clip[0]*fps)
end_idx = min(round(clip[1]*fps), vlen)
frame_indices = [start_idx, end_idx]
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C)
return frames