# Copyright (c) 2023-2024 DeepSeek. # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in # the Software without restriction, including without limitation the rights to # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of # the Software, and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import torch from einops import rearrange from transformers import ( AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM, PreTrainedModel, GenerationMixin ) import numpy as np from transformers.configuration_utils import PretrainedConfig from .clip_encoder import CLIPVisionTower from .siglip_vit import create_siglip_vit from .projector import MlpProjector from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig from .vq_model import VQ_models class vision_head(torch.nn.Module): def __init__(self, params): super().__init__() self.output_mlp_projector = torch.nn.Linear( params.n_embed, params.image_token_embed ) self.vision_activation = torch.nn.GELU() self.vision_head = torch.nn.Linear( params.image_token_embed, params.image_token_size ) def forward(self, x): x = self.output_mlp_projector(x) x = self.vision_activation(x) x = self.vision_head(x) return x def model_name_to_cls(cls_name): if "MlpProjector" in cls_name: cls = MlpProjector elif "CLIPVisionTower" in cls_name: cls = CLIPVisionTower elif "VQ" in cls_name: from .vq_model import VQ_models cls = VQ_models[cls_name] elif "vision_head" in cls_name: cls = vision_head else: raise ValueError(f"class_name {cls_name} is invalid.") return cls class MultiModalityPreTrainedModel(PreTrainedModel): config_class = MultiModalityConfig base_model_prefix = "multi_modality" _no_split_modules = [] _skip_keys_device_placement = "past_key_values" class MultiModalityCausalLM(MultiModalityPreTrainedModel): def __init__(self, config: MultiModalityConfig): super().__init__(config) vision_config = config.vision_config vision_cls = model_name_to_cls(vision_config.cls) self.vision_model = vision_cls(**vision_config.params) aligner_config = config.aligner_config aligner_cls = model_name_to_cls(aligner_config.cls) self.aligner = aligner_cls(aligner_config.params) gen_vision_config = config.gen_vision_config gen_vision_cls = model_name_to_cls(gen_vision_config.cls) self.gen_vision_model = gen_vision_cls() gen_aligner_config = config.gen_aligner_config gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) gen_head_config = config.gen_head_config gen_head_cls = model_name_to_cls(gen_head_config.cls) self.gen_head = gen_head_cls(gen_head_config.params) self.gen_embed = torch.nn.Embedding( gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed ) language_config = config.language_config self.language_model = LlamaForCausalLM(language_config) def prepare_inputs_embeds( self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, images_seq_mask: torch.LongTensor, images_emb_mask: torch.LongTensor, **kwargs, ): """ Args: input_ids (torch.LongTensor): [b, T] pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] images_seq_mask (torch.BoolTensor): [b, T] images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) Returns: input_embeds (torch.Tensor): [b, T, D] """ bs, n = pixel_values.shape[0:2] images = rearrange(pixel_values, "b n c h w -> (b n) c h w") # [b x n, T2, D] images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D] -> [b, n x T2, D] images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # [b, n, T2] -> [b, n x T2] images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # [b, T, D] input_ids[input_ids < 0] = 0 # ignore the image embeddings inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # replace with the image embeddings inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] return inputs_embeds def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): return self.gen_aligner(self.gen_embed(image_ids)) def forward( self, input_ids, pixel_values=None, past_key_values=None, inputs_embeds=None, attention_mask=None, position_ids=None, images_seq_mask=None, images_emb_mask=None, **kwargs, ): if inputs_embeds is None: inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs) return self.language_model.forward( input_ids=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, **kwargs, ) def generate( self, input_ids=None, pixel_values=None, past_key_values=None, inputs_embeds=None, attention_mask=None, position_ids=None, images_seq_mask=None, images_emb_mask=None, **kwargs ): if inputs_embeds is None: inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs) return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs) @torch.no_grad() def generate_image( self, processor, prompt: str, temperature: float = 1, parallel_size: int = 16, cfg_weight: float = 5, image_token_num_per_image: int = 576, img_size: int = 384, patch_size: int = 16, generator=None ): from PIL import Image conversation = [ { "role": "User", "content": prompt, }, {"role": "Assistant", "content": ""}, ] sft_format = processor.apply_sft_template_for_multi_turn_prompts( conversations=conversation, sft_format=processor.sft_format, system_prompt="", ) prompt = sft_format + processor.image_start_tag input_ids = processor.tokenizer.encode(prompt) input_ids = torch.LongTensor(input_ids) tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = processor.pad_id inputs_embeds = self.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int) past_key_values = None for i in range(image_token_num_per_image): outputs = self.language_model.model.forward( input_ids=None, inputs_embeds=inputs_embeds, use_cache=True, past_key_values=past_key_values, ) hidden_states = outputs.last_hidden_state past_key_values = outputs.past_key_values logits = self.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) if generator is None else torch.multinomial(probs, num_samples=1, generator=generator) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = self.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) dec = self.gen_vision_model.decode_code( generated_tokens.to(dtype=torch.int), [parallel_size, 8, img_size // patch_size, img_size // patch_size] ) dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) visual_img[:, :, :] = dec images = [] for i in range(parallel_size): images.append(Image.fromarray(visual_img[i])) return images AutoConfig.register("vision", VisionConfig) AutoConfig.register("aligner", AlignerConfig) AutoConfig.register("gen_vision", GenVisionConfig) AutoConfig.register("gen_aligner", GenAlignerConfig) AutoConfig.register("gen_head", GenHeadConfig) AutoConfig.register("multi_modality", MultiModalityConfig) AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)