|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
import sys |
|
import time |
|
from typing import Any, Dict, List |
|
|
|
import torch |
|
from torch import nn |
|
from omegaconf import DictConfig |
|
from PIL import Image |
|
|
|
from torchtune import config, utils |
|
from torchtune.utils._generation import sample |
|
from torchtune.models import convert_weights |
|
from torchtune.data import Message |
|
|
|
from models.tokenizer import START_IMAGE, END_IMAGE, START_AUDIO, END_AUDIO, START_VIDEO, END_VIDEO |
|
from imagebind.models.imagebind_model import ModalityType |
|
from diffusers import DiffusionPipeline |
|
|
|
from models import add_proj_convert_weights, _BASE_TRAINABLE |
|
import os |
|
|
|
log = utils.get_logger("DEBUG") |
|
add_proj_convert_weights() |
|
|
|
|
|
class InferenceRecipe: |
|
""" |
|
Recipe for generating tokens from a dense Transformer-based LLM. |
|
|
|
Currently this recipe supports single-GPU generation only. Speculative |
|
decoding is not supported. |
|
|
|
For more details on how to use this recipe for generation, please see our |
|
tutorial: https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#generation |
|
|
|
For using this recipe with a quantized model, please the following section of |
|
the above tutorial: |
|
https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#speeding-up-generation-using-quantization |
|
""" |
|
|
|
def __init__(self, cfg: DictConfig) -> None: |
|
self._device = utils.get_device(device=cfg.device) |
|
self._dtype = utils.get_dtype(dtype=cfg.dtype) |
|
self._quantizer = config.instantiate(cfg.inference.quantizer) |
|
self._quantization_mode = utils.get_quantizer_mode(self._quantizer) |
|
self.prompt_template = cfg.inference.prompt_template |
|
perception_tokens = cfg.model.perception_tokens |
|
self._perception_tokens = ("0 " * perception_tokens)[:perception_tokens] |
|
utils.set_seed(seed=cfg.seed) |
|
|
|
def setup(self, cfg: DictConfig) -> None: |
|
checkpointer = config.instantiate(cfg.checkpointer) |
|
if self._quantization_mode is None: |
|
ckpt_dict = checkpointer.load_checkpoint() |
|
else: |
|
|
|
|
|
|
|
ckpt_dict = checkpointer.load_checkpoint(weights_only=False) |
|
|
|
self._model = self._setup_model( |
|
model_cfg=cfg.model, |
|
model_state_dict=ckpt_dict[utils.MODEL_KEY], |
|
) |
|
with self._device: |
|
self._model.setup_caches(max_batch_size=cfg.batch_size, dtype=self._dtype) |
|
|
|
self._tokenizer = config.instantiate(cfg.tokenizer) |
|
self._mm_ids_start = self._tokenizer.encode(START_IMAGE + START_AUDIO + START_VIDEO, add_eos=False, add_bos=False) |
|
self._mm_ids_end = self._tokenizer.encode(END_IMAGE + END_AUDIO + END_VIDEO, add_eos=False, add_bos=False) |
|
self.use_clip = cfg.model.use_clip |
|
if self.use_clip: |
|
self._clip_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=self._dtype).to(self._device) |
|
|
|
def _setup_model( |
|
self, |
|
model_cfg: DictConfig, |
|
model_state_dict: Dict[str, Any], |
|
) -> nn.Module: |
|
with utils.set_default_dtype(self._dtype), self._device: |
|
model = config.instantiate(model_cfg) |
|
|
|
if self._quantization_mode is not None: |
|
model = self._quantizer.quantize(model) |
|
model = model.to(device=self._device, dtype=self._dtype) |
|
|
|
model.load_state_dict(model_state_dict) |
|
|
|
|
|
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) |
|
log.debug(f"Model is initialized with precision {self._dtype}.") |
|
|
|
return model |
|
|
|
def mm_process_prompt(self, prompt): |
|
return ( |
|
prompt |
|
.replace("{image}", f"{START_IMAGE}{self._perception_tokens}{END_IMAGE}") |
|
.replace("{audio}", f"{START_AUDIO}{self._perception_tokens}{END_AUDIO}") |
|
.replace("{video}", f"{START_VIDEO}{self._perception_tokens}{END_VIDEO}") |
|
) |
|
|
|
def extract_mm_context(self, video_ib_embed, tokens): |
|
context = {} |
|
in_mm_embed = False |
|
for idx, tok in enumerate(tokens): |
|
in_mm_embed = in_mm_embed and not tok in self._mm_ids_end |
|
if in_mm_embed: |
|
|
|
context[idx] = { |
|
"ib_embed": video_ib_embed.to(dtype=self._dtype, device=self._device), |
|
} |
|
in_mm_embed = in_mm_embed or tok in self._mm_ids_start |
|
return context |
|
|
|
@torch.no_grad() |
|
def generate(self, cfg: DictConfig, video_ib_embed: List[float]): |
|
messages = [ |
|
Message( |
|
role="user", |
|
content=self.mm_process_prompt(self.prompt_template), |
|
), |
|
Message( |
|
role="assistant", |
|
content="", |
|
) |
|
] |
|
tokens, mask = self._tokenizer.tokenize_messages(messages) |
|
tokens = tokens[:-2] |
|
mm_context = [self.extract_mm_context(video_ib_embed, tokens)] |
|
prompt = torch.tensor(tokens, dtype=torch.int, device=self._device) |
|
|
|
self._model.tok_embeddings.set_context(mm_context) |
|
self._model.output.set_context(mm_context) |
|
|
|
bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0] |
|
allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all") |
|
disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id)) |
|
|
|
|
|
def custom_generate_next_token(model, input_pos, x, temperature=1.0, top_k=None): |
|
model.tok_embeddings.set_context([]) |
|
model.output.set_context([]) |
|
|
|
|
|
logits = model(x, input_pos=input_pos) |
|
|
|
|
|
|
|
logits = logits[0, -1] |
|
|
|
|
|
token = sample(logits, temperature, top_k) |
|
if token in disallowed_tokens: |
|
return torch.tensor([self._tokenizer.eos_id]).to(x) |
|
return token |
|
|
|
|
|
|
|
if self._quantization_mode is not None: |
|
log.info("Starting compilation to improve generation performance ...") |
|
custom_generate_next_token = torch.compile( |
|
custom_generate_next_token, mode="max-autotune", fullgraph=True |
|
) |
|
t0 = time.perf_counter() |
|
_ = utils.generate( |
|
model=self._model, |
|
prompt=prompt, |
|
max_generated_tokens=2, |
|
temperature=cfg.temperature, |
|
top_k=cfg.top_k, |
|
eos_id=self._tokenizer.eos_id, |
|
custom_generate_next_token=custom_generate_next_token, |
|
) |
|
t = time.perf_counter() - t0 |
|
log.info(f"Warmup run for quantized model takes: {t:.02f} sec") |
|
|
|
t0 = time.perf_counter() |
|
generated_tokens = utils.generate( |
|
model=self._model, |
|
prompt=prompt, |
|
max_generated_tokens=cfg.max_new_tokens, |
|
temperature=cfg.temperature, |
|
top_k=cfg.top_k, |
|
eos_id=self._tokenizer.eos_id, |
|
custom_generate_next_token=custom_generate_next_token, |
|
) |
|
t = time.perf_counter() - t0 |
|
|
|
cleaned_tokens = [t for t in generated_tokens[len(prompt):] if t not in disallowed_tokens + allowed_id] |
|
caption = self._tokenizer.decode(cleaned_tokens) |
|
|
|
|
|
|
|
return caption |
|
|
|
|
|
@torch.no_grad() |
|
def generate_batch(self, cfg: DictConfig, video_ib_embed: torch.Tensor): |
|
log.info(f"inside generate_batch, video_ib_embed shape: {video_ib_embed.shape}") |
|
batch_dim = video_ib_embed.size(0) |
|
messages = [ |
|
Message( |
|
role="user", |
|
content=self.mm_process_prompt(self.prompt_template), |
|
), |
|
Message(role="assistant", content="") |
|
] |
|
tokens, mask = self._tokenizer.tokenize_messages(messages) |
|
tokens = tokens[:-2] |
|
mm_context = [self.extract_mm_context(e, tokens) for e in video_ib_embed] |
|
prompt = torch.tensor(tokens, dtype=torch.int, device=self._device).expand(batch_dim, -1).clone() |
|
prompt_length = prompt.size(1) |
|
|
|
self._model.tok_embeddings.set_context(mm_context) |
|
self._model.output.set_context(mm_context) |
|
|
|
bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0] |
|
allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all") |
|
disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id)) |
|
|
|
def generate_next_token(model, input_pos, x, temperature=1.0, top_k=None): |
|
|
|
|
|
|
|
logits = model(x, input_pos=input_pos)[:, -1] |
|
tokens = sample(logits, temperature, top_k) |
|
return torch.tensor([ |
|
[self._tokenizer.eos_id if t in disallowed_tokens else t for t in toks] |
|
for toks in tokens |
|
]).to(x.device) |
|
|
|
generated_tokens = prompt.clone() |
|
|
|
stop_token_reached = torch.zeros(batch_dim, dtype=torch.bool, device=prompt.device) |
|
|
|
|
|
tokens = generate_next_token( |
|
self._model, |
|
input_pos=torch.arange(0, prompt_length, device=prompt.device), |
|
x=prompt, |
|
temperature=cfg.temperature, |
|
top_k=cfg.top_k, |
|
) |
|
eot_reached_b = tokens == self._tokenizer.eot_id |
|
generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) |
|
|
|
self._model.tok_embeddings.set_context([]) |
|
self._model.output.set_context([]) |
|
|
|
input_pos = torch.tensor([prompt_length], device=prompt.device) |
|
for _ in range(cfg.max_new_tokens - 1): |
|
tokens = generate_next_token( |
|
self._model, input_pos=input_pos, x=tokens, temperature=cfg.temperature, top_k=cfg.top_k |
|
) |
|
eot_reached_b |= tokens == self._tokenizer.eot_id |
|
tokens *= ~eot_reached_b |
|
generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) |
|
if eot_reached_b.all(): |
|
print('eot_reached_b.all()') |
|
break |
|
input_pos += 1 |
|
|
|
captions = [] |
|
for caption_tokens in generated_tokens.tolist(): |
|
captions.append(self._tokenizer.decode(caption_tokens[prompt.size(1):])) |
|
return captions |
|
|
|
|
|
@config.parse |
|
def main(cfg: DictConfig) -> None: |
|
config.log_config(recipe_name="InferenceRecipe", cfg=cfg) |
|
cfg.model = DictConfig({ |
|
"_component_": "models.mmllama3_8b", |
|
"use_clip": False, |
|
"perception_tokens": cfg.model.perception_tokens, |
|
}) |
|
cfg.batch_size = 4 |
|
cfg.checkpointer.checkpoint_dir = os.path.dirname("/home/salman/tezuesh/omegalabs-anytoany-bittensor/sandboxing/cache/xzistance_omega-a2a-hotkey/meta_model_0.pth") |
|
|
|
cfg.checkpointer.checkpoint_files = ["models/meta_model_0.pt"] |
|
cfg.inference.max_new_tokens = 300 |
|
cfg.tokenizer.path = "./models/tokenizer.model" |
|
inference_recipe = InferenceRecipe(cfg) |
|
inference_recipe.setup(cfg=cfg) |
|
captions = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=torch.randn(4,1024)) |
|
print(captions) |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.exit(main()) |