tiny-random-janus / modeling_vlm.py
katuni4ka's picture
Upload 18 files
5c955cb verified
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import torch
from einops import rearrange
from transformers import (
AutoConfig,
AutoModelForCausalLM,
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel,
GenerationMixin
)
import numpy as np
from transformers.configuration_utils import PretrainedConfig
from .clip_encoder import CLIPVisionTower
from .siglip_vit import create_siglip_vit
from .projector import MlpProjector
from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig
from .vq_model import VQ_models
class vision_head(torch.nn.Module):
def __init__(self, params):
super().__init__()
self.output_mlp_projector = torch.nn.Linear(
params.n_embed, params.image_token_embed
)
self.vision_activation = torch.nn.GELU()
self.vision_head = torch.nn.Linear(
params.image_token_embed, params.image_token_size
)
def forward(self, x):
x = self.output_mlp_projector(x)
x = self.vision_activation(x)
x = self.vision_head(x)
return x
def model_name_to_cls(cls_name):
if "MlpProjector" in cls_name:
cls = MlpProjector
elif "CLIPVisionTower" in cls_name:
cls = CLIPVisionTower
elif "VQ" in cls_name:
from .vq_model import VQ_models
cls = VQ_models[cls_name]
elif "vision_head" in cls_name:
cls = vision_head
else:
raise ValueError(f"class_name {cls_name} is invalid.")
return cls
class MultiModalityPreTrainedModel(PreTrainedModel):
config_class = MultiModalityConfig
base_model_prefix = "multi_modality"
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def __init__(self, config: MultiModalityConfig):
super().__init__(config)
vision_config = config.vision_config
vision_cls = model_name_to_cls(vision_config.cls)
self.vision_model = vision_cls(**vision_config.params)
aligner_config = config.aligner_config
aligner_cls = model_name_to_cls(aligner_config.cls)
self.aligner = aligner_cls(aligner_config.params)
gen_vision_config = config.gen_vision_config
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
self.gen_vision_model = gen_vision_cls()
gen_aligner_config = config.gen_aligner_config
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
gen_head_config = config.gen_head_config
gen_head_cls = model_name_to_cls(gen_head_config.cls)
self.gen_head = gen_head_cls(gen_head_config.params)
self.gen_embed = torch.nn.Embedding(
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
)
language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.LongTensor,
**kwargs,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs, n = pixel_values.shape[0:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
input_ids[input_ids < 0] = 0 # ignore the image embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
# replace with the image embeddings
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids))
def forward(
self,
input_ids,
pixel_values=None,
past_key_values=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
images_seq_mask=None,
images_emb_mask=None,
**kwargs,
):
if inputs_embeds is None:
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
return self.language_model.forward(
input_ids=None,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
**kwargs,
)
def generate(
self,
input_ids=None,
pixel_values=None,
past_key_values=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
images_seq_mask=None,
images_emb_mask=None,
**kwargs
):
if inputs_embeds is None:
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
@torch.no_grad()
def generate_image(
self,
processor,
prompt: str,
temperature: float = 1,
parallel_size: int = 16,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
img_size: int = 384,
patch_size: int = 16,
generator=None
):
from PIL import Image
conversation = [
{
"role": "User",
"content": prompt,
},
{"role": "Assistant", "content": ""},
]
sft_format = processor.apply_sft_template_for_multi_turn_prompts(
conversations=conversation,
sft_format=processor.sft_format,
system_prompt="",
)
prompt = sft_format + processor.image_start_tag
input_ids = processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = processor.pad_id
inputs_embeds = self.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int)
past_key_values = None
for i in range(image_token_num_per_image):
outputs = self.language_model.model.forward(
input_ids=None,
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=past_key_values,
)
hidden_states = outputs.last_hidden_state
past_key_values = outputs.past_key_values
logits = self.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) if generator is None else torch.multinomial(probs, num_samples=1, generator=generator)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = self.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
dec = self.gen_vision_model.decode_code(
generated_tokens.to(dtype=torch.int), [parallel_size, 8, img_size // patch_size, img_size // patch_size]
)
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
images = []
for i in range(parallel_size):
images.append(Image.fromarray(visual_img[i]))
return images
AutoConfig.register("vision", VisionConfig)
AutoConfig.register("aligner", AlignerConfig)
AutoConfig.register("gen_vision", GenVisionConfig)
AutoConfig.register("gen_aligner", GenAlignerConfig)
AutoConfig.register("gen_head", GenHeadConfig)
AutoConfig.register("multi_modality", MultiModalityConfig)
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)