from typing import List import warnings import torch from torch import nn, Tensor from torchvision import transforms from torchtune.models.llama3 import lora_llama3_8b, llama3_8b from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear from torchtune.modules import TransformerDecoder with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) from imagebind.models import imagebind_model from models.imagebind_wrapper import get_imagebind_v2, V2_PATH from models.imagebind_wrapper import ImageBind IMAGEBIND_DIM = 1024 CLIP_DIM = 768 class MMEmbedding(nn.Embedding): def __init__(self, e, perception_tokens=1, use_clip=False): super().__init__( num_embeddings=e.num_embeddings, embedding_dim=e.embedding_dim, padding_idx=e.padding_idx, max_norm=e.max_norm, norm_type=e.norm_type, scale_grad_by_freq=e.scale_grad_by_freq, sparse=e.sparse, ) self._perception_tokens = perception_tokens self._context = [] self._use_clip = use_clip dim_in = IMAGEBIND_DIM + (CLIP_DIM if use_clip else 0) dim_out = e.embedding_dim * perception_tokens self.proj_to_llama = nn.Sequential( nn.Linear(dim_in, dim_out), nn.GELU(), nn.LayerNorm(dim_out), nn.Linear(dim_out, dim_out), ) def set_context(self, context): self._context = context def forward(self, input: Tensor) -> Tensor: r = super().forward(input) # self._context is first indexed by batch idx for b, context_dict in enumerate(self._context): # then by sequence idx for s, embed in context_dict.items(): # and then must be transformed from imagebind dim -> llama3 dim if self._use_clip: llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"], embed["clip_embed"]])) else: llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"]])) r[b, s:s+self._perception_tokens] = llama_embed.view(self._perception_tokens, -1) return r class MMLinear(nn.Linear): def __init__(self, o): super().__init__( in_features=o.in_features, out_features=o.out_features, bias=(o.bias != None) ) self._context = [] dim_out = CLIP_DIM dim_in = o.in_features self.proj_from_llama = nn.Sequential( nn.Linear(dim_in, dim_out), nn.GELU(), nn.LayerNorm(dim_out), nn.Linear(dim_out, dim_out), ) def set_context(self, context): self._context = context def forward(self, input_bsd: Tensor) -> Tensor: # self._context has the indexes of image llama tokens: process these with proj_from_llama self._clip_projections = [] # # self._context is first indexed by batch idx # for b, context_dict in enumerate(self._context): # # then by sequence idx # for s, embed in context_dict.items(): # # and then must be transformed from llama3 dim -> clip dim # self._clip_projections.append(( # self.proj_from_llama(input_bsd[b, s]), # (embed["clip_embed"] if "clip_embed" in embed else None) # terrible # )) r = super().forward(input_bsd) return r def lora_mmllama3_8b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, quantize_base: bool = False, perception_tokens: int = 2, use_clip: bool = False ) -> TransformerDecoder: llama3 = lora_llama3_8b( lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output, lora_rank, lora_alpha, quantize_base, ) llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip) llama3.output = MMLinear(llama3.output) return llama3 def mmllama3_8b( perception_tokens: int = 2, use_clip: bool = False ) -> TransformerDecoder: llama3 = llama3_8b() llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip) llama3.output = MMLinear(llama3.output) return llama3 def imagebind_huge(use_v2: bool=True): if use_v2: imagebind = ImageBind(v2=True) else: imagebind = imagebind_model.imagebind_huge(pretrained=True) imagebind.transform_from_pil = transforms.Compose([ transforms.Resize( 224, interpolation=transforms.InterpolationMode.BICUBIC ), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ]) return imagebind