import requests import re, ujson, os, sys, fire, glob, random, time, json import numpy as np import io import torch from torch.utils.data import default_collate import torchaudio from typing import * from dataclasses import dataclass, field import transformers from transformers.modeling_outputs import ModelOutput from transformers.audio_utils import mel_filter_bank, spectrogram, window_function from functools import lru_cache from io import BytesIO from PIL import Image from qcloud_cos import CosConfig from qcloud_cos import CosS3Client import tos import concurrent.futures as cf from transformers.image_transforms import resize, center_crop, get_resize_output_image_size from transformers.image_utils import PILImageResampling from PIL import Image, ImageOps from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True import base64 from decord import VideoReader, cpu import cv2 import av import imagesize import math def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 ): """Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. """ # if height < factor or width < factor: # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") if max(height, width) / min(height, width) > 200: raise ValueError( f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) h_bar = round(height / factor) * factor if height > factor else factor w_bar = round(width / factor) * factor if width > factor else factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = math.floor(height / beta / factor) * factor w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar def select_best_resolution(image_size, candidate_resolutions): '''找到最佳的resolution 对于原图进行放缩 image_size 通常为ori_size e.g. (8*336, 16*336) candidate_resolutions 为备选分辨率 e.g. (1*336, 4*336) ''' try: original_width, original_height = image_size except: pass best_fit = None max_effective_resolution = 0 min_wasted_resolution = float("inf") # 从candidate_resolutions 中遍历宽和高 for width, height in candidate_resolutions: # width / original_width 和 height / original_height 中最小的那个作为scale scale = min(width / original_width, height / original_height) # e.g. scale =min (1/8, 1/4) = 1/8 # 放缩 original_width 和 original_height downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) # e.g. 1*336, 2*336 # effective_resolution 为 放缩之后的分辨率 s^2 * w * h effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) # e.g. min(1*336 * 2*336, 8*336 * 16*336) # wasted_resolution 为 放缩前后分辨率的差值 wasted_resolution = (width * height) - effective_resolution # 若 (1) 放缩之后的分辨率 比当前的max_effective_resolution更大; # (2) if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): max_effective_resolution = effective_resolution # 更新max_effective_resolution min_wasted_resolution = wasted_resolution # min_wasted_resolution best_fit = (width, height) return best_fit def read_video(image_path, max_frame_number, decode_way): if decode_way=='1fps': try: vr = VideoReader(image_path, ctx=cpu(0)) total_frame_num = len(vr) fps = round(vr.get_avg_fps()) frame_idx = [i for i in range(0, len(vr), fps)] frames = vr.get_batch(frame_idx).asnumpy() frames = [i for i in frames] cnt = len(frames) except Exception as e: print(image_path) print('error is', e) return None elif decode_way=='key': try: with av.open(image_path) as container: stream = container.streams.video[0] stream.codec_context.skip_frame = 'NONKEY' frames = [] fps = int(stream.average_rate) cnt = 0 for frame in container.decode(stream): # 关键帧存成image patch image = frame.to_image() frames.append(image) cnt += 1 except Exception as e: print('error is', e) return None if frames is None or len(frames)==0: return None if len(frames)>max_frame_number and max_frame_number>0: # 生成均匀间隔的索引 indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int) # 根据索引获取对应元素 sampled_elements = [frames[idx] for idx in indices] frames = sampled_elements return frames class OceanImageProcessor: def __init__(self, config, **kwargs): self.config = config # visual_config self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56 self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280 self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14 self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2 self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2 self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2 def image_transform(self, strseq, return_mm_data = True): image = None if isinstance(strseq, str): if return_mm_data: image = Image.open(strseq).convert("RGB") else: image = Image.open(BytesIO(strseq)).convert("RGB") image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。 image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。 # resize, crop, scale, normalize # 接受目标尺寸作为输入参数,通常是目标尺寸的短边或长边长度。例如,如果指定目标短边为 336 像素,函数会自动计算出对应的长边大小,以保持图像的宽高比。 # 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。 resized_height, resized_width = smart_resize( image_org_size[0], image_org_size[1], factor=self.patch_size * self.spatial_merge_size, min_pixels=self.min_pixels, max_pixels=self.max_pixels, ) output_size = (resized_height, resized_width) # output_size = get_resize_output_image_size(image, self.config.crop_size, False) # 短边resize到336 # 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。 # image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例; # resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。 image = resize(image, output_size, PILImageResampling.BICUBIC) # 从图像中心裁剪出一个指定大小的区域,这里是一个正方形区域 self.config.crop_size x self.config.crop_size。center_crop 函数的参数 return_numpy=True 表示返回一个 NumPy 数组形式的裁剪图像。 # image = center_crop(image, (self.config.crop_size, self.config.crop_size), return_numpy=True) img = image.transpose(2, 0, 1) # 对图像进行归一化和标准化处理 image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis] # 处理成patch patches = image[np.newaxis, :] if patches.shape[0] == 1: patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) channel = patches.shape[1] grid_t = patches.shape[0] // self.temporal_patch_size grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size patches = patches.reshape( grid_t, self.temporal_patch_size, channel, grid_h // self.spatial_merge_size, self.spatial_merge_size, self.patch_size, grid_w // self.spatial_merge_size, self.spatial_merge_size, self.patch_size, ) patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size ) return flatten_patches, image_org_size, (grid_t, grid_h, grid_w) class OceanAudioProcessor: # 包含基本的音频特征抽取模块 + 输入数据解析模块 + cos请求/缓存模块 def __init__( self, config, # audio processor config **kwargs ): # make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio assert(len(torchaudio.list_audio_backends()) > 0) self.config = config self.mel_filters = mel_filter_bank( num_frequency_bins=1 + self.config.n_fft // 2, num_mel_filters=self.config.num_mel_bins, min_frequency=0.0, max_frequency=self.config.sampling_rate / 2.0, sampling_rate=self.config.sampling_rate, norm="slaney", mel_scale="slaney", ) @staticmethod def zero_mean_unit_var_norm(x): return (x - x.mean()) / torch.sqrt(x.var() + 1e-8) def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False): metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio waveform_tensor, _ = torchaudio.load(uri, normalize=True) if self.config.sampling_rate != metadata.sample_rate: waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate) # downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation if metadata.num_channels > 1: waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True) # normalized to zero mean (Qwen Audio没有处理 但Whisper官方实现) if do_normalize: waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor) if return_tensors: # (channels, samples) return waveform_tensor else: return waveform_tensor.numpy() def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段 channels, wave_samples = waveform.shape max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate if wave_samples <= max_audio_samples or self.config.split_overlap < 0: return [waveform] # 没有超出最大长度or截断逻辑 统一返回list split_waveform, start = [], 0 while start < wave_samples: # 20240724修改 统一按秒数对齐overlap 保证不同sampling rate/n_fft/hop length配置下采到的数据是一致的 if start > int(self.config.sampling_rate * self.config.split_overlap): start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数 end = min(start + max_audio_samples, wave_samples) split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃 start = end return split_waveform @classmethod def inference_output_length(cls, config, input_length): # for whisper + bridge kernel_size = config.kernel_size stride_size = config.stride_size avg_pooler = config.avg_pooler encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1 encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1 if avg_pooler > 1: bridge_length = encoder_length // avg_pooler return encoder_length, bridge_length def extract_fbank_features(self, waveform): # ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py channels, wave_samples = waveform.shape assert(wave_samples >= self.config.n_fft) valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1) if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate: waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0) else: waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate] window = torch.hann_window(self.config.n_fft) stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1 magnitudes = stft[..., :-1].abs() ** 2 mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) mel_spec = mel_filters.T @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() if waveform.dim() == 2: max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] log_spec = torch.maximum(log_spec, max_val - 8.0) else: log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples) log_spec[:, valid_frame_nums:] = 0.0 # pad0 在collect时取batch内最大长度 return log_spec, valid_frame_nums def data_augment(self, feature: np.array, input_length, training=True): # reference https://arxiv.org/pdf/1904.08779 # run only on cpu def mask_start_indices(input_length, mask_length, min_masks, mask_prob): # 计算总共需要mask的span数 之后随机筛选span开始下标 num_masked_span = int(mask_prob * input_length / mask_length + random.random()) num_masked_span = max(num_masked_span, min_masks) start_indices = list(range(input_length - mask_length)) random.shuffle(start_indices) start_indices = start_indices[:num_masked_span] return start_indices if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0): return feature if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1: return feature if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1: return feature if self.config.mask_time_prob > 0: start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob) for start_idx in start_indices: feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0 if self.config.mask_feature_prob > 0: start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob) for start_idx in start_indices: feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0 return feature class CosClient(): def __init__(self, bucket_name='crawl-pic-1317568651', max_retries=2): self.config = CosConfig( Endpoint="cos.ap-guangzhou.myqcloud.com", # Region='ap-guangzhou', SecretId='AKIDnRpxoOghgVs0tkU3Mfv20jAMI0SRDj02', SecretKey='td9tRlqiPvEJ8i27wXwBIDiy5ye6JGyS', Token=None, Scheme='https', Timeout=300) self.client = CosS3Client(self.config) self.max_retries = max_retries self.bucket_name = bucket_name def __call__(self, relative_path, bucket_name=None): if bucket_name is None or len(bucket_name) <= 0: bucket_name = self.bucket_name multimodal_bytes = None for _ in range(self.max_retries): try: response = self.client.get_object(Bucket=bucket_name, Key=relative_path) fp = response['Body'].get_raw_stream() multimodal_bytes = fp.read() break except Exception as e: time.sleep(0.01) continue return multimodal_bytes class TosClient(object): def __init__(self): ak = "AKLTYTM3MWY5MTFhNDgyNDk4YjhmYTE0ZTE3YTk5ZmU1MjU" sk = "TVRRM1pUZGtaVEJqWTJJd05HSTNPR0ppWVdKa1lqYzVORFUwTlRobU1UVQ==" endpoint = "tos-cn-beijing.ivolces.com" # "tos-cn-beijing.ivolces.com" region = "cn-beijing" self.bucket_name = "audio-dataset" self.client = tos.TosClientV2(ak, sk, endpoint, region) def __call__(self, path, bucket_name=None): if bucket_name is None: bucket_name = self.bucket_name for _ in range(2): try: object_stream = self.client.get_object(bucket_name, path) return object_stream.read() except Exception as e: time.sleep(0.01) continue return None @dataclass class OceanProcessorOutput(ModelOutput): input_ids: Optional["List|torch.Tensor"] = None labels: Optional["List|torch.Tensor"] = None attention_mask: Optional["List|torch.Tensor"] = None position_ids: Optional["List|torch.Tensor"] = None seqlens: Optional["List|torch.Tensor"] = None # 需要配合Ocean Modeling使用 # audio fields audios: Optional["List|torch.Tensor"] = None encoder_length: Optional["List|torch.Tensor"] = None bridge_length: Optional["List|torch.Tensor"] = None # image fields images: Optional["List|torch.Tensor"] = None patch_nums: Optional["List|torch.Tensor"] = None images_size: Optional["List|torch.Tensor"] = None crop_size: Optional["List|torch.Tensor"] = None images_grid: Optional["List|torch.Tensor"] = None # video fields videos: Optional["List|torch.Tensor"] = None videos_patch_nums: Optional["List|torch.Tensor"] = None videos_size: Optional["List|torch.Tensor"] = None videos_crop_size: Optional["List|torch.Tensor"] = None videos_grid: Optional["List|torch.Tensor"] = None # processor fields raw_text: Optional[str] = None index: Optional[int] = None def concatenate(self, other): # 仅限list使用 def concat_one(a, b): if a is None and b is None: return None elif a is None and b is not None: return b elif a is not None and b is None: return a else: return a + b return OceanProcessorOutput( input_ids=concat_one(self.input_ids, other.input_ids), labels=concat_one(self.labels, other.labels), audios=concat_one(self.audios, other.audios), encoder_length=concat_one(self.encoder_length, other.encoder_length), bridge_length=concat_one(self.bridge_length, other.bridge_length), images=concat_one(self.images, other.images), images_grid=concat_one(self.images_grid, other.images_grid), patch_nums=concat_one(self.patch_nums, other.patch_nums), videos=concat_one(self.videos, other.videos), videos_grid=concat_one(self.videos_grid, other.videos_grid), videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums), position_ids=concat_one(self.position_ids, other.position_ids), seqlens=concat_one(self.seqlens, other.seqlens), images_size=concat_one(self.images_size, other.images_size) ) class OceanMMProcessor(object): def __init__(self, tokenizer: transformers.PreTrainedTokenizer, config, training, relative_path=None, **kwargs, ): self.tokenizer = tokenizer self.config = config self.audio_processor = None if hasattr(config, "audio_config"): self.audio_processor = OceanAudioProcessor(config.audio_config) self.visual_processor = None if hasattr(config, "visual_config"): self.visual_processor = OceanImageProcessor(config.visual_config) self.video_processor = None if hasattr(config, "video_config"): self.video_processor = OceanImageProcessor(config.video_config) self.training = training self.relative_path = relative_path self.cos_client = CosClient() self.tos_client = TosClient() # audio tag self.audio_start_tag = None self.audio_end_tag = None self.audio_pad_tag = None self.audio_delim_tag = None if hasattr(self.config, "audio_config"): self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id) self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id) self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id) self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id) # image tag self.image_start_tag = None self.image_end_tag = None self.image_pad_tag = None self.video_start_tag = None self.video_end_tag = None if hasattr(self.config, "visual_config"): # special token for start_tag self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id) # special token for end_tag self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id) # special token for pad_tag self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id) self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id) self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id) if hasattr(self.config, "video_config"): self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id) self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id) self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id) self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id) self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id) self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id) # @lru_cache(maxsize=1024) def _get_audio(self, audio_info, return_mm_data = True): try: audio_info = ujson.loads(audio_info) audio_uri = None if 'path' in audio_info.keys(): if self.relative_path is not None: # 优先匹配本地路径 audio_uri = os.path.join(self.relative_path, audio_info['path']) if not os.path.exists(audio_uri): audio_uri = None if audio_uri is None: # 本地没有尝试取cos/tos if audio_info.get('server', 'cos') == 'tos': audio_uri = self.tos_client(audio_info['path'], 'audio-dataset') else: audio_uri = self.cos_client(audio_info['path'], 'audio-data-1317568651') elif 'local' in audio_info.keys(): audio_uri = audio_info['local'] if not os.path.exists(audio_uri): audio_uri = None return OceanProcessorOutput() else: raise ValueError("can not find path or local in audio_info") waveforms = self.audio_processor.load_audio_waveform(audio_uri, True) waveforms = self.audio_processor.split_with_overlap(waveforms) # 分割逻辑 ret = OceanProcessorOutput() # 默认初始化 audios字段为None for waveform in waveforms: audio, input_length = self.audio_processor.extract_fbank_features(waveform) audio = self.audio_processor.data_augment(audio, input_length, self.training) encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length) if bridge_length <= 0: # 过滤极端短数据 1. 如果len(waveforms)==1 ret=None; 2. len(waveforms)>1 则说明最后一段太短被抛弃 continue current_ret = OceanProcessorOutput( audios=[audio], encoder_length=[encoder_length], bridge_length=[bridge_length]) if ret.audios is None: ret = current_ret else: ret = ret.concatenate(current_ret) # 拼接多个切片 if not return_mm_data: ret.audios = [None] return ret except Exception as e: print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info))) return OceanProcessorOutput() # @lru_cache(maxsize=1024) def _get_image(self, image_info, return_mm_data = True): try: try: # chensong image_info = ujson.loads(image_info) except: #image_info = image_info.replace("'", '"') image_info = re.sub(r"(? 16**2: # 极端小的图过滤 return OceanProcessorOutput( images=[image_feat], patch_nums=[patch_nums], crop_size=[image_list], images_size= [org_size], images_grid=[image_list] ) else: print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info))) return OceanProcessorOutput() except Exception as e: print("**** get image error: {}, info: {} *****".format(str(e), str(image_info))) return OceanProcessorOutput() # @lru_cache(maxsize=1024) def _get_video_frame(self, video_frame_info, return_mm_data = True): try: pattern = r'\{.*?\}' matches = re.findall(pattern, video_frame_info) ret = OceanProcessorOutput() # 逐个解析 for match in matches: video_frame_info = ujson.loads(match) if 'local' in video_frame_info.keys(): image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'],return_mm_data = return_mm_data) else: raise ValueError("can not find any path in image_info") merge_length = self.video_processor.merge_size**2 patch_nums = np.array(image_list).prod() // merge_length if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤 ret = ret.concatenate( OceanProcessorOutput( videos=[image_feat], videos_patch_nums=[patch_nums], videos_crop_size=[image_list], videos_size= [org_size], videos_grid=[image_list] ) ) else: print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info))) return ret except Exception as e: print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info))) return OceanProcessorOutput() # 读取视频 def _get_video_obj_byte(self, source, path, video_obj_json): video_obj_byte = None if source == "cos": start_time = time.time() video_obj_byte = self.cos_client(path, bucket_name=video_obj_json.get("cos_bucket", None)) if (time.time() - start_time) > 1.0: self.reflash_cos_client() if source == "local": if os.path.exists(path): video_obj_byte = open(path, "rb").read() else: video_obj_byte = None if source == "base64": video_obj_byte = base64.b64decode(path) if source == "url": video_obj_byte = requests.get(url=path).content return video_obj_byte # 将视频切分为帧,保存至子目录中 def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"): video_path = video_info['local'] # 帧保存本地路径 frame_path = video_path.split('.')[0] + '_frames' if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0: # 保存帧 os.makedirs(frame_path, exist_ok=True) mm_obj_byte = self._get_video_obj_byte('local', video_path, video_info) if mm_obj_byte is None: # 未读取到视频文件 return "" frames = read_video(io.BytesIO(mm_obj_byte), max_frame_number=max_frame_number, decode_way=decode_way) #读取全部帧 for frame_idx, frame in enumerate(frames): output_filename = os.path.join(frame_path, f"{frame_idx}.jpg") frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) cv2.imwrite(output_filename, frame) # 选取帧 frame_number = len([filename for filename in os.listdir(frame_path) if filename.endswith('.jpg')]) if frame_number>max_frame_number: indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int) else: indices = np.linspace(0, frame_number - 1, frame_number, dtype=int) # 拼接模式 replace_str = "" for idx in indices: frame_str = f"{self.image_start_tag}{os.path.join(frame_path, f'{idx}.jpg')}{self.image_end_tag}" replace_str += frame_str return replace_str def _get_video_frame_str(self, video_info, return_mm_data = True ): try: video_info = ujson.loads(video_info) if 'local' in video_info.keys(): # 获取包含多帧图像路径的字符串,最大帧数量max_frame_number frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way) if frames_str != "": parts = frames_str.split(self.image_end_tag) result = [] for part in parts: if self.image_start_tag in part: before_path, path = part.split(self.image_start_tag) new_path = f'{self.image_start_tag}{{"local": "{path}"}}{self.image_end_tag}' result.append(before_path + new_path) else: result.append(part) return ''.join(result) else: raise ValueError('can not find localpath in video_info') except Exception as e: print("**** get video error: {}, info: {} *****".format(str(e), str(video_info))) return "" # def _replace_audio(self, audio_text, return_mm_data = True): # audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text) # ret = self._get_audio(audio_info, return_mm_data) # 重复取结果 cached result def _replace_audio(self, audio_text, mminfo_ret_dict): audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text) # ret = self._get_audio(audio_info) # 重复取结果 cached result ret = mminfo_ret_dict.get(audio_info, OceanProcessorOutput()) # 直接从字典取 if ret.bridge_length is not None: # TODO 如果pad token很多 tokenizer效率会很低 replaced_text = [self.audio_pad_tag * l for l in ret.bridge_length] replaced_text = self.audio_delim_tag.join(replaced_text) return self.audio_start_tag + replaced_text + self.audio_end_tag return '' # def _replace_image(self, image_text, return_mm_data = True): # image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text) # ret = self._get_image(image_info, return_mm_data) # 重复取结果 cached result def _replace_image(self, image_text, mminfo_ret_dict): image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text) # ret = self._get_image(image_info) # 重复取结果 cached result ret = mminfo_ret_dict.get(image_info, OceanProcessorOutput()) # 直接从字典取 if ret.patch_nums is None: return '' return self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag return '' # def _replace_video_frame(self, video_frame_text, return_mm_data = True): # video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text) # ret = self._get_video_frame(video_frame_info, return_mm_data) # 重复取结果 cached result def _replace_video_frame(self, video_frame_text, mminfo_ret_dict): video_frame_info = re.sub(re.compile(self.video_start_tag + '|' + self.video_end_tag), '', video_frame_text) video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_info) # ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result ret = mminfo_ret_dict.get(video_frame_info, OceanProcessorOutput()) if ret.videos_patch_nums is None: return '' video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))] return ''.join(video_frame_str) def extract_replace_multimodal(self, text, mtype='audio', return_mm_data = True): # 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token if (self.audio_start_tag != None) and (mtype == 'audio'): match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag) drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag) extract_func = self._get_audio replace_func = self._replace_audio elif (self.image_start_tag != None) and (mtype == 'image'): match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag) drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag) extract_func = self._get_image replace_func = self._replace_image elif (self.video_start_tag != None) and (mtype == 'video'): video_match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag) video_drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag) # 处理视频,将视频路径转换为多帧图像路径 mm_info_list = re.findall(video_match_regex, text) for mm_info in mm_info_list: frame_str = self._get_video_frame_str(re.sub(video_drop_regex, '', mm_info)) # 替换路径;如果视频不存在,路径替换为空字符串 text = re.sub(mm_info, self.video_start_tag + frame_str + self.video_end_tag, text) # 采用多图像处理方式 match_regex = re.compile(self.video_start_tag+r'(.*?)'+self.video_end_tag) drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag) extract_func = self._get_video_frame replace_func = self._replace_video_frame else: raise ValueError("mtype not supportted!") mm_info_list = re.findall(match_regex, text) mm_info_list = [re.sub(drop_regex, '', mm_info) for mm_info in mm_info_list] mminfo_ret_dict = {} ret = OceanProcessorOutput() for mm_info in mm_info_list: # 如果没有匹配到对应的模态 直接返回raw_text=text 结果不会是None mm_ret = extract_func(mm_info, return_mm_data = return_mm_data) mminfo_ret_dict[mm_info] = mm_ret if mm_ret.audios is None and mm_ret.images is None and mm_ret.videos is None: # 数据包含音频/图像/视频但抽取失败 整条数据无效(ret的raw_text为None return ret ret = ret.concatenate(mm_ret) # 可能有多条结果,初步collect # ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group()), text) ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group(), mminfo_ret_dict), text) return ret def process_one(self, text, index=0, raw_only=False, return_mm_data = True): ret = OceanProcessorOutput(index=index) for mtype in self.config.multimodal: # 循环获取音频 图像结果 并更新raw_text字段 mret = self.extract_replace_multimodal(text, mtype, return_mm_data = return_mm_data) # 增加获取视频结果 if mret.raw_text is None: # 数据包含音频但音频获取失败 return OceanProcessorOutput(index=index) ret = ret.concatenate(mret) text = mret.raw_text ret.raw_text = text if raw_only: return ret # 兼容SFT等自定义tokenizer逻辑的代码 # 处理预训练中的trainable部分 input_ids, labels = [], [] trainable_sep = re.findall(r'|', ret.raw_text.replace('\n', '')) if len(trainable_sep) <= 0: input_ids = self.tokenizer(ret.raw_text, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist() labels = [True for _ in input_ids] else: split_content = re.split(r'|', ret.raw_text) for i, sc in enumerate(split_content): if len(sc.strip()) == 0: continue # 把多余的空格干掉 sc_ids = self.tokenizer(sc, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist() input_ids.extend(sc_ids) if i == 0 or trainable_sep[i - 1] == '': # stop gradient labels.extend([False] * len(sc_ids)) else: labels.extend([True] * len(sc_ids)) # input_ids += [self.tokenizer.eos_token_id] # labels += [True] ret.labels = [input_ids[j] if (l and input_ids[j] not in self.config.multimodal_special_token_no_loss_list) else -100 for j, l in enumerate(labels)] ret.input_ids = input_ids ret.index = index return ret @torch.no_grad() def __call__(self, example, parallel=8): # 最终入口 支持预训练数据string,sft数据message, 以及 batch推理数据listofstring 3种形式 if isinstance(example, Dict): pass elif isinstance(example, str): return self.process_one(example) elif isinstance(example, List): # batch推理 异步多线程处理 with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor: future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)] batch_data = [key.result() for key in cf.as_completed(future_list)] valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data]) assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐 batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变 ret = OceanProcessorOutput() for i in range(len(batch_data)): ret = ret.concatenate(batch_data[i]) self.tokenizer.padding_side = "left" padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt') ret.input_ids = padding_result["input_ids"] ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens padding_result = self.tokenizer.pad({"input_ids": [r.labels for r in batch_data]}, return_tensors='pt') ret.labels = padding_result["input_ids"] if ret.audios is not None: ret.audios = default_collate(ret.audios) ret.encoder_length = default_collate(ret.encoder_length) ret.bridge_length = default_collate(ret.bridge_length) if ret.images is not None: ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images] # else:ret.images = default_collate(ret.images) # ret.patch_nums = default_collate(ret.patch_nums) if ret.videos is not None: ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos] return ret else: raise ValueError("example format supported yet") @torch.no_grad() def pack_batch_pretrain(self, raw_batch, max_sequence_length=None, parallel=8): if max_sequence_length is None: max_sequence_length = self.tokenizer.model_max_length # 将N条数据pack为M条 max_sequence_length长度的数据, 每条数据包含所属的多模态输入 assert isinstance(raw_batch, List) start_ts = time.time() if parallel > 1: with cf.ThreadPoolExecutor(max_workers=parallel) as executor: future_list = [] for idx, json_text in enumerate(raw_batch): try: # 读取json json_obj = ujson.loads(json_text.strip()) except: try: json_obj = ast.literal_eval(json_text.strip()) except: print("parse json obj faild: {}....".format(json_text[:300])) continue try: # chensong if isinstance(json_obj, list): content = json_obj[1] elif 'raw' in json_obj.keys(): content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["raw"] else: content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["content"] except: print("parse json raw/content error: {}....".format(json_text[:300])) continue future_list.append(executor.submit(self.process_one, content, idx)) # 获取结果 乱序 batch_data = [key.result() for key in cf.as_completed(future_list)] else: # debug only batch_data = [] for json_text in raw_batch: data = ujson.loads(json_text.strip()) if 'raw' in data.keys(): batch_data.append(self.process_one(data['raw'], 0)) else: batch_data.append(self.process_one(data['content'], 0)) if (time.time() - start_ts) / (len(batch_data) + 1e-3) > 1.0: print('[WARNING] processing each data cost more than 1.0s') # packing 文本部分的输入,不做任何截断 current_length, packed_output, output = 0, OceanProcessorOutput(position_ids=[], seqlens=[]), [] empty_data = OceanProcessorOutput(input_ids=[], labels=[]) for idx, bd in enumerate(batch_data + [empty_data]): # 加空数据方便appedn最后一个数据到output,防止遗漏 if bd.input_ids is None and idx < len(batch_data): continue # 数据没取到 并且不是最后一个 if (len(bd.input_ids) <= 0 or len(bd.input_ids) + 1 > max_sequence_length) and idx < len(batch_data): continue # 太长的直接不要 并且不是最后一个 if current_length + len(bd.input_ids) + 1 > max_sequence_length or idx == len(batch_data): pad_nums = max_sequence_length - current_length # right padding if packed_output.input_ids is None or packed_output.labels is None: packed_output.input_ids = [self.tokenizer.pad_token_id] * pad_nums packed_output.labels = [-100] * pad_nums packed_output.position_ids += [0] * (pad_nums+1) else: packed_output.input_ids += [self.tokenizer.pad_token_id] * pad_nums packed_output.labels += [-100] * pad_nums packed_output.position_ids += [0] * pad_nums packed_output.attention_mask = [1] * current_length + [0] * pad_nums packed_output.seqlens += [0] * (max_sequence_length - len(packed_output.seqlens)) output.append(packed_output) packed_output = OceanProcessorOutput(position_ids=[], seqlens=[]) # reset empty packed_output = packed_output.concatenate(bd) packed_output.input_ids.append(self.tokenizer.eos_token_id) # 需要单独加 packed_output.labels.append(self.tokenizer.eos_token_id) packed_output.position_ids.extend(list(range(len(bd.input_ids) + 1))) packed_output.seqlens.append(len(bd.input_ids) + 1) current_length = len(packed_output.input_ids) return output @torch.no_grad() def collect_batch_pretrain(self, batch_data): ret = OceanProcessorOutput() for i in range(len(batch_data)): ret = ret.concatenate(batch_data[i]) ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) ret.attention_mask = default_collate([np.asarray(x.attention_mask, dtype=np.float32) for x in batch_data]).cuda(non_blocking=True) ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) ret.raw_text = None if ret.audios is not None: ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)).cuda(non_blocking=True) ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)).cuda(non_blocking=True) ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)).cuda(non_blocking=True) if ret.images is not None: ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)).cuda(non_blocking=True) for image in ret.images]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True) ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True) if ret.videos is not None: ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)).cuda(non_blocking=True) for video in ret.videos]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True) ret.videos_patch_nums = default_collate(np.asarray(ret.videos_patch_nums, dtype=np.int32)).cuda(non_blocking=True) return ret @torch.no_grad() def collect_batch_sft(self, batch_data): # list of dict to dataclass batch_data = [OceanProcessorOutput(**bd) for bd in batch_data] ret = OceanProcessorOutput() for i in range(len(batch_data)): ret = ret.concatenate(batch_data[i]) ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]) ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]) ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]) ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]) ret.raw_text = None if ret.audios is not None: ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)) ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)) ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)) if ret.images is not None: # 转换 每个image 为torch tensor ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True) if ret.videos is not None: ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)) for video in ret.videos]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True) # ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True) ret = ret.__dict__ del ret['patch_nums'] del ret['images_size'] del ret['crop_size'] del ret['raw_text'] del ret['index'] del ret['attention_mask'] del ret['videos_patch_nums'] del ret['videos_size'] del ret['videos_crop_size'] return ret ####################################################### ## Unit Test Functions, usage ## python processor_ocean.py test ####################################################### def test_img_processor(): from transformers import AutoConfig from transformers.models.clip import CLIPImageProcessor config = AutoConfig.from_pretrained("./", trust_remote_code=True) processor = OceanImageProcessor(config.visual_config) offical_processor = CLIPImageProcessor(size=config.visual_config.crop_size, crop_size=config.visual_config.crop_size, image_mean=config.visual_config.image_mean, image_std=config.visual_config.image_std, do_convert_rgb=True) img_files = ['sogou/7a2c8ffc1bc61146b32805c3390f42e2', 'wukong/77c1db1c0e4200d12b478c33ba3a412d', 'wukong/62e9a5c8eb8b0ea8858a34ba3f1a999f', 'wukong/fb9ab4d7c3fe9f54289948fd6a57fc30'] cos_client = CosClient() for img_file in img_files: img_bytes = cos_client(img_file) img_rbg = Image.open(io.BytesIO(img_bytes)) image, org_size = processor.image_transform(img_bytes) offical_image = offical_processor.preprocess([img_rbg], do_resize=True, do_center_crop=True, do_rescale=True, do_normalize=True, return_tensors='np').data['pixel_values'][0] print('-'*60) print(np.array(img_rbg).shape) print(image.shape) print(offical_image.shape) print(image - offical_image) def test_audio_processor(): from transformers.models.whisper import WhisperFeatureExtractor from transformers import AutoConfig config = AutoConfig.from_pretrained("./", trust_remote_code=True) offical_processor = WhisperFeatureExtractor(feature_size=128) processor = OceanAudioProcessor(config.audio_config) # wave_files = glob.glob('/home/nfs_bc_alignment/sunhaoze/audio-data/openaqa/openaqa-as/audio/*') wave_files = ['/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/7ZY0U5tfKyQ.flac', '/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/Osly4Shchs4.flac'] for wave_file in wave_files: wave = processor.load_audio_waveform(wave_file, True, False) offical_features = offical_processor(wave[0].numpy(), do_normalize=False) feat = offical_features['input_features'][0] wave, frame_nums = processor.extract_fbank_features(wave) print("="*60) print(feat.shape) print(wave.shape, frame_nums) print('the difference between offical extractor and our implementation: {}'.format(wave_file)) print(wave[:, :frame_nums] - feat[:, :frame_nums]) print(wave) # print(wave[120:-1, :]) # print(feat[120:-1, :wave.shape[1]]) zeros_before = np.sum(wave == 0) aug = processor.data_augment(wave, frame_nums) zeros_after = np.sum(aug == 0) print(zeros_before, zeros_after) def test_audio_long(): # 测试超过30秒音频的截断策略 from transformers import AutoConfig, AutoTokenizer config = AutoConfig.from_pretrained("./", trust_remote_code=True) config.audio_config.split_overlap = 1 tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) processor = OceanMMProcessor(tokenizer, config, True) examples = ["{\"path\": \"panda\/testdata\/podcast_demo_30s\/easy_chat_xianliaohuier_30s\/easy_chat_xianliaohuier-133.mp3\"}What is the level of noise from the speech?\nThe speech energy\n is medium.", "what's the sound's energy? \n sound1 {\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-116.mp3\"} \n sound2 {\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-221.mp3\"}The speech energy is medium.", ] ret = processor(examples) print(ret) print(torch.sum(ret.input_ids == 151659)) print(torch.sum(ret.input_ids == 151674)) def test_processor(): from transformers import AutoConfig, AutoTokenizer config = AutoConfig.from_pretrained("./", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') examples = ["{\"path\": \"vggsound\/7DH5fqj8j6Q.flac\"}What is the level of noise from the speech?\nThe speech energy\n is medium.", "hello, ocean 你好 百川智能。", "what's the sound's energy? \n {\"path\": \"iemocap\/Ses01F_script01_3_F022.wav\"}The speech energy is medium.", "sound1: {\"path\": \"audioset_full\/9B53NVDNT8U.flac\"}\n sound2: \n{\"path\": \"audioset_full\/a2dgzb9GDSQ.flac\"}How is the speech speed related to the estimated speaker age?\nThe slow speech speed suggests a more deliberate and thoughtful approach often seen in mature individuals.", "{\"path\": \"sogou\/7351ae4f3fbe58ff0e4cc165cfabb3ed\"}新和记潮汕牛肉火锅的牛肉丸好不好吃 用户评价口味怎么样 常州美食牛肉丸实拍图片 大众点评", "这两个图片有什么关系?图片1{\"path\": \"sogou\/ac91d57ab68335913ed41aa283e76356\"}图片2\n{\"path\": \"sogou\/6ad5e632b74265d9ef689e45936ab1aa\"}", "根据图片和语音给出描述\n图片{\"path\": \"sogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}语音{\"path\": \"voxceleb2\/id06726_s2lysJWkjus_00169.m4a\"}这是一只猫", "这些图片和音频不存在{\"path\": \"soogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}语音{\"path\": \"voxceleb_1\/id06726_s2lysJWkjus_00169.m4a\"}这是一只猫" ] ret = processor(examples[4:-1]) print(ret) print(torch.sum(ret.input_ids == 151659)) print(torch.sum(ret.input_ids == 151662)) try: print(ret.bridge_length) print(ret.patch_nums) except: pass print(torch.sum(ret.attention_mask, dim=1)) def test_grounding(): from transformers import AutoConfig, AutoTokenizer config = AutoConfig.from_pretrained("./", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') examples = ["{\"path\": \"grit\/663423bf2f0884c034bf75279bce9694\"}\nWhere is \"A woman\" ? Answer: The bounding box is (0.58,0.8),(0.71,1.0)", "hello, ocean 你好 百川智能。", "{\"path\": \"grit\/0e6e3952c584cbac7235940a22514656\"} Generate the caption with grounding: Photo pour Portrait of young Asian muslim woman wearing hijab(0.09,0.01),(0.77,1.0) shows regret gesture, hand on her forehead, forget something important, against red background - image libre de droit", "Recognize the object in the outlined section {\"path\": \"grit\/045823cf6f819670f27aee20af7ae0e6\"} of the picture.(0.07,0.2),(0.91,0.96)\nInflatable water trampolines" ] ret = processor(examples) print(ret) for i, input_ids in enumerate(ret.input_ids): print("="*60) print(ret.labels[i]) def test_pack(): from transformers import AutoConfig, AutoTokenizer config = AutoConfig.from_pretrained("./", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=2048) processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') examples = open('/cpfs/29f69eb5e2e60f26/user/sunhaoze/pretrain-v6/sogou/part-00000').readlines()[:5] examples += open('/home/nfs_bc_alignment/sunhaoze/text/openaqa-as-stage2-v1/part-00000').readlines()[:5] random.shuffle(examples) batch_output = processor.pack_batch_pretrain(examples) for i, b in enumerate(batch_output): print('='*60) try: print(b.input_ids, len(b.input_ids)) print(b.labels, len(b.labels)) print(b.attention_mask, len(b.attention_mask)) print(b.position_ids, len(b.position_ids)) print(b.seqlens, len(b.seqlens)) print(b.audios) print(b.bridge_length) except: continue batch_for_model = processor.collect_batch_pretrain(batch_output) print(batch_for_model.input_ids.shape) print(batch_for_model.labels.shape) print(batch_for_model.audios.shape) print(batch_for_model["bridge_length"]) print(batch_for_model.images.shape) print(batch_for_model["patch_nums"]) print(batch_for_model["position_ids"]) print(batch_for_model["seqlens"]) def test_cos_audio(): cos_client = CosClient() audio_bytes = cos_client('panda/data/common_voice/cv-corpus-18.0-2024-06-14/zh-CN/clips/common_voice_zh-CN_19428637.mp3', 'audio-data-1317568651') wave, sr = torchaudio.load(audio_bytes, normalize=False) print(wave.shape, sr) # torchaudio.save('tmp.flac', wave, sr) if __name__ == '__main__': fire.Fire()