""" Copy from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/generation/utils.py General: 1. Enable output "past_key_values" from `model.generate(inputs)` Model-Specific: 1. Enable Llama, CodeLlama and Mistral to reuse cache when there is more than one new token. i.e. suppose input_ids.shape == [bsz, n_seq], we allow cache.shape == [bsz, n_cache] where n_cache != n_seq - 1 2. Add a function for Llama and CodeLlama to continuously encode text: given a text "a b c", the function allows encoding the substring 'b c', provided that there exists text before it, rather than assuming 'b' is the first word """ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload from dataclasses import dataclass import torch from torch import nn import torch.distributed as dist import transformers import copy import os import inspect import importlib import warnings from transformers.generation.utils import GenerationConfig, logging, is_deepspeed_zero3_enabled, _ranking_fast, _crop_past_key_values, _split_model_outputs from transformers.generation.utils import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation.utils import DisjunctiveConstraint, PhrasalConstraint from transformers.generation.utils import MaxLengthCriteria, MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria from transformers.generation.utils import ( EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, EtaLogitsWarper, ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, ForceTokensLogitsProcessor, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, ) from transformers.generation.utils import ( ModelOutput, CausalLMOutputWithPast, Seq2SeqLMOutput, ) logger = logging.get_logger(__name__) import transformers.models.auto.auto_factory import transformers.models.auto.tokenization_auto from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, TOKENIZER_MAPPING_NAMES, model_type_to_module_name def _get_model_class(config, model_mapping): model_type = model_mapping._reverse_config_mapping[type(config).__name__] if model_type == 'llama': return LlamaForCausalLM if model_type == 'mistral': return MistralForCausalLM supported_models = model_mapping[type(config)] if not isinstance(supported_models, (list, tuple)): return supported_models name_to_model = {model.__name__: model for model in supported_models} architectures = getattr(config, "architectures", []) for arch in architectures: if arch in name_to_model: return name_to_model[arch] elif f"TF{arch}" in name_to_model: return name_to_model[f"TF{arch}"] elif f"Flax{arch}" in name_to_model: return name_to_model[f"Flax{arch}"] # If not architecture is set in the config or match the supported models, the first element of the tuple is the # defaults. return supported_models[0] def tokenizer_class_from_name(class_name: str): if class_name == "LlamaTokenizer": return LlamaTokenizer if class_name == "CodeLlamaTokenizer": return CodeLlamaTokenizer if class_name == "GemmaTokenizer": return GemmaTokenizer if class_name == "PreTrainedTokenizerFast": return transformers.PreTrainedTokenizerFast for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): if class_name in tokenizers: module_name = model_type_to_module_name(module_name) module = importlib.import_module(f".{module_name}", "transformers.models") try: return getattr(module, class_name) except AttributeError: continue for config, tokenizers in TOKENIZER_MAPPING._extra_content.items(): for tokenizer in tokenizers: if getattr(tokenizer, "__name__", None) == class_name: return tokenizer # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main # init and we return the proper dummy to get an appropriate error message. main_module = importlib.import_module("transformers") if hasattr(main_module, class_name): return getattr(main_module, class_name) return None @dataclass class GreedySearchDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class ContrastiveSearchEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class ContrastiveSearchDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class GreedySearchEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class SampleDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class SampleEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class BeamSearchDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None beam_indices: Optional[torch.LongTensor] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class BeamSearchEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None beam_indices: Optional[torch.LongTensor] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class BeamSampleDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None beam_indices: Optional[torch.LongTensor] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None @dataclass class BeamSampleEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None beam_indices: Optional[torch.LongTensor] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Tuple[Tuple[torch.FloatTensor]] = None GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput] class GenerationMixin_Cache(transformers.generation.GenerationMixin): @torch.no_grad() def contrastive_search( self, input_ids: torch.LongTensor, top_k: Optional[int] = 1, penalty_alpha: Optional[float] = 0, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only batch_size = input_ids.shape[0] while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step if model_kwargs.get("past_key_values") is None: # prepare inputs model_kwargs["use_cache"] = True model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save # the `encoder_outputs` outputs = self( **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with # previous tokens) if self.config.is_encoder_decoder: last_hidden_states = outputs.decoder_hidden_states[-1] else: last_hidden_states = outputs.hidden_states[-1] # next logit for contrastive search to select top-k candidate tokens logit_for_next_step = outputs.logits[:, -1, :] model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, ) # Expands model inputs top_k times, for batched forward passes (akin to beam search). _, model_kwargs = self._expand_inputs_for_generation( expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) past_key_values = model_kwargs.get("past_key_values") if past_key_values is None: raise ValueError( f"{self.__class__.__name__} does not support caching and therefore **can't** be used " "for contrastive search." ) elif ( not isinstance(past_key_values[0], (tuple, torch.Tensor)) or past_key_values[0][0].shape[0] != batch_size ): raise ValueError( f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " "used for contrastive search without further modifications." ) # contrastive_search main logic start: # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # degeneration penalty logit_for_next_step = logits_processor(input_ids, logit_for_next_step) logit_for_next_step = logits_warper(input_ids, logit_for_next_step) next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (logit_for_next_step,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # Replicates the new past_key_values to match the `top_k` candidates new_key_values = [] for layer in model_kwargs["past_key_values"]: items = [] # item is either the key or the value matrix for item in layer: items.append(item.repeat_interleave(top_k, dim=0)) new_key_values.append(items) model_kwargs["past_key_values"] = new_key_values # compute the candidate tokens by the language model and collects their hidden_states next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) outputs = self( **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) logits = outputs.logits[:, -1, :] # name is different for encoder-decoder and decoder-only models if self.config.is_encoder_decoder: next_hidden = outputs.decoder_hidden_states[-1] full_hidden_states = outputs.decoder_hidden_states else: next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores # (model confidence minus degeneration penalty); (6) decoder hidden_states next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) next_hidden = next_hidden[range(batch_size), selected_idx, :] last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) next_decoder_hidden_states = () for layer in full_hidden_states: layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] next_decoder_hidden_states += (layer,) # select the past_key_value new_key_values = () for layer in next_past_key_values: items = () # item is either the key or the value matrix for item in layer: item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] items += (item,) new_key_values += (items,) next_past_key_values = new_key_values logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration if self.config.is_encoder_decoder: next_step_cross_attentions = () next_step_decoder_attentions = () if output_attentions: for layer in outputs.cross_attentions: layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] next_step_cross_attentions += (layer,) for layer in outputs.decoder_attentions: layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] next_step_decoder_attentions += (layer,) outputs = Seq2SeqLMOutput( past_key_values=next_past_key_values, decoder_hidden_states=next_decoder_hidden_states, decoder_attentions=next_step_decoder_attentions or None, cross_attentions=next_step_cross_attentions or None, ) else: next_step_attentions = () if output_attentions: for layer in outputs.attentions: layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] next_step_attentions += (layer,) outputs = CausalLMOutputWithPast( past_key_values=next_past_key_values, hidden_states=next_decoder_hidden_states, attentions=next_step_attentions or None, ) # contrastive_search main logic end if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True if this_peer_finished and not synced_gpus: break if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return ContrastiveSearchEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return ContrastiveSearchDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return input_ids def greedy_search( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_tokens_scores = logits_processor(input_ids, next_token_logits) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_tokens_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True if this_peer_finished and not synced_gpus: break if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return GreedySearchEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return GreedySearchDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return input_ids def sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only # auto-regressive generation while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # # finished sentences should have their next token be a padding token # if eos_token_id is not None: # if pad_token_id is None: # raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") # next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True if this_peer_finished and not synced_gpus: break if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return SampleEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return input_ids def beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores_processed,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return sequence_outputs["sequences"] def beam_sample( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, **model_kwargs, ) -> Union[BeamSampleOutput, torch.LongTensor]: # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (logits_warper(input_ids, next_token_scores_processed),) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSampleEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return BeamSampleDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return sequence_outputs["sequences"] def group_beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, **model_kwargs, ): # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups device = input_ids.device batch_beam_size, cur_len = input_ids.shape if return_dict_in_generate and output_scores: beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] else: beam_indices = None if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in # the same group don't produce same tokens everytime. beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]) for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx # indices of beams of current group among all sentences in batch batch_group_indices = [] for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] ) group_input_ids = input_ids[batch_group_indices] # select outputs of beams of current group only next_token_logits = outputs.logits[batch_group_indices, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * group_size, vocab_size) vocab_size = next_token_scores.shape[-1] next_token_scores_processed = logits_processor( group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) next_token_scores = next_token_scores.expand_as(next_token_scores_processed) if output_scores: processed_score[batch_group_indices] = next_token_scores_processed # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=process_beam_indices, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] if return_dict_in_generate and output_scores: beam_indices[beam_group_idx] = tuple( beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) ) input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group reordering_indices[batch_group_indices] = ( num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) ) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (processed_score,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache( model_kwargs["past_key_values"], reordering_indices ) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=final_beam_indices, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return sequence_outputs["sequences"] def constrained_beam_search( self, input_ids: torch.LongTensor, constrained_beam_scorer: ConstrainedBeamSearchScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = None, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) batch_size = len(constrained_beam_scorer._beam_hyps) num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) scores_for_all_vocab = next_token_scores.clone() # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) next_indices = (next_tokens / vocab_size).long() next_tokens = next_tokens % vocab_size # stateless beam_outputs = constrained_beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, scores_for_all_vocab, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) # increase cur_len cur_len = cur_len + 1 if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = constrained_beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return sequence_outputs["sequences"] def assisted_decoding( self, input_ids: torch.LongTensor, assistant_model: "PreTrainedModel", do_sample: bool = False, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ): # Assistant: initialize assistant-related variables if not hasattr(assistant_model, "max_assistant_tokens"): assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None and pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) # other auxiliary variables max_len = stopping_criteria[0].max_length this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # Assistant: main logic start cur_len = input_ids.shape[-1] assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1 # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids for _ in range(int(assistant_model.max_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits if "assistant_past_key_values" in model_kwargs: prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:] assist_attn = torch.ones_like(candidate_input_ids) # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 if assistant_model.config.is_encoder_decoder: assistant_model_outputs = assistant_model( decoder_input_ids=assist_inputs, decoder_attention_mask=assist_attn, past_key_values=model_kwargs["assistant_past_key_values"], encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: assistant_model_outputs = assistant_model( assist_inputs, attention_mask=assist_attn, past_key_values=model_kwargs["assistant_past_key_values"], ) else: if assistant_model.config.is_encoder_decoder: assistant_model_outputs = assistant_model( decoder_input_ids=candidate_input_ids, encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: assistant_model_outputs = assistant_model(candidate_input_ids) # 1.2. greedily select the next candidate token model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values if len(logits_processor) > 0: assistant_model_outputs.logits[:, -1, :] = logits_processor( candidate_input_ids, assistant_model_outputs.logits[:, -1, :] ) new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) # 1.3. stop assistant generation on EOS if eos_token_id_tensor is not None: last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) last_assistant_token_is_eos = ( ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() ) if last_assistant_token_is_eos: break else: last_assistant_token_is_eos = False candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, # we use this forward pass to also pick the subsequent logits in the original model. # 2.1. Run a forward pass on the candidate sequence if "past_key_values" in model_kwargs: model_attn = torch.ones_like(candidate_input_ids) model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] if self.config.is_encoder_decoder: outputs = self( decoder_input_ids=model_input_ids, decoder_attention_mask=model_attn, past_key_values=model_kwargs["past_key_values"], encoder_outputs=model_kwargs["encoder_outputs"], output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) else: outputs = self( model_input_ids, attention_mask=model_attn, past_key_values=model_kwargs["past_key_values"], output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) else: if self.config.is_encoder_decoder: outputs = self( decoder_input_ids=candidate_input_ids, encoder_outputs=model_kwargs["encoder_outputs"], output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) else: outputs = self( candidate_input_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # 2.2. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: for i in range(candidate_length): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) if len(logits_warper) > 0: for i in range(candidate_length): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) # 3. Obtain the next tokens from the original model logits. if do_sample: probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] else: selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1) # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # the assistant forecasted tokens until the first mismatch, or until the max length is reached. candidate_new_tokens = candidate_input_ids[:, -candidate_length:] n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated # by the model after the last candidate match is also valid, as it is generated from a correct sequence. # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there # is no match. # 5.1. Ensure we don't generate beyond max_len or an EOS token if last_assistant_token_is_eos and n_matches == candidate_length: n_matches -= 1 n_matches = min(n_matches, max_len - cur_len - 1) # 5.2. Get the valid continuation, after the matching tokens valid_tokens = selected_tokens[:, : n_matches + 1] input_ids = torch.cat((input_ids, valid_tokens), dim=-1) if streamer is not None: streamer.put(valid_tokens.cpu()) new_cur_len = input_ids.shape[-1] # 5.3. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) model_kwargs["assistant_past_key_values"] = _crop_past_key_values( assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1 ) # the assistant does not have the token after the last match, hence the -1 # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # cost of forecasting incorrect assistant tokens. if n_matches == int(assistant_model.max_assistant_tokens): assistant_model.max_assistant_tokens += 2.0 else: assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) # Assistant: main logic end if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need # Store scores, attentions and hidden_states when required # Assistant: modified to append one tuple element per token, as in the other generation methods. if return_dict_in_generate: if output_scores: scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) if "past_key_values" not in model_kwargs: added_len = new_cur_len else: added_len = n_matches + 1 if output_attentions: if self.config.is_encoder_decoder: cross_attentions = _split_model_outputs( cross_attentions, outputs.cross_attentions, cur_len, added_len ) decoder_attentions = _split_model_outputs( decoder_attentions, outputs.decoder_attentions, cur_len, added_len, is_decoder_attention=True, ) else: decoder_attentions = _split_model_outputs( decoder_attentions, outputs.attentions, cur_len, added_len, is_decoder_attention=True, ) if output_hidden_states: if self.config.is_encoder_decoder: decoder_hidden_states = _split_model_outputs( decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len ) else: decoder_hidden_states = _split_model_outputs( decoder_hidden_states, outputs.hidden_states, cur_len, added_len ) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( input_ids[:, -1] .tile(eos_token_id_tensor.shape[0], 1) .ne(eos_token_id_tensor.unsqueeze(1)) .prod(dim=0) ) # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True if this_peer_finished and not synced_gpus: break if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return GreedySearchEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return GreedySearchDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, ) else: return input_ids class LlamaForCausalLM(GenerationMixin_Cache, transformers.LlamaForCausalLM): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: past_seq_length = past_key_values[0][0].shape[2] input_ids = input_ids[:, past_seq_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, past_seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs class MistralForCausalLM(GenerationMixin_Cache, transformers.MistralForCausalLM): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: past_seq_length = past_key_values[0][0].shape[2] input_ids = input_ids[:, past_seq_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs class PreTrainedTokenizer(transformers.PreTrainedTokenizer): def get_continued_input_ids(self, text, right_padding=False, return_tensors=False): raise NotImplementedError class GemmaTokenizer(PreTrainedTokenizer, transformers.GemmaTokenizer): def add_placeholder(self, text: str): if len(text) == 0 or text[0] != '#': return '####' + text return '$$' + text def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False): if isinstance(text, str): text = self.add_placeholder(text) else: text = [self.add_placeholder(x) for x in text] input_ids = self(text, add_special_tokens=False).input_ids if isinstance(text, str): input_ids = input_ids[1:] else: input_ids = [ids[1:] for ids in input_ids] if right_padding and isinstance(text, list): max_length = max([len(ids) for ids in input_ids]) input_ids = [ ids + [self.pad_token_id] * (max_length - len(ids)) for ids in input_ids ] if return_tensors: input_ids = torch.tensor(input_ids) return input_ids class PreTrainedTokenizer(transformers.PreTrainedTokenizer): def get_continued_input_ids(self, text, right_padding=False, return_tensors=False): raise NotImplementedError class LlamaTokenizer(PreTrainedTokenizer, transformers.LlamaTokenizer): def add_placeholder(self, text: str): if len(text) == 0 or text[0] != '#': return '####' + text return '$$' + text def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False): if isinstance(text, str): text = self.add_placeholder(text) else: text = [self.add_placeholder(x) for x in text] input_ids = self(text, add_special_tokens=False).input_ids if isinstance(text, str): input_ids = input_ids[1:] else: input_ids = [ids[1:] for ids in input_ids] if right_padding and isinstance(text, list): max_length = max([len(ids) for ids in input_ids]) input_ids = [ ids + [self.pad_token_id] * (max_length - len(ids)) for ids in input_ids ] if return_tensors: input_ids = torch.tensor(input_ids) return input_ids class CodeLlamaTokenizer(PreTrainedTokenizer, transformers.CodeLlamaTokenizer): def add_placeholder(self, text: str): if len(text) == 0 or text[0] != '#': return '####' + text return '$$' + text def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False): if isinstance(text, str): text = self.add_placeholder(text) else: text = [self.add_placeholder(x) for x in text] input_ids = self(text, add_special_tokens=False).input_ids if isinstance(text, str): input_ids = input_ids[1:] else: input_ids = [ids[1:] for ids in input_ids] if right_padding and isinstance(text, list): max_length = max([len(ids) for ids in input_ids]) input_ids = [ ids + [self.pad_token_id] * (max_length - len(ids)) for ids in input_ids ] if return_tensors: input_ids = torch.tensor(input_ids) return input_ids def build_transformers_mapping_to_cached_models(): transformers.models.auto.auto_factory._get_model_class = _get_model_class def build_transformers_mapping_to_custom_tokenizers(): transformers.models.auto.tokenization_auto.tokenizer_class_from_name = tokenizer_class_from_name