auto-info / PSV /utils /cached_models.py
rookiemango's picture
Upload folder using huggingface_hub
da66274 verified
"""
Copy from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/generation/utils.py
General:
1. Enable output "past_key_values" from `model.generate(inputs)`
Model-Specific:
1. Enable Llama, CodeLlama and Mistral to reuse cache when there is more than one new token. i.e. suppose input_ids.shape == [bsz, n_seq], we allow cache.shape == [bsz, n_cache] where n_cache != n_seq - 1
2. Add a function for Llama and CodeLlama to continuously encode text: given a text "a b c", the function allows encoding the substring 'b c', provided that there exists text before it, rather than assuming 'b' is the first word
"""
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload
from dataclasses import dataclass
import torch
from torch import nn
import torch.distributed as dist
import transformers
import copy
import os
import inspect
import importlib
import warnings
from transformers.generation.utils import GenerationConfig, logging, is_deepspeed_zero3_enabled, _ranking_fast, _crop_past_key_values, _split_model_outputs
from transformers.generation.utils import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation.utils import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation.utils import MaxLengthCriteria, MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria
from transformers.generation.utils import (
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
from transformers.generation.utils import (
ModelOutput,
CausalLMOutputWithPast,
Seq2SeqLMOutput,
)
logger = logging.get_logger(__name__)
import transformers.models.auto.auto_factory
import transformers.models.auto.tokenization_auto
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, TOKENIZER_MAPPING_NAMES, model_type_to_module_name
def _get_model_class(config, model_mapping):
model_type = model_mapping._reverse_config_mapping[type(config).__name__]
if model_type == 'llama':
return LlamaForCausalLM
if model_type == 'mistral':
return MistralForCausalLM
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models
name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return supported_models[0]
def tokenizer_class_from_name(class_name: str):
if class_name == "LlamaTokenizer":
return LlamaTokenizer
if class_name == "CodeLlamaTokenizer":
return CodeLlamaTokenizer
if class_name == "GemmaTokenizer":
return GemmaTokenizer
if class_name == "PreTrainedTokenizerFast":
return transformers.PreTrainedTokenizerFast
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name)
except AttributeError:
continue
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
for tokenizer in tokenizers:
if getattr(tokenizer, "__name__", None) == class_name:
return tokenizer
# We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
# init and we return the proper dummy to get an appropriate error message.
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)
return None
@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class GreedySearchEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class SampleDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class SampleEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class BeamSearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class BeamSearchEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class BeamSampleDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
@dataclass
class BeamSampleEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
class GenerationMixin_Cache(transformers.generation.GenerationMixin):
@torch.no_grad()
def contrastive_search(
self,
input_ids: torch.LongTensor,
top_k: Optional[int] = 1,
penalty_alpha: Optional[float] = 0,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
batch_size = input_ids.shape[0]
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None:
# prepare inputs
model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
outputs = self(
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens)
if self.config.is_encoder_decoder:
last_hidden_states = outputs.decoder_hidden_states[-1]
else:
last_hidden_states = outputs.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = outputs.logits[:, -1, :]
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
)
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
past_key_values = model_kwargs.get("past_key_values")
if past_key_values is None:
raise ValueError(
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
"for contrastive search."
)
elif (
not isinstance(past_key_values[0], (tuple, torch.Tensor))
or past_key_values[0][0].shape[0] != batch_size
):
raise ValueError(
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
"used for contrastive search without further modifications."
)
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty
logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (logit_for_next_step,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# Replicates the new past_key_values to match the `top_k` candidates
new_key_values = []
for layer in model_kwargs["past_key_values"]:
items = []
# item is either the key or the value matrix
for item in layer:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(items)
model_kwargs["past_key_values"] = new_key_values
# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
# (model confidence minus degeneration penalty); (6) decoder hidden_states
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
next_hidden = next_hidden[range(batch_size), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
next_decoder_hidden_states = ()
for layer in full_hidden_states:
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
next_decoder_hidden_states += (layer,)
# select the past_key_value
new_key_values = ()
for layer in next_past_key_values:
items = ()
# item is either the key or the value matrix
for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
items += (item,)
new_key_values += (items,)
next_past_key_values = new_key_values
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if self.config.is_encoder_decoder:
next_step_cross_attentions = ()
next_step_decoder_attentions = ()
if output_attentions:
for layer in outputs.cross_attentions:
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
next_step_cross_attentions += (layer,)
for layer in outputs.decoder_attentions:
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
next_step_decoder_attentions += (layer,)
outputs = Seq2SeqLMOutput(
past_key_values=next_past_key_values,
decoder_hidden_states=next_decoder_hidden_states,
decoder_attentions=next_step_decoder_attentions or None,
cross_attentions=next_step_cross_attentions or None,
)
else:
next_step_attentions = ()
if output_attentions:
for layer in outputs.attentions:
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
next_step_attentions += (layer,)
outputs = CausalLMOutputWithPast(
past_key_values=next_past_key_values,
hidden_states=next_decoder_hidden_states,
attentions=next_step_attentions or None,
)
# contrastive_search main logic end
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return ContrastiveSearchEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return ContrastiveSearchDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return input_ids
def greedy_search(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_tokens_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GreedySearchEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return GreedySearchDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return input_ids
def sample(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# # finished sentences should have their next token be a padding token
# if eos_token_id is not None:
# if pad_token_id is None:
# raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
# next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return SampleEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return SampleDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return input_ids
def beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores_processed,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return sequence_outputs["sequences"]
def beam_sample(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
**model_kwargs,
) -> Union[BeamSampleOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (logits_warper(input_ids, next_token_scores_processed),)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSampleEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return BeamSampleDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return sequence_outputs["sequences"]
def group_beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
**model_kwargs,
):
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
else:
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of current group only
next_token_logits = outputs.logits[batch_group_indices, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
if output_scores:
processed_score[batch_group_indices] = next_token_scores_processed
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
if return_dict_in_generate and output_scores:
beam_indices[beam_group_idx] = tuple(
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
)
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
+ group_start_idx
+ (beam_idx % group_size)
)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (processed_score,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return sequence_outputs["sequences"]
def constrained_beam_search(
self,
input_ids: torch.LongTensor,
constrained_beam_scorer: ConstrainedBeamSearchScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = None,
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
scores_for_all_vocab = next_token_scores.clone()
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = (next_tokens / vocab_size).long()
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = constrained_beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
scores_for_all_vocab,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
# increase cur_len
cur_len = cur_len + 1
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
sequence_outputs = constrained_beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return sequence_outputs["sequences"]
def assisted_decoding(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
do_sample: bool = False,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
):
# Assistant: initialize assistant-related variables
if not hasattr(assistant_model, "max_assistant_tokens"):
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None and pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
# other auxiliary variables
max_len = stopping_criteria[0].max_length
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# Assistant: main logic start
cur_len = input_ids.shape[-1]
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(assistant_model.max_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
assist_attn = torch.ones_like(candidate_input_ids)
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=assist_inputs,
decoder_attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(
assist_inputs,
attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
)
else:
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(candidate_input_ids)
# 1.2. greedily select the next candidate token
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
if len(logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
# 1.3. stop assistant generation on EOS
if eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_assistant_token_is_eos:
break
else:
last_assistant_token_is_eos = False
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Run a forward pass on the candidate sequence
if "past_key_values" in model_kwargs:
model_attn = torch.ones_like(candidate_input_ids)
model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=model_input_ids,
decoder_attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
else:
outputs = self(
model_input_ids,
attention_mask=model_attn,
past_key_values=model_kwargs["past_key_values"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
else:
if self.config.is_encoder_decoder:
outputs = self(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["encoder_outputs"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
else:
outputs = self(
candidate_input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# 2.2. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
if len(logits_warper) > 0:
for i in range(candidate_length):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
# is no match.
# 5.1. Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)
# 5.2. Get the valid continuation, after the matching tokens
valid_tokens = selected_tokens[:, : n_matches + 1]
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1]
# 5.3. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if n_matches == int(assistant_model.max_assistant_tokens):
assistant_model.max_assistant_tokens += 2.0
else:
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
# Assistant: main logic end
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
# Store scores, attentions and hidden_states when required
# Assistant: modified to append one tuple element per token, as in the other generation methods.
if return_dict_in_generate:
if output_scores:
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
if "past_key_values" not in model_kwargs:
added_len = new_cur_len
else:
added_len = n_matches + 1
if output_attentions:
if self.config.is_encoder_decoder:
cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, cur_len, added_len
)
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.decoder_attentions,
cur_len,
added_len,
is_decoder_attention=True,
)
else:
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,
cur_len,
added_len,
is_decoder_attention=True,
)
if output_hidden_states:
if self.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
)
else:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, cur_len, added_len
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
input_ids[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GreedySearchEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return GreedySearchDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None,
)
else:
return input_ids
class LlamaForCausalLM(GenerationMixin_Cache, transformers.LlamaForCausalLM):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
past_seq_length = past_key_values[0][0].shape[2]
input_ids = input_ids[:, past_seq_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, past_seq_length:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
class MistralForCausalLM(GenerationMixin_Cache, transformers.MistralForCausalLM):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
past_seq_length = past_key_values[0][0].shape[2]
input_ids = input_ids[:, past_seq_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
class PreTrainedTokenizer(transformers.PreTrainedTokenizer):
def get_continued_input_ids(self, text, right_padding=False, return_tensors=False):
raise NotImplementedError
class GemmaTokenizer(PreTrainedTokenizer, transformers.GemmaTokenizer):
def add_placeholder(self, text: str):
if len(text) == 0 or text[0] != '#':
return '####' + text
return '$$' + text
def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False):
if isinstance(text, str):
text = self.add_placeholder(text)
else:
text = [self.add_placeholder(x) for x in text]
input_ids = self(text, add_special_tokens=False).input_ids
if isinstance(text, str):
input_ids = input_ids[1:]
else:
input_ids = [ids[1:] for ids in input_ids]
if right_padding and isinstance(text, list):
max_length = max([len(ids) for ids in input_ids])
input_ids = [
ids
+ [self.pad_token_id] * (max_length - len(ids))
for ids in input_ids
]
if return_tensors:
input_ids = torch.tensor(input_ids)
return input_ids
class PreTrainedTokenizer(transformers.PreTrainedTokenizer):
def get_continued_input_ids(self, text, right_padding=False, return_tensors=False):
raise NotImplementedError
class LlamaTokenizer(PreTrainedTokenizer, transformers.LlamaTokenizer):
def add_placeholder(self, text: str):
if len(text) == 0 or text[0] != '#':
return '####' + text
return '$$' + text
def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False):
if isinstance(text, str):
text = self.add_placeholder(text)
else:
text = [self.add_placeholder(x) for x in text]
input_ids = self(text, add_special_tokens=False).input_ids
if isinstance(text, str):
input_ids = input_ids[1:]
else:
input_ids = [ids[1:] for ids in input_ids]
if right_padding and isinstance(text, list):
max_length = max([len(ids) for ids in input_ids])
input_ids = [
ids
+ [self.pad_token_id] * (max_length - len(ids))
for ids in input_ids
]
if return_tensors:
input_ids = torch.tensor(input_ids)
return input_ids
class CodeLlamaTokenizer(PreTrainedTokenizer, transformers.CodeLlamaTokenizer):
def add_placeholder(self, text: str):
if len(text) == 0 or text[0] != '#':
return '####' + text
return '$$' + text
def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False):
if isinstance(text, str):
text = self.add_placeholder(text)
else:
text = [self.add_placeholder(x) for x in text]
input_ids = self(text, add_special_tokens=False).input_ids
if isinstance(text, str):
input_ids = input_ids[1:]
else:
input_ids = [ids[1:] for ids in input_ids]
if right_padding and isinstance(text, list):
max_length = max([len(ids) for ids in input_ids])
input_ids = [
ids
+ [self.pad_token_id] * (max_length - len(ids))
for ids in input_ids
]
if return_tensors:
input_ids = torch.tensor(input_ids)
return input_ids
def build_transformers_mapping_to_cached_models():
transformers.models.auto.auto_factory._get_model_class = _get_model_class
def build_transformers_mapping_to_custom_tokenizers():
transformers.models.auto.tokenization_auto.tokenizer_class_from_name = tokenizer_class_from_name