diff --git "a/llava_llama.py" "b/llava_llama.py" new file mode 100644--- /dev/null +++ "b/llava_llama.py" @@ -0,0 +1,6596 @@ +import base64 +import dataclasses +import logging +import math +import os +import os.path as osp +import re +import string +import tempfile +import warnings +from abc import ABC +from collections import OrderedDict +from dataclasses import dataclass +from enum import Enum, auto +from io import BytesIO +from shutil import copyfile +from threading import Thread +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from accelerate.hooks import add_hook_to_module +from huggingface_hub import repo_exists, snapshot_download +from huggingface_hub.utils import HFValidationError +from PIL import Image +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + PretrainedConfig, + PreTrainedModel, + StoppingCriteria, + TextIteratorStreamer, +) +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + get_channel_dimension_axis, + get_resize_output_image_size, + normalize, + pad, + rescale, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutputWithPast +from transformers.modeling_utils import ContextManagers, PreTrainedModel, no_init_weights +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import ( + AddedToken, + PaddingStrategy, + PreTokenizedInput, + TextInput, + TruncationStrategy, +) +from transformers.utils import ( + ModelOutput, + TensorType, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_tf_available, + is_torch_available, + is_torchvision_available, + is_vision_available, + logging, + replace_return_docstrings, +) + +# from ..configuration_llava import LlavaConfig +# from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel +# from ..mm_utils import get_model_name_from_path, tokenizer_image_token +# from .base_projector import MultimodalProjector, MultimodalProjectorConfig +# from .configuration_llava import LlavaConfig +# from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig +# from .constants import ( +# DEFAULT_IM_END_TOKEN, +# DEFAULT_IM_START_TOKEN, +# DEFAULT_IMAGE_PATCH_TOKEN, +# IGNORE_INDEX, +# IMAGE_TOKEN_INDEX, +# ) +# from .context_provider import ContextProvider, ContextProviderConfig +# from .language_model.builder import build_llm_and_tokenizer +# from .language_model.llava_llama import LlavaLlamaConfig, LlavaLlamaModel +# from .llava_arch import LlavaMetaForCausalLM, LlavaMetaModel +# from .model import get_model_name_from_path, load_pretrained_model +# from .model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX +# from .model.conversation import SeparatorStyle, conv_templates +# from .model.mm_utils import KeywordsStoppingCriteria, process_image, tokenizer_image_token +# from .multimodal_encoder.builder import build_context_provider, build_vision_tower +# from .multimodal_projector.builder import build_mm_projector +# from .siglip import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel +# from .siglip_encoder import SiglipVisionTower +# from .utils import get_model_config +# from .vision_encoder import VisionTower +# +# # Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 +MASK_TOKEN_INDEX = -300 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" +IMAGE_PLACEHOLDER = "" + + +class LlavaConfig(PretrainedConfig): + model_type = "llava" + + def __init__( + self, + llm_cfg=None, + vision_tower_cfg=None, + mm_projector_cfg=None, + mask_encoder_cfg=None, + context_provider_cfg=None, + architectures=None, + resume_path=None, + hidden_size=None, + mm_hidden_size=None, + image_aspect_ratio=None, + num_video_frames=None, + mm_vision_select_layer=None, + mm_vision_select_feature=None, + mm_use_im_start_end=False, + mm_use_im_patch_token=True, + mm_projector_lr=None, + vision_resolution=None, + interpolate_mode=None, + s2=None, + s2_scales=None, + s2_max_split_size=None, + **kwargs, + ): + super().__init__() + self.architectures = architectures + self.llm_cfg = llm_cfg + self.vision_tower_cfg = vision_tower_cfg + self.mm_projector_cfg = mm_projector_cfg + self.mask_encoder_cfg = mask_encoder_cfg + self.context_provider_cfg = context_provider_cfg + self.resume_path = resume_path + + self.hidden_size = hidden_size + self.mm_hidden_size = mm_hidden_size + self.image_aspect_ratio = image_aspect_ratio + self.num_video_frames = num_video_frames + self.mm_vision_select_layer = mm_vision_select_layer + self.mm_vision_select_feature = mm_vision_select_feature + self.mm_use_im_start_end = mm_use_im_start_end + self.mm_use_im_start_end = mm_use_im_start_end + self.mm_use_im_patch_token = mm_use_im_patch_token + self.mm_projector_lr = mm_projector_lr + self.vision_resolution = vision_resolution + self.interpolate_mode = interpolate_mode + self.s2 = s2 + self.s2_scales = s2_scales + self.s2_max_split_size = s2_max_split_size + + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +class SeparatorStyle(Enum): + """Different separator style.""" + + SINGLE = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_2 = auto() + MISTRAL = auto() + LLAMA_3 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0].replace("", "").strip() + if "mmtag" in self.version: + messages[0] = (init_role, init_msg) + messages.insert(0, (self.roles[0], "")) + messages.insert(1, (self.roles[1], "Received.")) + else: + messages[0] = (init_role, "\n" + init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.LLAMA_3: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message = message[0] + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif ( + self.sep_style == SeparatorStyle.LLAMA_2 + or self.sep_style == SeparatorStyle.MISTRAL + ): + if self.sep_style == SeparatorStyle.LLAMA_2: + + def wrap_sys(msg): + return f"<>\n{msg}\n<>\n\n" + + else: + + def wrap_sys(msg): + return f"{msg}" + ("\n" if msg else "") + + def wrap_inst(msg): + return f"[INST] {msg} [/INST]" + + ret = "" + if self.sep_style == SeparatorStyle.MISTRAL: + ret += "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + if self.sep_style == SeparatorStyle.LLAMA_2: + ret += " " + message + " " + self.sep2 + else: + ret += message + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + + from PIL import Image + + msg, image, image_process_mode = msg + if image_process_mode == "Pad": + + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new( + pil_img.mode, (width, width), background_color + ) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new( + pil_img.mode, (height, height), background_color + ) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize": + image = image.resize((336, 336)) + else: + raise ValueError( + f"Invalid image_process_mode: {image_process_mode}" + ) + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if longest_edge != max(image.size): + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + images.append(image) + else: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + images.append(img_b64_str) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + + msg, image, image_process_mode = msg + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'user upload image' + msg = img_str + msg.replace("", "").strip() + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version, + ) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [ + [x, y[0] if type(y) is tuple else y] for x, y in self.messages + ], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_vicuna_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "What are the key differences between renewable and non-renewable energy sources?", + ), + ( + "Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", + ), + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +# kentang-mit@: This conversation template is designed for SFT on VFLAN. +conv_vicuna_v1_nosys = Conversation( + system="", + roles=("USER", "ASSISTANT"), + version="v1_nosys", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_mistral = Conversation( + system="", + roles=("USER", "ASSISTANT"), + version="mistral", + messages=(), + offset=0, + sep_style=SeparatorStyle.MISTRAL, + sep="", + sep2="", +) + +conv_llava_llama_2 = Conversation( + system="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_mpt = Conversation( + system="""<|im_start|>system +A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_llava_plain = Conversation( + system="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="\n", +) + +conv_llava_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_llava_v0_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "The visual content will be provided with the following format: visual content.", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + version="v0_mmtag", +) + +conv_llava_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + + +conv_llava_v1_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "The visual content will be provided with the following format: visual content.", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", + version="v1_mmtag", +) + +hermes_2 = Conversation( + system="<|im_start|>system\nAnswer the questions.", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", + messages=(), + offset=0, + version="hermes-2", +) + + +# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template. +llama_3_chat = Conversation( + system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=( + "<|start_header_id|>user<|end_header_id|>\n\n", + "<|start_header_id|>system<|end_header_id|>\n\n", + ), + version="llama_v3", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_3, + sep="<|end_of_text|>", +) + + +default_conversation = conv_vicuna_v1 +conv_templates = { + "default": conv_vicuna_v0, + "hermes-2": hermes_2, + "llama_3": llama_3_chat, + "v0": conv_vicuna_v0, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "vicuna_v1_nosys": conv_vicuna_v1_nosys, + "llama_2": conv_llama_2, + "mistral": conv_mistral, + "plain": conv_llava_plain, + "v0_plain": conv_llava_plain, + "llava_v0": conv_llava_v0, + "v0_mmtag": conv_llava_v0_mmtag, + "llava_v1": conv_llava_v1, + "v1_mmtag": conv_llava_v1_mmtag, + "llava_llama_2": conv_llava_llama_2, + "mpt": conv_mpt, +} + +# if __name__ == "__main__": +# print(default_conversation.get_prompt()) +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None): + import cv2 + + if fps is None or frame_count is None: + # if one of fps or frame_count is None, still recompute + fps = vidcap.get(cv2.CAP_PROP_FPS) + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + if fps == 0 or frame_count == 0: + print("Video file not found. return empty images.") + return [ + Image.new("RGB", (720, 720)), + ] * num_frames + + frame_count / fps + frame_interval = frame_count // num_frames + if frame_interval == 0 and frame_count <= 1: + print("frame_interval is equal to 0. return empty image.") + return [ + Image.new("RGB", (720, 720)), + ] * num_frames + # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval) + + images = [] + count = 0 + success = True + frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int) + + while success: + # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval) + if frame_count >= num_frames: + success, frame = vidcap.read() + if count in frame_indices: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + if len(images) >= num_frames: + return images + count += 1 + else: + # Left padding frames if the video is not long enough + success, frame = vidcap.read() + if success: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + count += 1 + elif count >= 1: + width, height = images[-1].size + images = [Image.new("RGB", (width, height))] * ( + num_frames - len(images) + ) + images + print("padding frames:", (num_frames - len(images))) + return images + else: + break + raise ValueError("Did not find enough frames in the video. return empty image.") + + +def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None): + """ + Extract frames from a video using OpenCV. + + Args: + vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video. + frames (int): Number of frames to extract from the video. + + Returns: + list: List of PIL Images extracted from the video. + + Raises: + NotImplementedError: If the type of `vpath_or_bytesio` is not supported. + """ + import cv2 + + if isinstance(vpath_or_bytesio, str): + vidcap = cv2.VideoCapture(vpath_or_bytesio) + return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) + elif isinstance(vpath_or_bytesio, (BytesIO,)): + # assuming mp4 + with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: + temp_video.write(vpath_or_bytesio.read()) + temp_video_name = temp_video.name + vidcap = cv2.VideoCapture(temp_video_name) + return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) + else: + raise NotImplementedError(type(vpath_or_bytesio)) + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + """ + Expand the given PIL image to a square shape by adding padding. + + Parameters: + - pil_img: The PIL image to be expanded. + - background_color: The color of the padding to be added. + + Returns: + - The expanded PIL image. + + If the image is already square, it is returned as is. + If the image is wider than it is tall, padding is added to the top and bottom. + If the image is taller than it is wide, padding is added to the left and right. + """ + width, height = pil_img.size + if pil_img.mode == "L": + background_color = background_color[0] + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_image(image_file, data_args, image_folder, pil_preprocess_fn=None): + processor = data_args.image_processor + if isinstance(image_file, str): + if image_folder is not None: + image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") + else: + image = Image.open(image_file).convert("RGB") + else: + # image is stored in bytearray + image = image_file.convert("RGB") + + info = None + + if pil_preprocess_fn is not None: + image = pil_preprocess_fn(image) + if isinstance(image, tuple): + image, info = image + + if data_args.image_aspect_ratio == "resize": + if hasattr(data_args.image_processor, "crop_size"): + # CLIP vision tower + crop_size = data_args.image_processor.crop_size + else: + # SIGLIP vision tower + assert hasattr(data_args.image_processor, "size") + crop_size = data_args.image_processor.size + image = image.resize((crop_size["height"], crop_size["width"])) + if data_args.image_aspect_ratio == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + else: + # Using default behavior of the vision encoder + # For CLIP, default is central crop + # For Radio, default is central crop + # For Siglip, default is resize + # For InternVIT, default is resize + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + if info is not None: + return image, info + return image + + +def process_images(images, image_processor, model_cfg): + + model_cfg.image_processor = image_processor + new_images = [process_image(image, model_cfg, None) for image in images] + + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + +# Note that newer VILA codebase adds an lstrip option that defaults to False, and the functionality is the same by default +def tokenizer_image_token( + prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None +): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if ( + len(prompt_chunks) > 0 + and len(prompt_chunks[0]) > 0 + and prompt_chunks[0][0] == tokenizer.bos_token_id + ): + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == "pt": + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f"Unsupported tensor type: {return_tensors}") + return input_ids + + +def is_gemma_tokenizer(tokenizer): + return "gemma" in tokenizer.__class__.__name__.lower() + + +def get_model_name_from_path(model_path): + if not model_path: + return "describe_anything_model" + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if ( + len(cur_keyword_ids) > 1 + and cur_keyword_ids[0] == tokenizer.bos_token_id + ): + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def call_for_batch( + self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [ + keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids + ] + for keyword_id in self.keyword_ids: + if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): + return True + outputs = self.tokenizer.batch_decode( + output_ids[:, -offset:], skip_special_tokens=True + )[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + def __call__( + self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + outputs = [] + for i in range(output_ids.shape[0]): + outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) + return all(outputs) + + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +def get_model_config(config): + # `mask_encoder_cfg` and `context_provider_cfg` are optional + default_keys = [ + "llm_cfg", + "vision_tower_cfg", + "mm_projector_cfg", + "mask_encoder_cfg", + "context_provider_cfg", + ] + + if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2: + root_path = config._name_or_path + else: + root_path = config.resume_path + + # download from huggingface + if root_path is not None and not osp.exists(root_path): + try: + valid_hf_repo = repo_exists(root_path) + except HFValidationError: + valid_hf_repo = False + if valid_hf_repo: + root_path = snapshot_download(root_path) + + return_list = [] + for key in default_keys: + cfg = getattr(config, key, None) + if isinstance(cfg, dict): + try: + return_list.append(os.path.join(root_path, key[:-4])) + except: + raise ValueError(f"Cannot find resume path in config for {key}!") + elif isinstance(cfg, PretrainedConfig): + return_list.append(os.path.join(root_path, key[:-4])) + elif isinstance(cfg, str): + return_list.append(cfg) + elif cfg is None: + # We still return even if the cfg is None or does not exist + return_list.append(cfg) + + return return_list + + +def is_mm_model(model_path): + """ + Check if the model at the given path is a visual language model. + + Args: + model_path (str): The path to the model. + + Returns: + bool: True if the model is an MM model, False otherwise. + """ + config = AutoConfig.from_pretrained(model_path) + architectures = config.architectures + for architecture in architectures: + if "llava" in architecture.lower(): + return True + return False + + +def auto_upgrade(config): + cfg = AutoConfig.from_pretrained(config) + if "llava" in config and "llava" not in cfg.model_type: + assert cfg.model_type == "llama" + print( + "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." + ) + print( + "You must upgrade the checkpoint to the new code base (this can be done automatically)." + ) + confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") + if confirm.lower() in ["y", "yes"]: + print("Upgrading checkpoint...") + assert len(cfg.architectures) == 1 + setattr(cfg.__class__, "model_type", "llava") + cfg.architectures[0] = "LlavaLlamaForCausalLM" + cfg.save_pretrained(config) + print("Checkpoint upgraded.") + else: + print("Checkpoint upgrade aborted.") + exit(1) + + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# TODO decide whether should we use metaclass + + +class LlavaMetaModel(ABC): + def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): + # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. + if ( + hasattr(self, "llm") + or hasattr(self, "vision_tower") + or hasattr(self, "mm_projector") + ): + # already initialized, skipped + return + + model_dtype = getattr(config, "model_dtype", "torch.float16") + if not hasattr(config, "model_dtype"): + warnings.warn( + "model_dtype not found in config, defaulting to torch.float16." + ) + config.model_dtype = model_dtype + + # print("init_vlm(): config", config); input("DEBUG init_vlm") + cfgs = get_model_config(config) + # Only the first three are required. Others are optional. + ( + llm_cfg, + vision_tower_cfg, + mm_projector_cfg, + mask_encoder_cfg, + context_provider_cfg, + ) = cfgs + if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None: + raise ValueError( + "`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." + ) + # print("init_vlm():", cfgs); input("DEBUG init_vlm") + # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG init_vlm") + self.llm, self.tokenizer = build_llm_and_tokenizer( + llm_cfg, config, *args, **kwargs + ) + self.vision_tower = build_vision_tower(vision_tower_cfg, config) + self.mm_projector = build_mm_projector(mm_projector_cfg, config) + self.context_provider = ( + build_context_provider(context_provider_cfg, config) + if context_provider_cfg is not None + else None + ) + + self.post_config() + self.is_loaded = True + + assert ( + self.llm is not None + or self.vision_tower is not None + or self.mm_projector is not None + ), "At least one of the components must be instantiated." + + @classmethod + def load_from_config(cls, model_path_or_config, *args, **kwargs): + pass + + # FIXME we will use this function to load model in the future + @classmethod + def load_pretrained(cls, model_path_or_config, *args, **kwargs): + kwargs.pop("config", None) + + if isinstance(model_path_or_config, str): + config = AutoConfig.from_pretrained(model_path_or_config) + elif isinstance(model_path_or_config, LlavaConfig): + config = model_path_or_config + else: + raise NotImplementedError( + f"wrong type, {type(model_path_or_config)} \ + {isinstance(model_path_or_config, LlavaConfig)}" + ) + + model_dtype = getattr(config, "model_dtype", "torch.float16") + if not hasattr(config, "model_dtype"): + warnings.warn( + "model_dtype not found in config, defaulting to torch.float16." + ) + config.model_dtype = model_dtype + + cfgs = get_model_config(config) + # Only the first three are required. Others are optional. + ( + llm_cfg, + vision_tower_cfg, + mm_projector_cfg, + mask_encoder_cfg, + context_provider_cfg, + ) = cfgs + if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None: + raise ValueError( + "`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." + ) + + # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained") + with ContextManagers( + [ + no_init_weights(_enable=True), + ] + ): + vlm = cls(config, *args, **kwargs) + # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained finish") + + if ( + hasattr(vlm, "llm") + or hasattr(vlm, "vision_tower") + or hasattr(vlm, "mm_projector") + ): + if vlm.is_loaded: + return vlm + + vlm.llm, vlm.tokenizer = build_llm_and_tokenizer( + llm_cfg, config, *args, **kwargs + ) + vlm.vision_tower = build_vision_tower(vision_tower_cfg, config) + vlm.mm_projector = build_mm_projector(mm_projector_cfg, config) + if mask_encoder_cfg is not None: + raise NotImplementedError("Mask encoder is not supported.") + vlm.context_provider = ( + build_context_provider(context_provider_cfg, config) + if context_provider_cfg is not None + else None + ) + + self.post_config() + self.is_loaded = True + + # FIXME(ligeng, yunhao): llm should never be none here. + assert ( + vlm.llm is not None + or vlm.vision_tower is not None + or vlm.mm_projector is not None + ), "At least one of the components must be instantiated." + return vlm + + # FIXME we will use this function to save the model in the future + def save_pretrained(self, output_dir, state_dict=None): + if state_dict is None: + # other wise fetch from deepspeed + # state_dict = accelerator.get_state_dict(is_deepspeed_enabled) + state_dict = self.state_dict() + + if getattr(self, "tokenizer", None): + self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) + + if self.get_llm(): + print(f"saving llm to {osp.join(output_dir, 'llm')}") + self.llm.config._name_or_path = osp.join(output_dir, "llm") + llm_state_dict = OrderedDict( + {k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k} + ) + self.llm.save_pretrained( + os.path.join(output_dir, "llm"), state_dict=llm_state_dict + ) + self.config.llm_cfg = self.llm.config + + if ( + self.get_vision_tower() + and "radio" not in self.get_vision_tower().__class__.__name__.lower() + ): + print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") + self.vision_tower.config._name_or_path = osp.join( + output_dir, "vision_tower" + ) + vision_tower_state_dict = OrderedDict( + { + k.split("vision_tower.vision_tower.")[-1]: v + for k, v in state_dict.items() + if "vision_tower" in k + } + ) + self.vision_tower.vision_tower.save_pretrained( + os.path.join(output_dir, "vision_tower"), + state_dict=vision_tower_state_dict, + ) + self.vision_tower.image_processor.save_pretrained( + os.path.join(output_dir, "vision_tower") + ) + self.config.vision_tower_cfg = self.vision_tower.config + if hasattr(self.config.vision_tower_cfg, "auto_map"): + delattr(self.config.vision_tower_cfg, "auto_map") + + if self.get_mm_projector(): + print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") + self.mm_projector.config._name_or_path = osp.join( + output_dir, "mm_projector" + ) + mm_projector_state_dict = OrderedDict( + { + k.split("mm_projector.")[-1]: v + for k, v in state_dict.items() + if "mm_projector" in k + } + ) + self.mm_projector.save_pretrained( + os.path.join(output_dir, "mm_projector"), + state_dict=mm_projector_state_dict, + ) + self.config.mm_projector_cfg = self.mm_projector.config + + if self.get_context_provider(): + print( + f"saving context_provider to {osp.join(output_dir, 'context_provider')}" + ) + self.context_provider.config._name_or_path = osp.join( + output_dir, "context_provider" + ) + context_provider_state_dict = OrderedDict( + { + k.split("context_provider.")[-1]: v + for k, v in state_dict.items() + if "context_provider" in k + } + ) + self.context_provider.save_pretrained( + os.path.join(output_dir, "context_provider"), + state_dict=context_provider_state_dict, + ) + self.config.context_provider_cfg = self.context_provider.config + + # update and save top-level config + self.config._name_or_path = output_dir + self.config.architectures = [self.__class__.__name__] + self.config.save_pretrained(output_dir) + + def get_llm(self): + llm = getattr(self, "llm", None) + if type(llm) is list: + llm = llm[0] + return llm + + def get_lm_head(self): + lm_head = getattr(self.get_llm(), "lm_head", None) + return lm_head + + def get_vision_tower(self): + vision_tower = getattr(self, "vision_tower", None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def get_mm_projector(self): + mm_projector = getattr(self, "mm_projector", None) + if type(mm_projector) is list: + mm_projector = mm_projector[0] + return mm_projector + + def get_context_provider(self): + context_provider = getattr(self, "context_provider", None) + return context_provider + + def post_config(self): + self.training = self.get_llm().training + # configuration + if getattr(self.config, "llm_cfg", None) is None: + self.config.llm_cfg = self.llm.config + if getattr(self.config, "vision_tower_cfg", None) is None: + self.config.vision_tower_cfg = self.vision_tower.config + if getattr(self.config, "mm_projector_cfg", None) is None: + self.config.mm_projector_cfg = self.mm_projector.config + if ( + getattr(self.config, "context_provider_cfg", None) is None + and self.context_provider is not None + ): + self.config.context_provider_cfg = self.context_provider.config + + def freezed_module_patch(self): + """ + Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. + """ + if self.training: + if self.get_llm() and not getattr( + self.config, "tune_language_model", False + ): + logging.warning( + "Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations." + ) + if self.get_vision_tower() and not getattr( + self.config, "tune_vision_tower", False + ): + self.get_vision_tower().eval() + if self.get_mm_projector() and not getattr( + self.config, "tune_mm_projector", False + ): + self.get_mm_projector().eval() + if self.get_context_provider() and not getattr( + self.config, "tune_context_provider", False + ): + self.get_context_provider().eval() + + def encode_images(self, images): + image_features = self.get_vision_tower()(images) + image_features = self.get_mm_projector()(image_features) + return image_features + + def encode_images_with_context(self, images): + context_provider = self.get_context_provider() + # If the channels completely match, they are cimage (image with context). + cimage_mask = torch.any( + (images[:, :4, ...] != images[:, 4:, ...]).flatten(start_dim=1), dim=1 + ) + + if context_provider.treat_image_as_cimage: + # If the context provider treats the image as cimage, then all images are cimage. + cimage_mask[:] = True + + if context_provider.context_image_as_queries: + # Swap the crop image and full image since the model uses the full image as queries by default + images = torch.cat((images[:, 4:, ...], images[:, :4, ...]), dim=1) + + # Process the first 4 channels for all images: for image it's the image, for cimage it's the full image + vision_tower = self.get_vision_tower() + # Encode context images (full images) + image_features = vision_tower(images[:, :4, ...]).to(self.device) + # Each cimage has 8 channels (full and crop concatenated) + cimage_concatenated = images[cimage_mask] + cimage_full_features = image_features[cimage_mask] + if context_provider.context_provider_type == "cross_attn_end_to_all": + cimage_features = self.context_provider( + cimage_full_features=cimage_full_features, + cimage_concatenated=cimage_concatenated, + vision_tower=vision_tower, + ).to(self.device) + elif context_provider.context_provider_type == "concat": + # Full features of cimages are computed but not used. + cimage_features = self.context_provider( + cimage_concatenated=cimage_concatenated, vision_tower=vision_tower + ).to(self.device) + else: + raise NotImplementedError( + f"Context provider type {context_provider.context_provider_type} not implemented." + ) + # Put cimage_features into image_features + image_features[cimage_mask] = cimage_features + + # Project to the llm space + image_features = self.get_mm_projector()(image_features) + + return image_features + + # @yunhao: is there a better way to handle function call and attributes for llm? + # support beam search + def _temporary_reorder_cache(self, past_key_values, sorted_idx): + return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) + + def get_input_embeddings(self): + return self.get_llm().get_input_embeddings() + + def get_output_embeddings(self): + return self.get_llm().get_output_embeddings() + + def resize_token_embeddings(self, embed_size): + self.get_llm().resize_token_embeddings(embed_size) + + +class LlavaMetaForCausalLM(ABC): + """This class is originally implemented by the LLaVA team and + modified by Haotian Tang and Jason Lu based on Ji Lin's implementation + to support multiple images and input packing.""" + + # TODO move the forward function here if there is no need to override it + def prepare_inputs_labels_for_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, images + ): + vision_tower = self.get_vision_tower() + if vision_tower is None or images is None or input_ids.shape[1] == 1: + if ( + past_key_values is not None + and vision_tower is not None + and images is not None + and input_ids.shape[1] == 1 + ): + target_shape = past_key_values[-1][-1].shape[-2] + 1 + attention_mask = torch.cat( + ( + attention_mask, + torch.ones( + ( + attention_mask.shape[0], + target_shape - attention_mask.shape[1], + ), + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + ), + dim=1, + ) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + return ( + input_ids, + position_ids, + attention_mask, + past_key_values, + None, + labels, + ) + # handle different image dtypes for packing + if type(images) is list: + images = torch.cat(images, dim=0) + elif images.ndim == 5: # batch_size x seq_len x image_channels + images = images.flatten(0, 1) + if getattr(self, "context_provider", None): + image_features = self.encode_images_with_context(images) + else: + # Since we slice it with index below, turning it into a list splits things by the first index which does not result in data copy or degrade performance. + # Example dimension: [1, 196, 2560] + assert ( + images.shape[1] <= 4 + ), "images have more than 4 channels, but context provider is not included" + image_features = self.encode_images(images).to(self.device) + # Note (kentang-mit@): image start / end is not implemented here to support pretraining. + if getattr(self.config, "turn_mm_projector", False) and getattr( + self.config, "mm_use_im_start_end", False + ): + raise NotImplementedError + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange( + 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device + ) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask + input_ids_copy = input_ids.clone() + # kentang-mit@: Otherwise tokenizer out of bounds. Embeddings of image tokens will not be used. + input_ids_copy[input_ids_copy == IMAGE_TOKEN_INDEX] = 0 + input_embeds = self.llm.model.embed_tokens(input_ids_copy) + + input_ids = [ + cur_input_ids[cur_attention_mask] + for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + input_embeds_1 = [ + cur_input_embeds[cur_attention_mask] + for cur_input_embeds, cur_attention_mask in zip( + input_embeds, attention_mask + ) + ] + labels = [ + cur_labels[cur_attention_mask] + for cur_labels, cur_attention_mask in zip(labels, attention_mask) + ] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + + # print("BEFORE BATCH LOOP:", len(input_ids), input_ids[0].shape, input_ids[0].device, [(x == IMAGE_TOKEN_INDEX).sum() for x in input_ids]) + + # kentang-mit@: If some part of the model is executed in the loop, the the loop length needs to be a constant. + for batch_idx, cur_input_ids in enumerate(input_ids): + cur_input_ids = input_ids[batch_idx] + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_images == 0: + cur_image_features = image_features[0] + # cur_input_embeds_1 = self.get_llm().embed_tokens(cur_input_ids) + cur_input_embeds_1 = input_embeds_1[batch_idx] + cur_input_embeds = torch.cat( + [cur_input_embeds_1, cur_image_features[0:0]], dim=0 + ) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + # kenang-mit@: we do not have placeholdr image for text-only data now. + # cur_image_idx += 1 + continue + + cur_input_embeds = input_embeds_1[batch_idx] + image_token_indices = ( + [-1] + + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + + [cur_input_ids.shape[0]] + ) + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + cur_input_embeds_no_im = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append( + cur_input_ids[ + image_token_indices[i] + 1 : image_token_indices[i + 1] + ] + ) + cur_labels_noim.append( + cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]] + ) + cur_input_embeds_no_im.append( + cur_input_embeds[ + image_token_indices[i] + 1 : image_token_indices[i + 1] + ] + ) + [x.shape[0] for x in cur_labels_noim] + # cur_input_embeds = self.get_llm().embed_tokens(torch.cat(cur_input_ids_noim)) + # cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + cur_image_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append( + torch.full( + (cur_image_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr( + self.llm.config, "tokenizer_model_max_length", None + ) + if tokenizer_model_max_length is not None: + if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): + warnings.warn("Inputs truncated!") + new_input_embeds = [ + x[:tokenizer_model_max_length] for x in new_input_embeds + ] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full( + (batch_size, max_len), + IGNORE_INDEX, + dtype=new_labels[0].dtype, + device=new_labels[0].device, + ) + attention_mask = torch.zeros( + (batch_size, max_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + position_ids = torch.zeros( + (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device + ) + + for i, (cur_new_embed, cur_new_labels) in enumerate( + zip(new_input_embeds, new_labels) + ): + cur_len = cur_new_embed.shape[0] + if getattr(self.llm.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + cur_new_embed, + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return ( + None, + position_ids, + attention_mask, + past_key_values, + new_input_embeds, + new_labels, + ) + + def repack_multimodal_data( + self, + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ): + # kentang-mit@: reorder and repack (reduce computation overhead) + # requires transformers replacement. + new_inputs_embeds = [] + new_position_ids = [] + new_labels = [] + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + sorted_seqlens_in_batch, sorted_idx = torch.sort( + seqlens_in_batch, descending=True + ) + # print(sorted_seqlens_in_batch) + max_seqlen = inputs_embeds.shape[1] + + cur_inputs_embeds = [] + cur_position_ids = [] + cur_labels = [] + cur_batch_len = 0 + # print(sorted_seqlens_in_batch.device, len(sorted_seqlens_in_batch), max_seqlen) + for i in range(len(sorted_seqlens_in_batch)): + cur_seqlen = sorted_seqlens_in_batch[i].item() + if cur_seqlen + cur_batch_len <= max_seqlen: + cur_batch_len += cur_seqlen + # each item: num_tokens x num_channels + # remove padding on-the-fly + cur_inputs_embeds.append( + inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]] + ) + # each item: num_tokens + cur_position_ids.append( + torch.arange( + cur_inputs_embeds[-1].shape[0], + device=cur_inputs_embeds[-1].device, + ) + ) + # each item: num_tokens + # remove padding on-the-fly + cur_labels.append(labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]) + else: + new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) + new_position_ids.append(torch.cat(cur_position_ids, 0)) + new_labels.append(torch.cat(cur_labels, 0)) + # The current batch is too long. We will start a new batch. + cur_batch_len = cur_seqlen + cur_inputs_embeds = [ + inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]] + ] + cur_position_ids = [ + torch.arange( + cur_inputs_embeds[-1].shape[0], + device=cur_inputs_embeds[-1].device, + ) + ] + cur_labels = [labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]] + + if len(cur_inputs_embeds): + new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) + new_position_ids.append(torch.cat(cur_position_ids, 0)) + new_labels.append(torch.cat(cur_labels, 0)) + + # print(new_position_ids[0].device, [x.shape for x in new_inputs_embeds], [x.shape for x in new_labels], [x.shape for x in new_position_ids]) + # assert 0 + new_inputs_embeds = torch.nn.utils.rnn.pad_sequence( + new_inputs_embeds, batch_first=True, padding_value=self.llm.pad_token_id + ) + + new_position_ids = torch.nn.utils.rnn.pad_sequence( + new_position_ids, batch_first=True, padding_value=-1 + ) + + new_labels = torch.nn.utils.rnn.pad_sequence( + new_labels, batch_first=True, padding_value=IGNORE_INDEX + ) + # yunhao: it's currently a workaround to avoid errors for seq_len < 100 + new_attention_mask = new_position_ids.ne(-1) + # sanity check + assert new_attention_mask.sum() == attention_mask.sum() + # print(new_inputs_embeds.shape, (new_attention_mask.sum(1))) + # print(sorted_seqlens_in_batch.device, sorted_seqlens_in_batch, new_attention_mask.sum(1)) + + # return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + return ( + None, + new_position_ids, + new_attention_mask, + past_key_values, + new_inputs_embeds, + new_labels, + sorted_seqlens_in_batch, + ) + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + # TODO yunhao: handle cases for + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load( + model_args.pretrain_mm_mlp_adapter, map_location="cpu" + ) + embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[ + -num_new_tokens: + ] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError( + f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." + ) + elif model_args.mm_use_im_patch_token: + if model_args.mm_projector: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +def build_mm_projector( + model_type_or_path: str, config: PretrainedConfig +) -> PreTrainedModel: + if model_type_or_path is None: + return None + + # load from pretrained model + if config.resume_path: + assert os.path.exists( + model_type_or_path + ), f"Resume mm projector path {model_type_or_path} does not exist!" + return MultimodalProjector.from_pretrained( + model_type_or_path, config, torch_dtype=eval(config.model_dtype) + ) + # build from scratch + else: + mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) + mm_projector = MultimodalProjector(mm_projector_cfg, config).to( + eval(config.model_dtype) + ) + return mm_projector + + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": "identity"} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) + ) + + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + + +class DownSampleBlock(nn.Module): + + def forward(self, x): + vit_embeds = x + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.flat_square(vit_embeds) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + return vit_embeds + + def flat_square(self, x): + n, w, h, c = x.size() + if w % 2 == 1: + x = torch.concat( + [x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1 + ).contiguous() + n, w, h, c = x.size() + if h % 2 == 1: + x = torch.concat( + [x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2 + ).contiguous() + n, w, h, c = x.size() + x = x.view(n, w, int(h / 2), int(c * 2)) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h / 2), int(w / 2), int(c * 4)) + return x + + +class MultimodalProjectorConfig(PretrainedConfig): + model_type = "v2l_projector" + + def __init__(self, mm_projector_type: str = None, **kwargs): + super().__init__() + self.mm_projector_type = mm_projector_type + + +class MultimodalProjector(PreTrainedModel): + config_class = MultimodalProjectorConfig + + def __init__( + self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig + ): + super().__init__(mm_projector_cfg) + mm_projector_type = mm_projector_cfg.mm_projector_type + if mm_projector_type == "identity": + self.layers = IdentityMap() + elif mm_projector_type == "linear": + self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size) + elif mm_projector_type == "mlp_downsample": + self.layers = nn.Sequential( + DownSampleBlock(), + nn.LayerNorm(config.mm_hidden_size * 4), + nn.Linear(config.mm_hidden_size * 4, config.hidden_size), + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + else: + mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + self.layers = nn.Sequential(*modules) + else: + raise ValueError(f"Unknown projector type: {mm_projector_type}") + + def forward(self, x, *args, **kwargs): + return self.layers(x) + + +AutoConfig.register("v2l_projector", MultimodalProjectorConfig) +# This file is modified from https://github.com/haotian-liu/LLaVA/ +AutoModel.register(MultimodalProjectorConfig, MultimodalProjector) + + +def build_vision_tower( + model_name_or_path: str, config: PretrainedConfig +) -> PreTrainedModel: + # skip vision tower instantiation + if model_name_or_path is None: + return None + + vision_tower_arch = None + if config.resume_path and "radio" not in model_name_or_path: + assert os.path.exists( + model_name_or_path + ), f"Resume vision tower path {model_name_or_path} does not exist!" + vision_tower_cfg = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + vision_tower_arch = vision_tower_cfg.architectures[0].lower() + vision_tower_name = ( + vision_tower_arch if vision_tower_arch is not None else model_name_or_path + ) + + if "siglip" in vision_tower_name: + vision_tower = SiglipVisionTower(model_name_or_path, config) + else: + raise ValueError(f"Unknown vision tower: {model_name_or_path}") + + config.mm_hidden_size = vision_tower.config.hidden_size + return vision_tower + + +def build_context_provider( + model_type_or_path: str, config: PretrainedConfig +) -> PreTrainedModel: + if model_type_or_path is None: + return None + + # load from pretrained model + if config.resume_path: + assert os.path.exists( + model_type_or_path + ), f"Resume context provider path {model_type_or_path} does not exist!" + return ContextProvider.from_pretrained( + model_type_or_path, config, torch_dtype=eval(config.model_dtype) + ) + # build from scratch + else: + mm_projector_cfg = ContextProviderConfig(model_type_or_path) + mm_projector = ContextProvider(mm_projector_cfg, config).to( + eval(config.model_dtype) + ) + return mm_projector + + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +# import deepspeed +# from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + + +def is_deepspeed_zero3_enabled(*args, **kwargs): + return False + + +class ContextProviderConfig(PretrainedConfig): + model_type = "context_provider" + + def __init__( + self, + context_provider_type: str = None, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_mask_channels=0, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + zero_init_output=True, + residual_dropout=0.0, + context_image_as_queries=False, + context_provider_layer_indices=None, + masked_cross_attn=False, + crop_position_single_embedding=False, + trainable_crop_position_embedding=True, + crop_embedding_mode="add", + treat_image_as_cimage=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.context_provider_type = context_provider_type + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.num_mask_channels = num_mask_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + self.zero_init_output = zero_init_output + self.residual_dropout = residual_dropout + self.context_image_as_queries = context_image_as_queries + + # cross_attn_end_to_all + # the `num_hidden_layers` should be the same as the one in the vision tower + self.num_hidden_layers = num_hidden_layers + self.context_provider_layer_indices = context_provider_layer_indices + + self.masked_cross_attn = masked_cross_attn + # If enabled, crop_position_embedding (delta to full pos) will be updated during training. + self.trainable_crop_position_embedding = trainable_crop_position_embedding + # If enabled, crop_position_embedding (delta to full pos) will be a single embedding for all positions. + self.crop_position_single_embedding = crop_position_single_embedding + # add: delta. replace: do not add the original positional embedding + self.crop_embedding_mode = crop_embedding_mode + + # If True, the input image will be treated as a cimage (with mask as full 1s) + self.treat_image_as_cimage = treat_image_as_cimage + + +# Context Provider + + +class ContextProviderCrossAttention(nn.Module): + """Multi-headed cross-attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + batch_size, kv_len, _ = encoder_hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + batch_size, kv_len, self.num_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + batch_size, kv_len, self.num_heads, self.head_dim + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # Visualizations (-inf are shown as white) + # import matplotlib.pyplot as plt + # plt.imshow(attention_mask[0, 0, 0].view(27, 27).detach().cpu().numpy()) + # plt.title("Attention mask") + # plt.colorbar() + # plt.show() + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + + # Visualizations: show the attention weights of the first head, with the first query + # import matplotlib.pyplot as plt + # plt.imshow(attn_weights[0, 0, 0].view(27, 27).detach().cpu().numpy()) + # plt.title("Attention weights") + # plt.colorbar() + # plt.show() + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class ContextProviderMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +def get_token_mask_bias(mask, patch_size): + # Note: mask should be (0, 1) + with torch.no_grad(): + # Add a channel dimension and perform conv + # mask_tokens_after_conv: (B, 1, H, W), example dimension: [1, 1, 27, 27] + mask_tokens_after_conv = F.conv2d( + input=mask[:, None], + weight=torch.ones( + (1, 1, patch_size, patch_size), device=mask.device, dtype=mask.dtype + ), + bias=None, + stride=(patch_size, patch_size), + padding="valid", + ) + + token_mask_bias = torch.zeros_like(mask_tokens_after_conv) + token_mask_bias.masked_fill_(mask_tokens_after_conv < 1e-5, float("-inf")) + token_mask_bias = token_mask_bias.flatten(1) + + # Flattened dimension: (1, 729) + return token_mask_bias + + +def attn_mask_from_cimage_concatenated(cimage_concatenated, patch_size): + # Use the mask from input image (4th channel) + mask_normalized = cimage_concatenated[:, 3] + mask_unnormalized = (mask_normalized + 1) / 2 + # (1, 729) + token_mask_bias = get_token_mask_bias(mask_unnormalized, patch_size=patch_size) + + # attn_mask: (B, 1, Q, KV) + # print("Token positions:", token_mask.nonzero()) + + # Obtain token mask in the bias format: in mask 0, out of mask -inf + q_kv = token_mask_bias.shape[-1] + attn_mask_bias = token_mask_bias[:, None, None, :].repeat(1, 1, q_kv, 1) + + # Visualizations + # print(f"token_mask_bias shape: {token_mask_bias.shape}, attn_mask_bias shape: {attn_mask_bias.shape}") + # import matplotlib.pyplot as plt + # plt.imshow(attn_mask_bias[0, 0, 0].view(27, 27).detach().cpu().numpy()) + # plt.title("Attention mask (outside)") + # plt.show() + + return attn_mask_bias + + +# From SiglipEncoderLayer. We would like to modify this to cross-attention. + + +class CrossAttnEncoderLayer(nn.Module): + def __init__(self, config: ContextProviderConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.cross_attn = ContextProviderCrossAttention(config) + self.residual_dropout = nn.Dropout(config.residual_dropout) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = ContextProviderMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + if config.zero_init_output: + # TODO: alternatively, we could parameterize with an MLP + # These factors are initialized with 0 (so only residual passes through) + if config.context_provider_type != "cross_attn_at_the_end": + self.register_parameter("attn_factor", nn.Parameter(torch.zeros((1,)))) + self.register_parameter("mlp_factor", nn.Parameter(torch.zeros((1,)))) + else: + # Use scalar tensor for compatibility + self.register_parameter( + "attn_factor", nn.Parameter(torch.zeros((1,)).view(())) + ) + self.register_parameter( + "mlp_factor", nn.Parameter(torch.zeros((1,)).view(())) + ) + else: + self.attn_factor = 1.0 + self.mlp_factor = 1.0 + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.cross_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + # Dropping the residual: let the model leverage more on the context + hidden_states = ( + self.residual_dropout(residual) + self.attn_factor * hidden_states + ) + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.mlp_factor * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class CrossAttnContextProviderEndToAll(nn.Module): + def __init__(self, config: ContextProviderConfig): + super().__init__() + self.layers = nn.ModuleList( + [ + CrossAttnEncoderLayer(config) + for i in enumerate(range(config.num_hidden_layers)) + if config.context_provider_layer_indices is None + or i in config.context_provider_layer_indices + ] + ) + self.patch_size = config.patch_size + self.masked_cross_attn = config.masked_cross_attn + + def forward(self, context_image_features, cimage_concatenated, vision_tower): + # Use the mask from input image (4th channel) + if self.masked_cross_attn: + attn_mask = attn_mask_from_cimage_concatenated( + cimage_concatenated, patch_size=self.patch_size + ) + else: + attn_mask = None + + detail_raw_image = cimage_concatenated[:, 4:, ...] + # NOTE: when using context image as queries, the context image was swapped with the detail image before passing into the context provider + outputs = vision_tower( + detail_raw_image, + context_provider_layers=self.layers, + contexts=context_image_features, + cross_attention_mask=attn_mask, + ) + + return outputs + + +class ContextProvider(PreTrainedModel): + config_class = ContextProviderConfig + + def __init__( + self, context_provider_cfg: ContextProviderConfig, config: PretrainedConfig + ): + super().__init__(context_provider_cfg) + + self.context_image_as_queries = context_provider_cfg.context_image_as_queries + self.context_provider_type = context_provider_type = ( + context_provider_cfg.context_provider_type + ) + + self.treat_image_as_cimage = context_provider_cfg.treat_image_as_cimage + + if self.context_image_as_queries: + assert ( + not context_provider_cfg.masked_cross_attn + ), "Masked cross-attention not implemented when using context image as queries." + assert ( + "concat" not in context_provider_type + ), "Concat not implemented when using context image as queries." + + if context_provider_type == "cross_attn_end_to_all": + # Information flow: end of context features -> all detail features + self.context_provider_module = CrossAttnContextProviderEndToAll( + context_provider_cfg + ) + else: + raise ValueError(f"Unknown context provider type: {context_provider_type}") + + def forward( + self, + cimage_full_features=None, + cimage_crop_features=None, + cimage_concatenated=None, + vision_tower=None, + ): + if self.context_provider_type == "cross_attn_end_to_all": + assert ( + cimage_full_features.shape[0] == cimage_concatenated.shape[0] + ), f"shape mismatches: {cimage_full_features.shape[0]} != {cimage_concatenated.shape[0]}" + return self.context_provider_module( + context_image_features=cimage_full_features, + cimage_concatenated=cimage_concatenated, + vision_tower=vision_tower, + ) + else: + raise ValueError(f"Unknown context provider type: {context_provider_type}") + + +AutoConfig.register("context_provider", ContextProviderConfig) +AutoModel.register(ContextProviderConfig, ContextProvider) +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Image processor class for RADIO.""" + + +if is_torch_available(): + import torch + import torch.nn.functional as F + +if is_torchvision_available(): + pass + +if is_tf_available(): + + pass + +logger = logging.get_logger(__name__) + + +def rank_print(s): + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + print(f"[Rank {rank}] {s}") + + +class ImageProcessor(BaseImageProcessor): + r""" + Constructs an image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`): + Size of the output image after resizing. If "longest_edge" is specified, resizes the longest edge of the image to match + `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image + to that size, possibly changing the aspect ratio. Can be overridden by the `size` parameter in the + `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the + `preprocess` method. + pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`): + Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess` + method. + pad_value (`float` or `Iterable[float]`, *optional*, defaults to `0.`): + Value of padded pixels. + pad_multiple (`int`, *optional*, defaults to `None`): + Pad to a multiple of specified number. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = True, + pad_size: int = None, + pad_multiple: int = None, + pad_value: Optional[Union[float, List[float]]] = 0.0, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 1024} + size = ( + get_size_dict(max_size=size, default_to_square=False) + if not isinstance(size, dict) + else size + ) + + if pad_size is not None and pad_multiple is not None: + raise ValueError( + "pad_size and pad_multiple should not be set at the same time." + ) + + pad_size = ( + pad_size + if pad_size is not None + else {"height": 1024, "width": 1024} if pad_multiple is not None else None + ) + if do_pad: + pad_size = get_size_dict(pad_size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = ( + image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + ) + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_multiple = pad_multiple + self.pad_size = pad_size + self.pad_value = tuple(pad_value) if isinstance(pad_value, list) else pad_value + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "segmentation_maps", + "do_resize", + "size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_pad", + "pad_size", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def pad_image( + self, + image: np.ndarray, + pad_size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Pad an image to `(pad_size["height"], pad_size["width"])` to the right and bottom. + + Args: + image (`np.ndarray`): + Image to pad. + pad_size (`Dict[str, int]`): + Size of the output image after padding. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the + `data_format` of the `image` will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = pad_size["height"], pad_size["width"] + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + + pad_width = output_width - input_width + pad_height = output_height - input_height + + padded_image = pad( + image, + ((0, pad_height), (0, pad_width)), + data_format=data_format, + input_data_format=input_data_format, + constant_values=self.pad_value, + **kwargs, + ) + return padded_image + + def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int): + """ + Compute the output size given input size and target long side length. + """ + oldh, oldw = old_shape + scale = longest_edge * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + newh = int(newh + 0.5) + neww = int(neww + 0.5) + return (newh, neww) + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"longest_edge": int}` or `{"width": int, "height": int}` specifying the size + of the output image. If "longest_edge" is specified, resizes the longest edge of the image to match + `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image + to that size, possibly changing the aspect ratio. + resample: + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "longest_edge" not in size: + if "width" not in size or "height" not in size: + raise ValueError( + f"The `size` dictionary must contain the key `longest_edge`, or `width` and `height`. Got {size.keys()}" + ) + input_size = get_image_size(image, channel_dim=input_data_format) + if "longest_edge" in size: + output_height, output_width = self._get_preprocess_shape( + input_size, size["longest_edge"] + ) + else: + output_height, output_width = size["height"], size["width"] + return resize( + image, + size=(output_height, output_width), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _preprocess( + self, + image: ImageInput, + do_resize: bool, + do_rescale: bool, + do_normalize: bool, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + rescale_factor: Optional[float] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[Dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_resize: + image = self.resize( + image=image, + size=size, + resample=resample, + input_data_format=input_data_format, + ) + reshaped_input_size = get_image_size(image, channel_dim=input_data_format) + + if do_rescale: + image = self.rescale( + image=image, scale=rescale_factor, input_data_format=input_data_format + ) + + if do_normalize: + image = self.normalize( + image=image, + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) + + if do_pad: + if self.pad_multiple: + h, w = get_image_size(image, channel_dim=input_data_format) + pad_size = { + "height": math.ceil(h / self.pad_multiple) * self.pad_multiple, + "width": math.ceil(w / self.pad_multiple) * self.pad_multiple, + } + + image = self.pad_image( + image=image, pad_size=pad_size, input_data_format=input_data_format + ) + + return image, reshaped_input_size + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[Dict[str, int]] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]: + # image = to_numpy_array(image) + + # import time + # if int(time.time()*1000) % 10 == 0: + # # create an PIL image of size 1x1 + # image = PIL.Image.new('RGB', (1, 1)) + + if isinstance(image, Image.Image): + # PIL always uses Channels Last. + input_data_format = ChannelDimension.LAST + + # PIL RGBA images are converted to RGB + # mode_before = image.mode + if do_convert_rgb: + image = convert_to_rgb(image) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + # if isinstance(image_, np.ndarray): + # rank_print(f"preprocess image type={type(image_)} shape={image_.shape} array shape={image.shape}") + # elif isinstance(image_, Image.Image): + # rank_print(f"preprocessimage type={type(image_)} size={image_.size} mode={image_.mode} array shape={image.shape}") + # else: + # rank_print(f"preprocess unknown image type={type(image_)} array shape={image.shape}") + + if len(image.shape) == 2: + h, w = image.shape + ret = np.empty((h, w, 3), dtype=np.uint8) + ret[:, :, 0] = image + ret[:, :, 1] = image + ret[:, :, 2] = image + image = ret + rank_print(f"preprocess new image shape={image.shape}") + elif len(image.shape) == 3 and image.shape[-1] == 1: + ret = np.empty((h, w, 3), dtype=np.uint8) + ret[:, :, 0] = image[:, :, 0] + ret[:, :, 1] = image[:, :, 0] + ret[:, :, 2] = image[:, :, 0] + image = ret + rank_print(f"preprocess new image shape={image.shape}") + + if is_scaled_image(image) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + original_size = get_image_size(image, channel_dim=input_data_format) + + image, reshaped_input_size = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + pad_size=pad_size, + input_data_format=input_data_format, + ) + + if data_format is not None: + image = to_channel_dimension_format( + image, data_format, input_channel_dim=input_data_format + ) + + # rank_print(f"preprocess original_size={original_size} reshaped_input_size={reshaped_input_size} image shape={image.shape} type={type(image)}") + + # if image is a single channel convert to rgb + if do_convert_rgb and image.shape[0] == 1: + c, h, w = image.shape + ret = np.empty((3, h, w), dtype=np.uint8) + ret[0, :, :] = image[0, :, :] + ret[1, :, :] = image[0, :, :] + ret[2, :, :] = image[0, :, :] + image = ret + rank_print(f"preprocess final: {image.shape}") + + return image, original_size, reshaped_input_size + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional["PILImageResampling"] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[Dict[str, int]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The longest edge of the image is resized to + `size["longest_edge"]` whilst preserving the aspect ratio. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image pixel values by rescaling factor. + rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to apply to the image pixel values. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. + pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`): + Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and + `pad_size["width"]` if `do_pad` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = ( + get_size_dict(max_size=size, default_to_square=False) + if not isinstance(size, dict) + else size + ) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = ( + rescale_factor if rescale_factor is not None else self.rescale_factor + ) + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_pad = do_pad if do_pad is not None else self.do_pad + pad_size = pad_size if pad_size is not None else self.pad_size + if do_pad: + pad_size = get_size_dict(pad_size, default_to_square=True) + do_convert_rgb = ( + do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + ) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + images, original_sizes, reshaped_input_sizes = zip( + *( + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + pad_size=pad_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ) + ) + + data = { + "pixel_values": images, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +class VisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = getattr(args, "mm_vision_select_layer", -2) + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + self.cfg_only = None + + def feature_select(self, image_forward_outs): + image_features = image_forward_outs.hidden_states[self.select_layer] + if self.select_feature == "patch": + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def _maybe_resize_pos_embeds( + self, + model: PreTrainedModel, + image_processor: BaseImageProcessor, + resolution: int = -1, + interpolate_mode: str = "linear", + ): + if resolution in [model.config.image_size, -1]: + return + print( + f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..." + ) + embeddings = model.vision_model.embeddings + patch_size = embeddings.patch_size + num_new_tokens = int((resolution // patch_size) ** 2) + + old_embeddings = embeddings.position_embedding + match interpolate_mode: + case "linear": + # Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M + # Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)] + import torch + import torch.nn as nn + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters( + [old_embeddings.weight], modifier_rank=None + ): + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + else: + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + new_embeddings = nn.Embedding( + num_new_tokens, + old_embedding_dim, + dtype=old_embeddings.weight.dtype, + device=old_embeddings.weight.device, + ) + mapped_indices = ( + torch.arange(num_new_tokens).to(old_embeddings.weight.device) + / (num_new_tokens - 1) + * (old_num_tokens - 1) + ) + floor_indices = torch.clamp( + mapped_indices.floor().long(), min=0, max=old_num_tokens - 1 + ) + ceil_indices = torch.clamp( + mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1 + ) + if is_deepspeed_zero3_enabled(): + params = [old_embeddings.weight, new_embeddings.weight] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + interpolated_embeds = (mapped_indices - floor_indices)[ + :, None + ] * old_embeddings.weight.data[ceil_indices, :] + ( + ceil_indices - mapped_indices + )[ + :, None + ] * old_embeddings.weight.data[ + floor_indices, : + ] + else: + interpolated_embeds = (mapped_indices - floor_indices)[ + :, None + ] * old_embeddings.weight.data[ceil_indices, :] + ( + ceil_indices - mapped_indices + )[ + :, None + ] * old_embeddings.weight.data[ + floor_indices, : + ] + new_embeddings.weight.data = interpolated_embeds + case _: + raise NotImplementedError + + if hasattr(old_embeddings, "_hf_hook"): + hook = old_embeddings._hf_hook + add_hook_to_module(new_embeddings, hook) + new_embeddings.requires_grad_(old_embeddings.weight.requires_grad) + # update vision encoder's configurations + model.config.image_size = resolution + if hasattr(image_processor, "crop_size"): + # CLIP vision tower + image_processor.crop_size = resolution + else: + # SIGLIP vision tower + assert hasattr(image_processor, "size") + image_processor.size = {"height": resolution, "width": resolution} + # TODO define a '_reinitialize' method for VisionTower + embeddings.position_embedding = new_embeddings + embeddings.image_size = resolution + embeddings.num_patches = embeddings.num_positions = num_new_tokens + embeddings.position_ids = ( + torch.arange(embeddings.num_positions) + .expand((1, -1)) + .to(old_embeddings.weight.device) + ) + + def forward(self, images, **kwargs): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower( + image.to(device=self.device, dtype=self.dtype).unsqueeze(0), + output_hidden_states=True, + **kwargs, + ) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower( + images.to(device=self.device, dtype=self.dtype), + output_hidden_states=True, + **kwargs, + ) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size) ** 2 + + +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Siglip model configuration""" + + +logger = logging.get_logger(__name__) + +SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json", +} + + +class SiglipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + + Example: + + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_text_model" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": + # cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + # get the text config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["text_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_mask_channels (`int`, *optional*, defaults to 0): + Number of mask channels in the input images. + + Example: + + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + num_mask_channels=0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.num_mask_channels = num_mask_channels + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": + # cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipConfig(PretrainedConfig): + r""" + [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import SiglipConfig, SiglipModel + + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip" + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info( + "`text_config` is `None`. Initializing the `SiglipTextConfig` with default values." + ) + + if vision_config is None: + vision_config = {} + logger.info( + "`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values." + ) + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs( + cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs + ): + r""" + Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + model configuration. + + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls( + text_config=text_config.to_dict(), + vision_config=vision_config.to_dict(), + **kwargs, + ) + + +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SigLIP.""" + + +logger = logging.get_logger(__name__) + + +def is_scaled_image(image: np.ndarray) -> bool: + """ + Checks to see whether the pixel values have already been rescaled to [0, 1]. + """ + if image.dtype == np.uint8: + return False + + # It's possible the image has pixel values in [0, 255] but is of floating type + return np.min(image) >= 0 and np.max(image) <= 1 + + +if is_vision_available(): + import PIL + + +class SiglipImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 384} + size = get_size_dict(size, default_to_square=False) + image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + """ + # size = get_size_dict(size, default_to_square=False) + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain either 'shortest_edge' or 'height' and 'width'." + ) + output_size = get_resize_output_image_size( + image, size=size, default_to_square=default_to_square + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + # do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + # crop_size = crop_size if crop_size is not None else self.crop_size + # crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = ( + rescale_factor if rescale_factor is not None else self.rescale_factor + ) + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = ( + do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + ) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None: + raise ValueError("Size must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError( + "Image mean and std must be specified if do_normalize is True." + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + # if input_data_format is None: + # # We assume that all images have the same channel dimension format. + # input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample) + for image in images + ] + + if do_rescale: + images = [rescale(image=image, scale=rescale_factor) for image in images] + + if do_normalize: + output_images = [] + for image in images: + if get_channel_dimension_axis(image) == 0: + image = image.transpose((1, 2, 0)) + if image.shape[-1] == 1: + image = np.dstack((image, image, image)) + output_images.append(image) + images = output_images + # for image in images: + # # print("image shape", image.shape) + # channel_axis = get_channel_dimension_axis(image) + # num_channels = image.shape[channel_axis] + # if num_channels != len(image_mean): + # print("image_mean", image_mean) + # print("channel_axis", channel_axis) + # print("num_channels", num_channels) + # print("image.shape", image.shape) + # raise ValueError( + # f"Number of channels in the image ({num_channels}) does not match the length of image mean " + # f"({len(image_mean)})." + # ) + + images = [ + normalize(image=image, mean=image_mean, std=image_std) + for image in images + ] + + images = [to_channel_dimension_format(image, data_format) for image in images] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Siglip model.""" + + +# from ...modeling_attn_mask_utils import _prepare_4d_attention_mask + + +logger = logging.get_logger(__name__) + +# _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +# SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ +# "google/siglip-base-patch16-224", +# # See all SigLIP models at https://huggingface.co/models?filter=siglip +# ] + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + ( + self[k] + if k not in ["text_model_output", "vision_model_output"] + else getattr(self, k).to_tuple() + ) + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + print(f"Number of mask channels: {config.num_mask_channels}") + if config.num_mask_channels: + # Mask should have the same output shape to be added. + # Currently we have bias in this embedding (so that mask vs no mask are different). + self.mask_patch_embedding = nn.Conv2d( + in_channels=config.num_mask_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.mask_patch_embedding.use_zero_init = True + else: + self.mask_patch_embedding = None + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + pixel_values: torch.FloatTensor, + additional_position_embedding: Optional[torch.Tensor] = None, + additional_embedding_mode: Optional[str] = None, + ) -> torch.Tensor: + if self.mask_patch_embedding is None: + patch_embeds = self.patch_embedding( + pixel_values + ) # shape = [*, width, grid, grid] + else: + # Comment this out if you want to encode both images without mask channel and with mask channel. + # However, if different samples in the batch have different number of channels, this is not applicable. + # assert pixel_values.size(1) == 4, f"Input does not have a mask channel, shape: {pixel_values.shape}" + patch_embeds = self.patch_embedding( + pixel_values[:, :3, ...] + ) # shape = [*, width, grid, grid] + if pixel_values.size(1) == 4: + patch_embeds = patch_embeds + self.mask_patch_embedding( + pixel_values[:, 3:4, ...] + ) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if additional_position_embedding is not None: + if additional_embedding_mode == "add": + embeddings = embeddings + self.position_embedding(self.position_ids) + embeddings = embeddings + additional_position_embedding + elif additional_embedding_mode == "replace": + # The original positional embedding is not used (multiplied by zero to ensure all parameters are used to be safe) + embeddings = ( + embeddings + self.position_embedding(self.position_ids) * 0.0 + ) + embeddings = embeddings + additional_position_embedding + else: + raise ValueError( + f"additional_embedding_mode should be either 'add' or 'replace', got {additional_embedding_mode}" + ) + else: + # Without additional position embedding + embeddings = embeddings + self.position_embedding(self.position_ids) + # print("No additional position embedding") + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = ( + input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SiglipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, nn.Conv2d) and getattr(module, "use_zero_init", False): + import deepspeed + + param_list = [module.weight] + if module.bias is not None: + param_list += [module.bias] + # This is used in mask patch embedding + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(param_list, modifier_rank=0): + for param in param_list: + nn.init.zeros_(param) + else: + for param in param_list: + nn.init.zeros_(param) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + context_provider_layers: Optional[nn.ModuleList] = None, + contexts: Optional[List[torch.Tensor]] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + context_provider_layers (nn.ModuleList): ModuleList of context provider layers. + contexts: List of torch.Tensor for context (for KV in cross-attention). + cross_attention_mask (`torch.Tensor` of shape `(batch_size, q_sequence_length, kv_sequence_length)`, *optional*): mask for cross-attention. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for layer_index, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if context_provider_layers: + # Right now contexts is passed as the encoder_hidden_states (the output hidden_states of the context ViT). + context_provider_layer = context_provider_layers[layer_index] + if context_provider_layer is not None: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + context_provider_layer.__call__, + hidden_states, + contexts, + cross_attention_mask, + output_attentions, + ) + else: + layer_outputs = context_provider_layer( + hidden_states, + contexts, + cross_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, embed_dim) + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + # if attention_mask is not None: + # # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + # attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig + ) + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig + ) + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + +@add_start_docstrings(SIGLIP_START_DOCSTRING) +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise ValueError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.text_model = SiglipTextTransformer(text_config) + self.vision_model = SiglipVisionTransformer(vision_config) + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SiglipOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> inputs = processor(text=texts, images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = ( + torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + + self.logit_bias + ) + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + raise NotImplementedError("SigLIP loss to be implemented") + + if not return_dict: + output = ( + logits_per_image, + logits_per_text, + text_embeds, + image_embeds, + text_outputs, + vision_outputs, + ) + return ((loss,) + output) if loss is not None else output + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for SigLIP. +""" + + +class SiglipProcessor(ProcessorMixin): + r""" + Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. + + [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the + [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. + + Args: + image_processor ([`SiglipImageProcessor`]): + The image processor is a required input. + tokenizer ([`SiglipTokenizer`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "SiglipImageProcessor" + tokenizer_class = "SiglipTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[ + TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] + ] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = "max_length", + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` argument to + SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if text is None and images is None: + raise ValueError( + "You have to specify either text or images. Both cannot be none." + ) + + if text is not None: + encoding = self.tokenizer( + text, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + + if images is not None: + image_features = self.image_processor(images, return_tensors=return_tensors) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features.pixel_values + return encoding + elif text is not None: + return encoding + else: + return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for SigLIP model.""" + + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/spiece.model", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "google/siglip-base-patch16-224": 256, +} + +SPIECE_UNDERLINE = "▁" + + +class SiglipTokenizer(PreTrainedTokenizer): + """ + Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + model_max_length (`int`, *optional*, defaults to 64): + The maximum length (in number of tokens) for model inputs. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + model_max_length=64, + do_lower_case=True, + **kwargs, + ) -> None: + requires_backends(self, "protobuf") + + pad_token = ( + AddedToken( + pad_token, rstrip=True, lstrip=True, normalized=False, special=True + ) + if isinstance(pad_token, str) + else pad_token + ) + unk_token = ( + AddedToken( + unk_token, rstrip=True, lstrip=True, normalized=False, special=True + ) + if isinstance(unk_token, str) + else unk_token + ) + eos_token = ( + AddedToken( + eos_token, rstrip=True, lstrip=True, normalized=False, special=True + ) + if isinstance(eos_token, str) + else eos_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.do_lower_case = do_lower_case + self.vocab_file = vocab_file + + self.sp_model = self.get_spm_processor() + self.vocab_file = vocab_file + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + model_max_length=model_max_length, + do_lower_case=do_lower_case, + **kwargs, + ) + + def get_spm_processor(self): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf() + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size + def vocab_size(self): + return self.sp_model.get_piece_size() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def remove_punctuation(self, text: str) -> str: + return text.translate(str.maketrans("", "", string.punctuation)) + + # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 + def canonicalize_text(self, text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (puncuation removed). + + Args: + text (`str`): + String to be canonicalized. + keep_punctuation_exact_string (`str`, *optional*): + If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}' + (but will still remove '{' and '}' that appear separately). + """ + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + self.remove_punctuation(part) + for part in text.split(keep_punctuation_exact_string) + ) + else: + text = self.remove_punctuation(text) + text = re.sub(r"\s+", " ", text) + text = text.strip() + + return text + + def tokenize( + self, text: "TextInput", add_special_tokens=False, **kwargs + ) -> List[str]: + """ + Converts a string to a list of tokens. + """ + tokens = super().tokenize( + SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs + ) + + if ( + len(tokens) > 1 + and tokens[0] == SPIECE_UNDERLINE + and tokens[1] in self.all_special_tokens + ): + tokens = tokens[1:] + return tokens + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. + + For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`. + + Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + text = self.canonicalize_text(text, keep_punctuation_exact_string=None) + tokens = self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return ( + tokens[self.unk_token_length :] + if len(tokens) >= self.unk_token_length + else tokens + ) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + # since we manually add the prefix space, we have to remove it + tokens[0] = tokens[0].lstrip(SPIECE_UNDERLINE) + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary + def save_vocabulary( + self, save_directory: str, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +class SiglipVisionTower(VisionTower): + def __init__( + self, model_name_or_path: str, config: PretrainedConfig, state_dict=None + ): + super().__init__(model_name_or_path, config) + self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) + self.vision_tower = SiglipVisionModel.from_pretrained( + # TODO(ligeng): why pass config here leading to errors? + model_name_or_path, + torch_dtype=eval(config.model_dtype), + state_dict=state_dict, + ) + self.is_loaded = True + + +AutoConfig.register("siglip_vision_model", SiglipVisionConfig, exist_ok=True) +AutoModel.register(SiglipVisionConfig, SiglipVisionModel, exist_ok=True) + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +class LlavaLlamaConfig(LlavaConfig): + model_type = "llava_llama" + + +# FIXME we will follow the convention to add a new class for CausalLM in the future + + +class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel): + config_class = LlavaLlamaConfig + main_input_name = "input_embeds" + supports_gradient_checkpointing = True + tokenizer_image_token = staticmethod(tokenizer_image_token) + + def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None: + super().__init__(config) + self.dam_model = None + self.pretrained_model_name_or_path = None + self.init_vlm(config=config, *args, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + init_dam: bool = False, + # conv_mode and prompt_mode are only used by `init_dam` in `from_pretrained` if `init_dam` is set to True + conv_mode: str = "v1", + prompt_mode: str = "full+focal_crop", + **kwargs, + ): + if hasattr(cls, "load_pretrained"): + obj = cls.load_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + **kwargs, + ) + else: + obj = super(LlavaLlamaModel).from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + **kwargs, + ) + obj.pretrained_model_name_or_path = pretrained_model_name_or_path + + # `init_dam` is used to initialize a `DescribeAnythingModel` object in a `LlavaLlamaModel` in DAM. If you initialize `DescribeAnythingModel` on your own outside, then you don't have to use this option. + # This is very useful if you use `from_pretrained` with remote code execution and don't want to put implementation for `DescribeAnythingModel` class in your codebase. + if init_dam: + obj.init_dam(conv_mode, prompt_mode) + + return obj + + def forward( + self, + input_ids: torch.LongTensor = None, + images: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + self.freezed_module_patch() + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, position_ids, attention_mask, past_key_values, labels, images + ) + # Note (kentang-mit@): we have a unit test for this function. + if self.training: + ( + _, + new_position_ids, + new_attention_mask, + _, + new_inputs_embeds, + new_labels, + sorted_seqlens_in_batch, + ) = self.repack_multimodal_data( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) + new_input_ids = None + past_key_values = None + else: + new_attention_mask = attention_mask + new_position_ids = position_ids + new_inputs_embeds = inputs_embeds + new_labels = labels + sorted_seqlens_in_batch = attention_mask.sum(-1).int() + new_input_ids = input_ids + + outputs = self.llm.forward( + input_ids=new_input_ids, + attention_mask=new_attention_mask, + position_ids=new_position_ids, + past_key_values=past_key_values, + inputs_embeds=new_inputs_embeds, + labels=new_labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + seqlens_in_batch=sorted_seqlens_in_batch, + ) + return outputs + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.FloatTensor] = None, + images: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **generation_kwargs, + ): + if images is not None: + ( + _, + _, + attention_mask, + _, + inputs_embeds, + _, + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, None, attention_mask, None, None, images + ) + else: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = inputs_embeds.to(self.dtype) + + outputs = self.llm.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **generation_kwargs, + ) + return outputs + + def init_dam(self, conv_mode, prompt_mode): + + + model_name = get_model_name_from_path(self.pretrained_model_name_or_path) + self.dam_model = DescribeAnythingModel( + model_path=dict( + model=self, tokenizer=self.tokenizer, model_name=model_name + ), + conv_mode=conv_mode, + prompt_mode=prompt_mode, + ) + + return self.dam_model + + @property + def dam(self): + if self.dam_model is None: + self.init_dam() + return self.dam_model + + +AutoConfig.register("llava_llama", LlavaLlamaConfig) +AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel) + + +def has_tokenizer(path): + if ( + osp.exists(osp.join(path, "special_tokens_map.json")) + and osp.exists(osp.join(path, "tokenizer_config.json")) + and ( + osp.exists(osp.join(path, "tokenizer.model")) + or osp.exists(osp.join(path, "tokenizer.json")) + ) + ): + # print("[has_tokenizer]", path, True) + return True + from huggingface_hub import HfApi, file_exists + from huggingface_hub.utils import HFValidationError + + api = HfApi() + try: + valid_hf_repo = api.repo_exists(path) + except HFValidationError: + valid_hf_repo = False + if ( + valid_hf_repo + and file_exists(path, "special_tokens_map.json") + and file_exists(path, "tokenizer_config.json") + and ( + file_exists(path, "tokenizer.model") or file_exists(path, "tokenizer.json") + ) + ): + # print("[has_tokenizer]", path, True) + return True + # print("[has_tokenizer]", path, False) + return False + + +def context_length_extension(config): + orig_ctx_len = getattr(config, "max_position_embeddings", None) + model_max_length = getattr(config, "model_max_length", None) + if orig_ctx_len and model_max_length > orig_ctx_len: + print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") + scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + return config + + +def build_llm_and_tokenizer( + model_name_or_path: str, + config: PretrainedConfig, + # config_cls: PretrainedConfig = None, + # llm_cls: PreTrainedModel = None, + attn_implementation=None, + model_max_length=None, + *args, + **kwargs, +) -> PreTrainedModel: + # if config_cls is None: + # config_cls = AutoConfig + # if llm_cls is None: + # llm_cls = AutoModelForCausalLM + # config_cls = AutoConfig + # llm_cls = AutoModelForCausalLM + # extra configuration for llm + # print("build_llm_and_tokenizer():", model_name_or_path); input("DEBUG") + llm_cfg = AutoConfig.from_pretrained(model_name_or_path) + llm_cfg._attn_implementation = attn_implementation + llm_cfg.model_max_length = model_max_length + if model_max_length is not None: + context_length_extension(llm_cfg) + + llm = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + config=llm_cfg, + torch_dtype=eval(config.model_dtype), + *args, + **kwargs, + ) + + llm_path = model_name_or_path + if not has_tokenizer(llm_path): + warnings.warn( + "tokenizer found in VLM root folder. Move to ./{VILA}/llm in the future." + ) + llm_path = osp.join(llm_path, "llm") + + # TODO(ligeng): use LLM class to judge to better compability. + if "mpt" in model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + llm_path, + model_max_length=llm_cfg.model_max_length, + padding_side="right", + ) + elif "yi" in model_name_or_path.lower(): + tokenizer = AutoTokenizer.from_pretrained( + llm_path, + model_max_length=llm_cfg.model_max_length, + padding_side="right", + use_fast=False, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + llm_path, + model_max_length=llm_cfg.model_max_length, + padding_side="right", + use_fast=False, + legacy=False, + ) + + # TODO(ligeng): is this necessary for llava? + config.hidden_size = llm.config.hidden_size + return llm, tokenizer + + +# This file is modified from https://github.com/haotian-liu/LLaVA/ and https://github.com/NVlabs/VILA/ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# TODO: we may move LlavaConfig to configuration_llava.py +# from model.configuration_llava import LlavaConfig + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_pretrained_model( + model_path, + model_name, + model_base=None, + load_8bit=False, + load_4bit=False, + device_map="auto", + device="cuda", + **kwargs, +): + kwargs = {"device_map": device_map, **kwargs} + + if device != "cuda": + kwargs["device_map"] = {"": device} + + if load_8bit: + kwargs["load_in_8bit"] = True + elif load_4bit: + kwargs["load_in_4bit"] = True + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + else: + kwargs["torch_dtype"] = torch.float16 + + config = AutoConfig.from_pretrained(model_path) + config.resume_path = model_path + prepare_config_for_eval(config, kwargs) + + model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs) + tokenizer = model.tokenizer + + model.eval() + + # mm_use_im_start_end = getattr( + # model.config, "mm_use_im_start_end", False) + # mm_use_im_patch_token = getattr( + # model.config, "mm_use_im_patch_token", True) + # if mm_use_im_patch_token: + # tokenizer.add_tokens( + # [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + # if mm_use_im_start_end: + # tokenizer.add_tokens( + # [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + # ) + + model.resize_token_embeddings(len(tokenizer)) + vision_tower = model.get_vision_tower() + vision_tower.to(device=device, dtype=torch.float16) + mm_projector = model.get_mm_projector() + mm_projector.to(device=device, dtype=torch.float16) + context_provider = model.get_context_provider() + if context_provider is not None: + context_provider.to(device=device, dtype=torch.float16) + image_processor = vision_tower.image_processor + + if hasattr(model.llm.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, image_processor, context_len + + +def parse_model_name_or_path(config: PretrainedConfig, model_name="llm", suffix="_cfg"): + target_model = f"{model_name}{suffix}" + target_cfg = getattr(config, target_model, None) + + if isinstance(target_cfg, str): + return target_cfg + elif isinstance(target_cfg, dict): + return target_cfg["architectures"][0] + else: + raise ValueError(f"Invalid {target_model} configuration!") + + +def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict): + try: + # compatible with deprecated config convention + if getattr(config, "vision_tower_cfg", None) is None: + config.vision_tower_cfg = config.mm_vision_tower + except AttributeError: + raise ValueError( + f"Invalid configuration! Cannot find vision_tower in config:\n{config}" + ) + + config.model_dtype = kwargs.pop("torch_dtype").__str__() + # siglip does not support device_map = "auto" + vision_tower_name = parse_model_name_or_path(config, "vision_tower") + if "siglip" in vision_tower_name.lower(): + kwargs["device_map"] = "cuda" + + +class DescribeAnythingModel(nn.Module): + def __init__(self, model_path, conv_mode, prompt_mode, **kwargs): + super().__init__() + + self.model_path = model_path + self.conv_mode = conv_mode + self.prompt_mode = prompt_mode + + if isinstance(model_path, str): + self.tokenizer, self.model, _, _ = load_pretrained_model( + model_path, None, None, **kwargs + ) + self.model_name = get_model_name_from_path(model_path) + else: + # model_path is actually a dict with model, tokenizer, and (optionally) model_name + self.model = model_path["model"] + self.tokenizer = model_path["tokenizer"] + self.model_name = model_path.get("model_name", None) + + image_processor = self.model.vision_tower.image_processor + self.model.config.image_processor = image_processor + + def get_prompt(self, qs): + if DEFAULT_IMAGE_TOKEN not in qs: + raise ValueError("no tag found in input.") + + conv = conv_templates[self.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + return prompt, conv + + @staticmethod + def mask_to_box(mask_np): + mask_coords = np.argwhere(mask_np) + y0, x0 = mask_coords.min(axis=0) + y1, x1 = mask_coords.max(axis=0) + 1 + + h = y1 - y0 + w = x1 - x0 + + return x0, y0, w, h + + @classmethod + def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48): + if crop_mode == "full": + # no crop + info = dict(mask_np=mask_np) + return pil_img, info + + if crop_mode == "crop": + # crop image and mask + x0, y0, w, h = cls.mask_to_box(mask_np) + img_np = np.asarray(pil_img) + assert ( + img_np.shape[:2] == mask_np.shape + ), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" + cropped_mask_np = mask_np[y0 : y0 + h, x0 : x0 + w] + cropped_img_np = img_np[y0 : y0 + h, x0 : x0 + w] + cropped_pil_img = Image.fromarray(cropped_img_np) + elif crop_mode == "context_crop": + # crop image and mask + x0, y0, w, h = cls.mask_to_box(mask_np) + img_np = np.asarray(pil_img) + assert ( + img_np.shape[:2] == mask_np.shape + ), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" + img_h, img_w = img_np.shape[:2] + cropped_mask_np = mask_np[ + max(y0 - h, 0) : min(y0 + 2 * h, img_h), + max(x0 - w, 0) : min(x0 + 2 * w, img_w), + ] + cropped_img_np = img_np[ + max(y0 - h, 0) : min(y0 + 2 * h, img_h), + max(x0 - w, 0) : min(x0 + 2 * w, img_w), + ] + cropped_pil_img = Image.fromarray(cropped_img_np) + elif crop_mode == "focal_crop": + # crop image and mask + x0, y0, w, h = cls.mask_to_box(mask_np) + img_np = np.asarray(pil_img) + assert ( + img_np.shape[:2] == mask_np.shape + ), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" + img_h, img_w = img_np.shape[:2] + + xc, yc = x0 + w / 2, y0 + h / 2 + # focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD + w, h = max(w, min_box_w), max(h, min_box_h) + x0, y0 = int(xc - w / 2), int(yc - h / 2) + + cropped_mask_np = mask_np[ + max(y0 - h, 0) : min(y0 + 2 * h, img_h), + max(x0 - w, 0) : min(x0 + 2 * w, img_w), + ] + cropped_img_np = img_np[ + max(y0 - h, 0) : min(y0 + 2 * h, img_h), + max(x0 - w, 0) : min(x0 + 2 * w, img_w), + ] + cropped_pil_img = Image.fromarray(cropped_img_np) + elif crop_mode == "crop_mask": + # crop image and mask + x0, y0, w, h = cls.mask_to_box(mask_np) + img_np = np.asarray(pil_img) + assert ( + img_np.shape[:2] == mask_np.shape + ), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" + cropped_mask_np = mask_np[y0 : y0 + h, x0 : x0 + w] + cropped_img_np = img_np[y0 : y0 + h, x0 : x0 + w] + # Mask the image + cropped_img_np = cropped_img_np * cropped_mask_np[..., None] + cropped_pil_img = Image.fromarray(cropped_img_np) + else: + raise ValueError(f"Unsupported crop_mode: {crop_mode}") + + info = dict(mask_np=cropped_mask_np) + return cropped_pil_img, info + + def get_description( + self, + image_pil, + mask_pil, + query, + streaming=False, + temperature=0.2, + top_p=0.5, + num_beams=1, + max_new_tokens=512, + **kwargs, + ): + # kwargs is passed to generation_kwargs: https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig + + prompt, conv = self.get_prompt(query) + if not isinstance(image_pil, (list, tuple)): + assert not isinstance( + mask_pil, (list, tuple) + ), "image_pil and mask_pil must be both list or tuple or not list or tuple." + image_pils = [image_pil] + mask_pils = [mask_pil] + else: + image_pils = image_pil + mask_pils = mask_pil + description = self.get_description_from_prompt( + image_pils, + mask_pils, + prompt, + conv, + streaming=streaming, + temperature=temperature, + top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + **kwargs, + ) + + return description + + def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2): + # the pil has True/False (if the value is non-zero, then we treat it as True) + mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8) + images_tensor, image_info = process_image( + image_pil, + self.model.config, + None, + pil_preprocess_fn=lambda pil_img: self.crop_image( + image_pil, mask_np=mask_np, crop_mode=crop_mode + ), + ) + images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16) + + mask_np = image_info["mask_np"] + mask_pil = Image.fromarray(mask_np * 255) + + masks_tensor = process_image(mask_pil, self.model.config, None) + masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16) + + images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1) + + if crop_mode2 is not None: + images_tensor2, image_info2 = process_image( + image_pil, + self.model.config, + None, + pil_preprocess_fn=lambda pil_img: self.crop_image( + pil_img, mask_np=mask_np, crop_mode=crop_mode2 + ), + ) + images_tensor2 = images_tensor2[None].to( + self.model.device, dtype=torch.float16 + ) + + mask_np2 = image_info2["mask_np"] + mask_pil2 = Image.fromarray(mask_np2 * 255) + + masks_tensor2 = process_image(mask_pil2, self.model.config, None) + masks_tensor2 = masks_tensor2[None].to( + self.model.device, dtype=torch.float16 + ) + + images_tensor2 = torch.cat( + (images_tensor2, masks_tensor2[:, :1, ...]), dim=1 + ) + else: + images_tensor2 = None + + return ( + torch.cat((images_tensor, images_tensor2), dim=1) + if images_tensor2 is not None + else images_tensor + ) + + def get_description_from_prompt( + self, + image_pils, + mask_pils, + prompt, + conv, + streaming=False, + temperature=0.2, + top_p=0.5, + num_beams=1, + max_new_tokens=512, + **kwargs, + ): + if streaming: + return self.get_description_from_prompt_iterator( + image_pils, + mask_pils, + prompt, + conv, + streaming=True, + temperature=temperature, + top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + **kwargs, + ) + else: + # If streaming is False, there will be only one output + output = self.get_description_from_prompt_iterator( + image_pils, + mask_pils, + prompt, + conv, + streaming=False, + temperature=temperature, + top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + **kwargs, + ) + return next(output) + + def get_description_from_prompt_iterator( + self, + image_pils, + mask_pils, + prompt, + conv, + streaming=False, + temperature=0.2, + top_p=0.5, + num_beams=1, + max_new_tokens=512, + **kwargs, + ): + crop_mode, crop_mode2 = self.prompt_mode.split("+") + assert ( + crop_mode == "full" + ), "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt." + + assert len(image_pils) == len( + mask_pils + ), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}." + image_tensors = [ + self.get_image_tensor( + image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2 + ) + for image_pil, mask_pil in zip(image_pils, mask_pils) + ] + + input_ids = ( + tokenizer_image_token( + prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" + ) + .unsqueeze(0) + .cuda() + ) + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria( + keywords, self.tokenizer, input_ids + ) + + streamer = ( + TextIteratorStreamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True + ) + if streaming + else None + ) + generation_kwargs = dict( + input_ids=input_ids, + images=image_tensors, + do_sample=True if temperature > 0 else False, + use_cache=True, + stopping_criteria=[stopping_criteria], + streamer=streamer, + temperature=temperature, + top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + **kwargs, + ) + + if streaming: + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + + generated_text = "" + for new_text in streamer: + generated_text += new_text + if stop_str in generated_text: + generated_text = generated_text[: generated_text.find(stop_str)] + break + yield new_text + + thread.join() + else: + with torch.inference_mode(): + output_ids = self.model.generate(**generation_kwargs) + + outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[ + 0 + ] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[: -len(stop_str)] + outputs = outputs.strip() + + yield outputs