|
""" |
|
Copy from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/generation/utils.py |
|
|
|
General: |
|
1. Enable output "past_key_values" from `model.generate(inputs)` |
|
|
|
Model-Specific: |
|
1. Enable Llama, CodeLlama and Mistral to reuse cache when there is more than one new token. i.e. suppose input_ids.shape == [bsz, n_seq], we allow cache.shape == [bsz, n_cache] where n_cache != n_seq - 1 |
|
2. Add a function for Llama and CodeLlama to continuously encode text: given a text "a b c", the function allows encoding the substring 'b c', provided that there exists text before it, rather than assuming 'b' is the first word |
|
""" |
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload |
|
from dataclasses import dataclass |
|
import torch |
|
from torch import nn |
|
import torch.distributed as dist |
|
import transformers |
|
import copy |
|
import os |
|
import inspect |
|
import importlib |
|
import warnings |
|
from transformers.generation.utils import GenerationConfig, logging, is_deepspeed_zero3_enabled, _ranking_fast, _crop_past_key_values, _split_model_outputs |
|
from transformers.generation.utils import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer |
|
from transformers.generation.utils import DisjunctiveConstraint, PhrasalConstraint |
|
from transformers.generation.utils import MaxLengthCriteria, MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria |
|
from transformers.generation.utils import ( |
|
EncoderRepetitionPenaltyLogitsProcessor, |
|
EpsilonLogitsWarper, |
|
EtaLogitsWarper, |
|
ExponentialDecayLengthPenalty, |
|
ForcedBOSTokenLogitsProcessor, |
|
ForcedEOSTokenLogitsProcessor, |
|
ForceTokensLogitsProcessor, |
|
HammingDiversityLogitsProcessor, |
|
InfNanRemoveLogitsProcessor, |
|
LogitNormalization, |
|
LogitsProcessorList, |
|
MinLengthLogitsProcessor, |
|
MinNewTokensLengthLogitsProcessor, |
|
NoBadWordsLogitsProcessor, |
|
NoRepeatNGramLogitsProcessor, |
|
PrefixConstrainedLogitsProcessor, |
|
RepetitionPenaltyLogitsProcessor, |
|
SuppressTokensAtBeginLogitsProcessor, |
|
SuppressTokensLogitsProcessor, |
|
TemperatureLogitsWarper, |
|
TopKLogitsWarper, |
|
TopPLogitsWarper, |
|
TypicalLogitsWarper, |
|
) |
|
from transformers.generation.utils import ( |
|
ModelOutput, |
|
CausalLMOutputWithPast, |
|
Seq2SeqLMOutput, |
|
) |
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
import transformers.models.auto.auto_factory |
|
import transformers.models.auto.tokenization_auto |
|
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, TOKENIZER_MAPPING_NAMES, model_type_to_module_name |
|
|
|
|
|
def _get_model_class(config, model_mapping): |
|
model_type = model_mapping._reverse_config_mapping[type(config).__name__] |
|
if model_type == 'llama': |
|
return LlamaForCausalLM |
|
if model_type == 'mistral': |
|
return MistralForCausalLM |
|
|
|
supported_models = model_mapping[type(config)] |
|
if not isinstance(supported_models, (list, tuple)): |
|
return supported_models |
|
|
|
name_to_model = {model.__name__: model for model in supported_models} |
|
architectures = getattr(config, "architectures", []) |
|
for arch in architectures: |
|
if arch in name_to_model: |
|
return name_to_model[arch] |
|
elif f"TF{arch}" in name_to_model: |
|
return name_to_model[f"TF{arch}"] |
|
elif f"Flax{arch}" in name_to_model: |
|
return name_to_model[f"Flax{arch}"] |
|
|
|
|
|
|
|
return supported_models[0] |
|
|
|
|
|
def tokenizer_class_from_name(class_name: str): |
|
if class_name == "LlamaTokenizer": |
|
return LlamaTokenizer |
|
if class_name == "CodeLlamaTokenizer": |
|
return CodeLlamaTokenizer |
|
if class_name == "GemmaTokenizer": |
|
return GemmaTokenizer |
|
|
|
if class_name == "PreTrainedTokenizerFast": |
|
return transformers.PreTrainedTokenizerFast |
|
|
|
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): |
|
if class_name in tokenizers: |
|
module_name = model_type_to_module_name(module_name) |
|
|
|
module = importlib.import_module(f".{module_name}", "transformers.models") |
|
try: |
|
return getattr(module, class_name) |
|
except AttributeError: |
|
continue |
|
|
|
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items(): |
|
for tokenizer in tokenizers: |
|
if getattr(tokenizer, "__name__", None) == class_name: |
|
return tokenizer |
|
|
|
|
|
|
|
main_module = importlib.import_module("transformers") |
|
if hasattr(main_module, class_name): |
|
return getattr(main_module, class_name) |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class GreedySearchDecoderOnlyOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class ContrastiveSearchEncoderDecoderOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class ContrastiveSearchDecoderOnlyOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class GreedySearchEncoderDecoderOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class SampleDecoderOnlyOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class SampleEncoderDecoderOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class BeamSearchDecoderOnlyOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
sequences_scores: Optional[torch.FloatTensor] = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
beam_indices: Optional[torch.LongTensor] = None |
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class BeamSearchEncoderDecoderOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
sequences_scores: Optional[torch.FloatTensor] = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
beam_indices: Optional[torch.LongTensor] = None |
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class BeamSampleDecoderOnlyOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
sequences_scores: Optional[torch.FloatTensor] = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
beam_indices: Optional[torch.LongTensor] = None |
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class BeamSampleEncoderDecoderOutput(ModelOutput): |
|
sequences: torch.LongTensor = None |
|
sequences_scores: Optional[torch.FloatTensor] = None |
|
scores: Optional[Tuple[torch.FloatTensor]] = None |
|
beam_indices: Optional[torch.LongTensor] = None |
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] |
|
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] |
|
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] |
|
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] |
|
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] |
|
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput] |
|
|
|
|
|
class GenerationMixin_Cache(transformers.generation.GenerationMixin): |
|
|
|
@torch.no_grad() |
|
def contrastive_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
top_k: Optional[int] = 1, |
|
penalty_alpha: Optional[float] = 0, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
logits_warper: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[Union[int, List[int]]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: bool = False, |
|
streamer: Optional["BaseStreamer"] = None, |
|
**model_kwargs, |
|
) -> Union[ContrastiveSearchOutput, torch.LongTensor]: |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate |
|
if return_dict_in_generate is not None |
|
else self.generation_config.return_dict_in_generate |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
|
|
|
this_peer_finished = False |
|
batch_size = input_ids.shape[0] |
|
|
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
|
|
if model_kwargs.get("past_key_values") is None: |
|
|
|
model_kwargs["use_cache"] = True |
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
|
|
outputs = self( |
|
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions |
|
) |
|
|
|
|
|
|
|
if self.config.is_encoder_decoder: |
|
last_hidden_states = outputs.decoder_hidden_states[-1] |
|
else: |
|
last_hidden_states = outputs.hidden_states[-1] |
|
|
|
logit_for_next_step = outputs.logits[:, -1, :] |
|
|
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, |
|
model_kwargs, |
|
is_encoder_decoder=self.config.is_encoder_decoder, |
|
standardize_cache_format=True, |
|
) |
|
|
|
|
|
_, model_kwargs = self._expand_inputs_for_generation( |
|
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs |
|
) |
|
|
|
past_key_values = model_kwargs.get("past_key_values") |
|
if past_key_values is None: |
|
raise ValueError( |
|
f"{self.__class__.__name__} does not support caching and therefore **can't** be used " |
|
"for contrastive search." |
|
) |
|
elif ( |
|
not isinstance(past_key_values[0], (tuple, torch.Tensor)) |
|
or past_key_values[0][0].shape[0] != batch_size |
|
): |
|
raise ValueError( |
|
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " |
|
"used for contrastive search without further modifications." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
logit_for_next_step = logits_processor(input_ids, logit_for_next_step) |
|
logit_for_next_step = logits_warper(input_ids, logit_for_next_step) |
|
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) |
|
top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (logit_for_next_step,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
new_key_values = [] |
|
for layer in model_kwargs["past_key_values"]: |
|
items = [] |
|
|
|
for item in layer: |
|
items.append(item.repeat_interleave(top_k, dim=0)) |
|
new_key_values.append(items) |
|
model_kwargs["past_key_values"] = new_key_values |
|
|
|
|
|
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) |
|
outputs = self( |
|
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions |
|
) |
|
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) |
|
|
|
logits = outputs.logits[:, -1, :] |
|
|
|
if self.config.is_encoder_decoder: |
|
next_hidden = outputs.decoder_hidden_states[-1] |
|
full_hidden_states = outputs.decoder_hidden_states |
|
else: |
|
next_hidden = outputs.hidden_states[-1] |
|
full_hidden_states = outputs.hidden_states |
|
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) |
|
|
|
|
|
|
|
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) |
|
|
|
|
|
|
|
|
|
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] |
|
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) |
|
next_hidden = next_hidden[range(batch_size), selected_idx, :] |
|
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) |
|
|
|
next_decoder_hidden_states = () |
|
for layer in full_hidden_states: |
|
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] |
|
next_decoder_hidden_states += (layer,) |
|
|
|
|
|
new_key_values = () |
|
for layer in next_past_key_values: |
|
items = () |
|
|
|
for item in layer: |
|
item = torch.stack(torch.split(item, top_k, dim=0)) |
|
item = item[range(batch_size), selected_idx, ...] |
|
items += (item,) |
|
new_key_values += (items,) |
|
next_past_key_values = new_key_values |
|
|
|
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] |
|
|
|
|
|
if self.config.is_encoder_decoder: |
|
next_step_cross_attentions = () |
|
next_step_decoder_attentions = () |
|
if output_attentions: |
|
for layer in outputs.cross_attentions: |
|
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] |
|
next_step_cross_attentions += (layer,) |
|
for layer in outputs.decoder_attentions: |
|
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] |
|
next_step_decoder_attentions += (layer,) |
|
outputs = Seq2SeqLMOutput( |
|
past_key_values=next_past_key_values, |
|
decoder_hidden_states=next_decoder_hidden_states, |
|
decoder_attentions=next_step_decoder_attentions or None, |
|
cross_attentions=next_step_cross_attentions or None, |
|
) |
|
else: |
|
next_step_attentions = () |
|
if output_attentions: |
|
for layer in outputs.attentions: |
|
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] |
|
next_step_attentions += (layer,) |
|
outputs = CausalLMOutputWithPast( |
|
past_key_values=next_past_key_values, |
|
hidden_states=next_decoder_hidden_states, |
|
attentions=next_step_attentions or None, |
|
) |
|
|
|
|
|
if synced_gpus and this_peer_finished: |
|
continue |
|
|
|
|
|
if eos_token_id is not None: |
|
if pad_token_id is None: |
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
if streamer is not None: |
|
streamer.put(next_tokens.cpu()) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
|
|
|
|
if eos_token_id_tensor is not None: |
|
unfinished_sequences = unfinished_sequences.mul( |
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
|
) |
|
|
|
|
|
if unfinished_sequences.max() == 0: |
|
this_peer_finished = True |
|
|
|
|
|
if stopping_criteria(input_ids, scores): |
|
this_peer_finished = True |
|
|
|
if this_peer_finished and not synced_gpus: |
|
break |
|
|
|
if streamer is not None: |
|
streamer.end() |
|
|
|
if return_dict_in_generate: |
|
if self.config.is_encoder_decoder: |
|
return ContrastiveSearchEncoderDecoderOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return ContrastiveSearchDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return input_ids |
|
|
|
def greedy_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[Union[int, List[int]]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: bool = False, |
|
streamer: Optional["BaseStreamer"] = None, |
|
**model_kwargs, |
|
) -> Union[GreedySearchOutput, torch.LongTensor]: |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use" |
|
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate |
|
if return_dict_in_generate is not None |
|
else self.generation_config.return_dict_in_generate |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
|
|
|
this_peer_finished = False |
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_tokens_scores = logits_processor(input_ids, next_token_logits) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_tokens_scores,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
next_tokens = torch.argmax(next_tokens_scores, dim=-1) |
|
|
|
|
|
if eos_token_id is not None: |
|
if pad_token_id is None: |
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
if streamer is not None: |
|
streamer.put(next_tokens.cpu()) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
|
|
|
|
if eos_token_id_tensor is not None: |
|
unfinished_sequences = unfinished_sequences.mul( |
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
|
) |
|
|
|
|
|
if unfinished_sequences.max() == 0: |
|
this_peer_finished = True |
|
|
|
|
|
if stopping_criteria(input_ids, scores): |
|
this_peer_finished = True |
|
|
|
if this_peer_finished and not synced_gpus: |
|
break |
|
|
|
if streamer is not None: |
|
streamer.end() |
|
|
|
if return_dict_in_generate: |
|
if self.config.is_encoder_decoder: |
|
return GreedySearchEncoderDecoderOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return GreedySearchDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return input_ids |
|
|
|
def sample( |
|
self, |
|
input_ids: torch.LongTensor, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
logits_warper: Optional[LogitsProcessorList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[Union[int, List[int]]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: bool = False, |
|
streamer: Optional["BaseStreamer"] = None, |
|
**model_kwargs, |
|
) -> Union[SampleOutput, torch.LongTensor]: |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use" |
|
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate |
|
if return_dict_in_generate is not None |
|
else self.generation_config.return_dict_in_generate |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
|
|
|
this_peer_finished = False |
|
|
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_scores,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
if streamer is not None: |
|
streamer.put(next_tokens.cpu()) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
|
|
|
|
if eos_token_id_tensor is not None: |
|
unfinished_sequences = unfinished_sequences.mul( |
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
|
) |
|
|
|
|
|
if unfinished_sequences.max() == 0: |
|
this_peer_finished = True |
|
|
|
|
|
if stopping_criteria(input_ids, scores): |
|
this_peer_finished = True |
|
|
|
if this_peer_finished and not synced_gpus: |
|
break |
|
|
|
if streamer is not None: |
|
streamer.end() |
|
|
|
if return_dict_in_generate: |
|
if self.config.is_encoder_decoder: |
|
return SampleEncoderDecoderOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return SampleDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return input_ids |
|
|
|
def beam_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
beam_scorer: BeamScorer, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[Union[int, List[int]]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: bool = False, |
|
**model_kwargs, |
|
) -> Union[BeamSearchOutput, torch.LongTensor]: |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use" |
|
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
if len(stopping_criteria) == 0: |
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate |
|
if return_dict_in_generate is not None |
|
else self.generation_config.return_dict_in_generate |
|
) |
|
|
|
batch_size = len(beam_scorer._beam_hyps) |
|
num_beams = beam_scorer.num_beams |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
|
|
if num_beams * batch_size != batch_beam_size: |
|
raise ValueError( |
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
beam_indices = ( |
|
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None |
|
) |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
|
beam_scores[:, 1:] = -1e9 |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_scores_processed,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
|
|
|
next_token_scores, next_tokens = torch.topk( |
|
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True |
|
) |
|
|
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
beam_outputs = beam_scorer.process( |
|
input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
beam_indices=beam_indices, |
|
) |
|
|
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past_key_values"] is not None: |
|
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) |
|
|
|
if return_dict_in_generate and output_scores: |
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
beam_indices=beam_indices, |
|
) |
|
|
|
if return_dict_in_generate: |
|
if not output_scores: |
|
sequence_outputs["sequence_scores"] = None |
|
|
|
if self.config.is_encoder_decoder: |
|
return BeamSearchEncoderDecoderOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=sequence_outputs["beam_indices"], |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return BeamSearchDecoderOnlyOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=sequence_outputs["beam_indices"], |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return sequence_outputs["sequences"] |
|
|
|
def beam_sample( |
|
self, |
|
input_ids: torch.LongTensor, |
|
beam_scorer: BeamScorer, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
logits_warper: Optional[LogitsProcessorList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[Union[int, List[int]]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: bool = False, |
|
**model_kwargs, |
|
) -> Union[BeamSampleOutput, torch.LongTensor]: |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use" |
|
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate |
|
if return_dict_in_generate is not None |
|
else self.generation_config.return_dict_in_generate |
|
) |
|
|
|
batch_size = len(beam_scorer._beam_hyps) |
|
num_beams = beam_scorer.num_beams |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
beam_indices = ( |
|
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None |
|
) |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
|
|
|
|
|
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (logits_warper(input_ids, next_token_scores_processed),) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
|
|
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) |
|
next_token_scores = torch.gather(next_token_scores, -1, next_tokens) |
|
|
|
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) |
|
next_tokens = torch.gather(next_tokens, -1, _indices) |
|
|
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
beam_outputs = beam_scorer.process( |
|
input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
beam_indices=beam_indices, |
|
) |
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past_key_values"] is not None: |
|
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) |
|
|
|
if return_dict_in_generate and output_scores: |
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
beam_indices=beam_indices, |
|
) |
|
|
|
if return_dict_in_generate: |
|
if not output_scores: |
|
sequence_outputs["sequence_scores"] = None |
|
|
|
if self.config.is_encoder_decoder: |
|
return BeamSampleEncoderDecoderOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=sequence_outputs["beam_indices"], |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return BeamSampleDecoderOnlyOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=sequence_outputs["beam_indices"], |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return sequence_outputs["sequences"] |
|
|
|
def group_beam_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
beam_scorer: BeamScorer, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[Union[int, List[int]]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: bool = False, |
|
**model_kwargs, |
|
): |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use" |
|
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate |
|
if return_dict_in_generate is not None |
|
else self.generation_config.return_dict_in_generate |
|
) |
|
|
|
batch_size = len(beam_scorer._beam_hyps) |
|
num_beams = beam_scorer.num_beams |
|
num_beam_groups = beam_scorer.num_beam_groups |
|
num_sub_beams = num_beams // num_beam_groups |
|
device = input_ids.device |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
|
|
if return_dict_in_generate and output_scores: |
|
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] |
|
else: |
|
beam_indices = None |
|
|
|
if num_beams * batch_size != batch_beam_size: |
|
raise ValueError( |
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
|
|
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) |
|
beam_scores[:, ::num_sub_beams] = 0 |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) |
|
|
|
|
|
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) |
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
if output_scores: |
|
processed_score = torch.zeros_like(outputs.logits[:, -1, :]) |
|
|
|
for beam_group_idx in range(num_beam_groups): |
|
group_start_idx = beam_group_idx * num_sub_beams |
|
group_end_idx = min(group_start_idx + num_sub_beams, num_beams) |
|
group_size = group_end_idx - group_start_idx |
|
|
|
|
|
batch_group_indices = [] |
|
|
|
for batch_idx in range(batch_size): |
|
batch_group_indices.extend( |
|
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] |
|
) |
|
group_input_ids = input_ids[batch_group_indices] |
|
|
|
|
|
next_token_logits = outputs.logits[batch_group_indices, -1, :] |
|
|
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
vocab_size = next_token_scores.shape[-1] |
|
|
|
next_token_scores_processed = logits_processor( |
|
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx |
|
) |
|
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) |
|
next_token_scores = next_token_scores.expand_as(next_token_scores_processed) |
|
|
|
if output_scores: |
|
processed_score[batch_group_indices] = next_token_scores_processed |
|
|
|
|
|
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) |
|
|
|
|
|
next_token_scores, next_tokens = torch.topk( |
|
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True |
|
) |
|
|
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
|
beam_outputs = beam_scorer.process( |
|
group_input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
beam_indices=process_beam_indices, |
|
) |
|
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
if return_dict_in_generate and output_scores: |
|
beam_indices[beam_group_idx] = tuple( |
|
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) |
|
) |
|
|
|
input_ids[batch_group_indices] = group_input_ids[beam_idx] |
|
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
current_tokens[batch_group_indices] = group_input_ids[:, -1] |
|
|
|
|
|
|
|
reordering_indices[batch_group_indices] = ( |
|
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") |
|
+ group_start_idx |
|
+ (beam_idx % group_size) |
|
) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (processed_score,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past_key_values"] is not None: |
|
model_kwargs["past_key_values"] = self._reorder_cache( |
|
model_kwargs["past_key_values"], reordering_indices |
|
) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
|
sequence_outputs = beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
beam_indices=final_beam_indices, |
|
) |
|
|
|
if return_dict_in_generate: |
|
if not output_scores: |
|
sequence_outputs["sequence_scores"] = None |
|
|
|
if self.config.is_encoder_decoder: |
|
return BeamSearchEncoderDecoderOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=sequence_outputs["beam_indices"], |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return BeamSearchDecoderOnlyOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=sequence_outputs["beam_indices"], |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return sequence_outputs["sequences"] |
|
|
|
def constrained_beam_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
constrained_beam_scorer: ConstrainedBeamSearchScorer, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[Union[int, List[int]]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = None, |
|
**model_kwargs, |
|
) -> Union[BeamSearchOutput, torch.LongTensor]: |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use" |
|
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
if len(stopping_criteria) == 0: |
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate |
|
if return_dict_in_generate is not None |
|
else self.generation_config.return_dict_in_generate |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
batch_size = len(constrained_beam_scorer._beam_hyps) |
|
num_beams = constrained_beam_scorer.num_beams |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
|
|
if num_beams * batch_size != batch_beam_size: |
|
raise ValueError( |
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
|
) |
|
|
|
|
|
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
|
beam_scores[:, 1:] = -1e9 |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
|
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
|
|
scores_for_all_vocab = next_token_scores.clone() |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_scores,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
|
|
|
next_token_scores, next_tokens = torch.topk( |
|
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True |
|
) |
|
|
|
next_indices = (next_tokens / vocab_size).long() |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
beam_outputs = constrained_beam_scorer.process( |
|
input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
scores_for_all_vocab, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
) |
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past_key_values"] is not None: |
|
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = constrained_beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
) |
|
|
|
if return_dict_in_generate: |
|
if not output_scores: |
|
sequence_outputs["sequence_scores"] = None |
|
if self.config.is_encoder_decoder: |
|
return BeamSearchEncoderDecoderOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return BeamSearchDecoderOnlyOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return sequence_outputs["sequences"] |
|
|
|
def assisted_decoding( |
|
self, |
|
input_ids: torch.LongTensor, |
|
assistant_model: "PreTrainedModel", |
|
do_sample: bool = False, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
logits_warper: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[Union[int, List[int]]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: bool = False, |
|
streamer: Optional["BaseStreamer"] = None, |
|
**model_kwargs, |
|
): |
|
|
|
if not hasattr(assistant_model, "max_assistant_tokens"): |
|
assistant_model.max_assistant_tokens = 5 |
|
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
|
if eos_token_id is not None and pad_token_id is None: |
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate |
|
if return_dict_in_generate is not None |
|
else self.generation_config.return_dict_in_generate |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
|
|
|
|
|
max_len = stopping_criteria[0].max_length |
|
|
|
this_peer_finished = False |
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
cur_len = input_ids.shape[-1] |
|
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1 |
|
|
|
|
|
|
|
|
|
candidate_input_ids = input_ids |
|
for _ in range(int(assistant_model.max_assistant_tokens)): |
|
|
|
if "assistant_past_key_values" in model_kwargs: |
|
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] |
|
|
|
new_token_len = candidate_input_ids.shape[1] - prev_seq_len |
|
assist_inputs = candidate_input_ids[:, -new_token_len:] |
|
assist_attn = torch.ones_like(candidate_input_ids) |
|
|
|
if assistant_model.config.is_encoder_decoder: |
|
assistant_model_outputs = assistant_model( |
|
decoder_input_ids=assist_inputs, |
|
decoder_attention_mask=assist_attn, |
|
past_key_values=model_kwargs["assistant_past_key_values"], |
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"], |
|
) |
|
else: |
|
assistant_model_outputs = assistant_model( |
|
assist_inputs, |
|
attention_mask=assist_attn, |
|
past_key_values=model_kwargs["assistant_past_key_values"], |
|
) |
|
else: |
|
if assistant_model.config.is_encoder_decoder: |
|
assistant_model_outputs = assistant_model( |
|
decoder_input_ids=candidate_input_ids, |
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"], |
|
) |
|
else: |
|
assistant_model_outputs = assistant_model(candidate_input_ids) |
|
|
|
|
|
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values |
|
if len(logits_processor) > 0: |
|
assistant_model_outputs.logits[:, -1, :] = logits_processor( |
|
candidate_input_ids, assistant_model_outputs.logits[:, -1, :] |
|
) |
|
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) |
|
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) |
|
|
|
|
|
if eos_token_id_tensor is not None: |
|
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) |
|
last_assistant_token_is_eos = ( |
|
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() |
|
) |
|
if last_assistant_token_is_eos: |
|
break |
|
else: |
|
last_assistant_token_is_eos = False |
|
|
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if "past_key_values" in model_kwargs: |
|
model_attn = torch.ones_like(candidate_input_ids) |
|
model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] |
|
if self.config.is_encoder_decoder: |
|
outputs = self( |
|
decoder_input_ids=model_input_ids, |
|
decoder_attention_mask=model_attn, |
|
past_key_values=model_kwargs["past_key_values"], |
|
encoder_outputs=model_kwargs["encoder_outputs"], |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
else: |
|
outputs = self( |
|
model_input_ids, |
|
attention_mask=model_attn, |
|
past_key_values=model_kwargs["past_key_values"], |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
else: |
|
if self.config.is_encoder_decoder: |
|
outputs = self( |
|
decoder_input_ids=candidate_input_ids, |
|
encoder_outputs=model_kwargs["encoder_outputs"], |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
else: |
|
outputs = self( |
|
candidate_input_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
|
|
new_logits = outputs.logits[:, -candidate_length - 1 :] |
|
if len(logits_processor) > 0: |
|
for i in range(candidate_length): |
|
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) |
|
if len(logits_warper) > 0: |
|
for i in range(candidate_length): |
|
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) |
|
|
|
|
|
if do_sample: |
|
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) |
|
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] |
|
else: |
|
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1) |
|
|
|
|
|
|
|
candidate_new_tokens = candidate_input_ids[:, -candidate_length:] |
|
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if last_assistant_token_is_eos and n_matches == candidate_length: |
|
n_matches -= 1 |
|
n_matches = min(n_matches, max_len - cur_len - 1) |
|
|
|
|
|
valid_tokens = selected_tokens[:, : n_matches + 1] |
|
input_ids = torch.cat((input_ids, valid_tokens), dim=-1) |
|
if streamer is not None: |
|
streamer.put(valid_tokens.cpu()) |
|
new_cur_len = input_ids.shape[-1] |
|
|
|
|
|
new_cache_size = new_cur_len - 1 |
|
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) |
|
model_kwargs["assistant_past_key_values"] = _crop_past_key_values( |
|
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1 |
|
) |
|
|
|
|
|
|
|
|
|
if n_matches == int(assistant_model.max_assistant_tokens): |
|
assistant_model.max_assistant_tokens += 2.0 |
|
else: |
|
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) |
|
|
|
|
|
|
|
if synced_gpus and this_peer_finished: |
|
continue |
|
|
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) |
|
|
|
if "past_key_values" not in model_kwargs: |
|
added_len = new_cur_len |
|
else: |
|
added_len = n_matches + 1 |
|
|
|
if output_attentions: |
|
if self.config.is_encoder_decoder: |
|
cross_attentions = _split_model_outputs( |
|
cross_attentions, outputs.cross_attentions, cur_len, added_len |
|
) |
|
decoder_attentions = _split_model_outputs( |
|
decoder_attentions, |
|
outputs.decoder_attentions, |
|
cur_len, |
|
added_len, |
|
is_decoder_attention=True, |
|
) |
|
else: |
|
decoder_attentions = _split_model_outputs( |
|
decoder_attentions, |
|
outputs.attentions, |
|
cur_len, |
|
added_len, |
|
is_decoder_attention=True, |
|
) |
|
if output_hidden_states: |
|
if self.config.is_encoder_decoder: |
|
decoder_hidden_states = _split_model_outputs( |
|
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len |
|
) |
|
else: |
|
decoder_hidden_states = _split_model_outputs( |
|
decoder_hidden_states, outputs.hidden_states, cur_len, added_len |
|
) |
|
|
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
|
|
|
|
if eos_token_id_tensor is not None: |
|
unfinished_sequences = unfinished_sequences.mul( |
|
input_ids[:, -1] |
|
.tile(eos_token_id_tensor.shape[0], 1) |
|
.ne(eos_token_id_tensor.unsqueeze(1)) |
|
.prod(dim=0) |
|
) |
|
|
|
|
|
if unfinished_sequences.max() == 0: |
|
this_peer_finished = True |
|
|
|
|
|
if stopping_criteria(input_ids, scores): |
|
this_peer_finished = True |
|
|
|
if this_peer_finished and not synced_gpus: |
|
break |
|
|
|
if streamer is not None: |
|
streamer.end() |
|
|
|
if return_dict_in_generate: |
|
if self.config.is_encoder_decoder: |
|
return GreedySearchEncoderDecoderOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return GreedySearchDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get('past_key_values', None) if model_kwargs["use_cache"] else None, |
|
) |
|
else: |
|
return input_ids |
|
|
|
|
|
|
|
|
|
|
|
class LlamaForCausalLM(GenerationMixin_Cache, transformers.LlamaForCausalLM): |
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
if past_key_values: |
|
past_seq_length = past_key_values[0][0].shape[2] |
|
input_ids = input_ids[:, past_seq_length:] |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, past_seq_length:] |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
} |
|
) |
|
return model_inputs |
|
|
|
|
|
class MistralForCausalLM(GenerationMixin_Cache, transformers.MistralForCausalLM): |
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
if past_key_values: |
|
past_seq_length = past_key_values[0][0].shape[2] |
|
input_ids = input_ids[:, past_seq_length:] |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, -1].unsqueeze(-1) |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
} |
|
) |
|
return model_inputs |
|
|
|
|
|
class PreTrainedTokenizer(transformers.PreTrainedTokenizer): |
|
|
|
def get_continued_input_ids(self, text, right_padding=False, return_tensors=False): |
|
raise NotImplementedError |
|
|
|
|
|
class GemmaTokenizer(PreTrainedTokenizer, transformers.GemmaTokenizer): |
|
|
|
def add_placeholder(self, text: str): |
|
if len(text) == 0 or text[0] != '#': |
|
return '####' + text |
|
return '$$' + text |
|
|
|
def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False): |
|
if isinstance(text, str): |
|
text = self.add_placeholder(text) |
|
else: |
|
text = [self.add_placeholder(x) for x in text] |
|
|
|
input_ids = self(text, add_special_tokens=False).input_ids |
|
if isinstance(text, str): |
|
input_ids = input_ids[1:] |
|
else: |
|
input_ids = [ids[1:] for ids in input_ids] |
|
|
|
if right_padding and isinstance(text, list): |
|
max_length = max([len(ids) for ids in input_ids]) |
|
input_ids = [ |
|
ids |
|
+ [self.pad_token_id] * (max_length - len(ids)) |
|
for ids in input_ids |
|
] |
|
if return_tensors: |
|
input_ids = torch.tensor(input_ids) |
|
return input_ids |
|
|
|
class PreTrainedTokenizer(transformers.PreTrainedTokenizer): |
|
|
|
def get_continued_input_ids(self, text, right_padding=False, return_tensors=False): |
|
raise NotImplementedError |
|
|
|
|
|
class LlamaTokenizer(PreTrainedTokenizer, transformers.LlamaTokenizer): |
|
|
|
def add_placeholder(self, text: str): |
|
if len(text) == 0 or text[0] != '#': |
|
return '####' + text |
|
return '$$' + text |
|
|
|
def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False): |
|
if isinstance(text, str): |
|
text = self.add_placeholder(text) |
|
else: |
|
text = [self.add_placeholder(x) for x in text] |
|
|
|
input_ids = self(text, add_special_tokens=False).input_ids |
|
if isinstance(text, str): |
|
input_ids = input_ids[1:] |
|
else: |
|
input_ids = [ids[1:] for ids in input_ids] |
|
|
|
if right_padding and isinstance(text, list): |
|
max_length = max([len(ids) for ids in input_ids]) |
|
input_ids = [ |
|
ids |
|
+ [self.pad_token_id] * (max_length - len(ids)) |
|
for ids in input_ids |
|
] |
|
if return_tensors: |
|
input_ids = torch.tensor(input_ids) |
|
return input_ids |
|
|
|
|
|
class CodeLlamaTokenizer(PreTrainedTokenizer, transformers.CodeLlamaTokenizer): |
|
|
|
def add_placeholder(self, text: str): |
|
if len(text) == 0 or text[0] != '#': |
|
return '####' + text |
|
return '$$' + text |
|
|
|
def get_continued_input_ids(self, text: Union[str, List[str]], right_padding=False, return_tensors=False): |
|
if isinstance(text, str): |
|
text = self.add_placeholder(text) |
|
else: |
|
text = [self.add_placeholder(x) for x in text] |
|
|
|
input_ids = self(text, add_special_tokens=False).input_ids |
|
if isinstance(text, str): |
|
input_ids = input_ids[1:] |
|
else: |
|
input_ids = [ids[1:] for ids in input_ids] |
|
|
|
if right_padding and isinstance(text, list): |
|
max_length = max([len(ids) for ids in input_ids]) |
|
input_ids = [ |
|
ids |
|
+ [self.pad_token_id] * (max_length - len(ids)) |
|
for ids in input_ids |
|
] |
|
if return_tensors: |
|
input_ids = torch.tensor(input_ids) |
|
return input_ids |
|
|
|
|
|
|
|
def build_transformers_mapping_to_cached_models(): |
|
transformers.models.auto.auto_factory._get_model_class = _get_model_class |
|
|
|
|
|
def build_transformers_mapping_to_custom_tokenizers(): |
|
transformers.models.auto.tokenization_auto.tokenizer_class_from_name = tokenizer_class_from_name |
|
|
|
|
|
|