eeg2meg-small / hf_transformer.py
fracapuano's picture
Upload Seq2SeqCrossFormer
9446c98 verified
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from transformers import PreTrainedModel, PretrainedConfig
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from model.architectures.transformer import EncoderDecoderTransformer
from model.architectures.crossformer import EncoderDecoderCrossFormer
from model.hf_configs import Seq2SeqConfig, Seq2SeqCrossConfig
from einops import rearrange
class Seq2SeqTransformer(PreTrainedModel):
"""
Custom Transformer for Sequence to Sequence tasks.
"""
config_class = Seq2SeqConfig
base_model_prefix = "transformer"
def __init__(self, config: PretrainedConfig, device: Optional[str]=None):
super().__init__(config)
self.softmax = nn.Softmax(dim=-1)
self.transformer = EncoderDecoderTransformer(
src_vocab_size=config.vocab_size_src,
tgt_vocab_size=config.vocab_size_tgt,
embed_dim=config.d_model,
num_heads=config.n_heads,
ff_dim=config.d_ff,
num_encoder_layers=config.n_layers,
num_decoder_layers=config.n_layers,
max_seq_length=config.sequence_length
)
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def _create_padding_mask(self, ids: torch.LongTensor) -> torch.DoubleTensor:
"""Creates a mask to avoid padded tokens to be interfering with attention"""
# First create boolean mask where True = padding token
is_padding = ids.eq(self.config.pad_token_id)
# Convert to float and replace padding positions with -inf, others with 1.0
mask = is_padding.float()
mask = mask.masked_fill(is_padding, float('-inf'))
mask = mask.masked_fill(~is_padding, 1.0)
return mask
def _shift_right(self, x: torch.LongTensor) -> torch.LongTensor:
"""Helper method to prepare decoder inputs (teacher forcing) by shifting right label tokens"""
shifted = torch.full(
(*x.shape[:-1], 1),
self.config.bos_token_id,
dtype=x.dtype,
device=x.device
)
shifted = torch.cat([shifted, x[:, :-1]], dim=-1)
return shifted
def _add_beginning_of_stream(self, x: torch.LongTensor) -> torch.LongTensor:
"""
Helper method to add BOS token to the beginning of input sequences
"""
bos = torch.full(
(*x.shape[:-1], 1),
self.config.bos_token_id,
dtype=x.dtype,
device=x.device
)
return torch.cat([bos, x], dim=-1)
def _add_end_of_stream(self, x: torch.LongTensor) -> torch.LongTensor:
"""Helper method to add EOS token to the end of label sequences"""
eos = torch.full(
(*x.shape[:-1], 1),
self.config.eos_token_id,
dtype=x.dtype,
device=x.device
)
return torch.cat([x, eos], dim=-1)
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
**kwargs
) -> Union[Tuple, dict]:
# TODO: add/end of streaming and right shift should take place outside of the model in tokenizer
# adding beginning of stream tokens to input too
input_ids = self._add_beginning_of_stream(input_ids)
# adding end of stream tokens to labels
labels = self._add_end_of_stream(labels)
# Prepare input for the decoder
if decoder_input_ids is None and labels is not None:
decoder_input_ids = self._shift_right(labels)
src_key_padding_mask = self._create_padding_mask(input_ids)
tgt_key_padding_mask = self._create_padding_mask(decoder_input_ids)
# Forward pass through your model
outputs = self.transformer(
src=input_ids,
tgt=decoder_input_ids,
src_mask=attention_mask,
tgt_mask=decoder_attention_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask
)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(outputs.view(-1, self.config.vocab_size_tgt), labels.view(-1))
return dict(
loss=loss,
logits=outputs,
)
def generate(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
max_length: Optional[int] = None,
temperature: float = 1.0,
do_sample: bool = False,
**kwargs
) -> torch.LongTensor:
batch_size = input_ids.shape[0]
max_length = max_length or self.config.max_length or 128
decoder_input_ids = torch.full(
(batch_size, 1),
self.config.bos_token_id,
dtype=torch.long,
device=input_ids.device
)
for _ in range(max_length - 1):
outputs = self.forward(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
next_token_logits = outputs["logits"][:, -1, :]
if do_sample:
# Apply temperature scaling
scaled_logits = next_token_logits / temperature
# Convert to probabilities
next_token_probs = self.softmax(scaled_logits)
# Sample from the probability distribution
next_token = torch.multinomial(
next_token_probs, num_samples=1
).squeeze(-1)
else:
# Greedy decoding
next_token = next_token_logits.argmax(dim=-1)
decoder_input_ids = torch.cat(
[decoder_input_ids, next_token.unsqueeze(-1)],
dim=-1
)
# Stop if all sequences have generated EOS token
if (decoder_input_ids == self.config.eos_token_id).any(dim=-1).all():
break
return decoder_input_ids
class Seq2SeqCrossFormer(Seq2SeqTransformer):
"""CrossFormer wrapper predicting over a discrete vocabulatory."""
config_class = Seq2SeqCrossConfig
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.softmax = nn.Softmax(dim=-1)
self.transformer = EncoderDecoderCrossFormer(
source_sequence_dimension=config.source_sequence_dimension,
target_sequence_dimension=config.target_sequence_dimension,
router_dim=config.router_dim,
src_vocab_size=config.vocab_size_src,
tgt_vocab_size=config.vocab_size_tgt,
embed_dim=config.d_model,
num_heads=config.n_heads,
ff_dim=config.d_ff,
num_encoder_layers=config.n_layers,
num_decoder_layers=config.n_layers,
max_seq_length=config.sequence_length
)
def _shift_right(self, x: torch.LongTensor) -> torch.LongTensor:
"""
Helper method to prepare decoder inputs (teacher forcing) by shifting right label tokens.
Handles 3D (B, S, C) tensors
"""
# Create shape that matches x's dimensions except for seq_len which will be 1
shape = list(x.shape)
shape[-2] = 1 # Set sequence dimension to 1
shifted = torch.full(
shape,
self.config.bos_token_id,
dtype=x.dtype,
device=x.device
)
shifted = torch.cat([shifted, x[..., :-1, :]], dim=-2)
return shifted
def _add_beginning_of_stream(self, x: torch.LongTensor) -> torch.LongTensor:
"""
Helper method to add BOS token to the beginning of input sequences.
Handles 3D (B, S, C) tensors
"""
shape = list(x.shape)
shape[-2] = 1 # Set sequence dimension to 1
sos = torch.full(
shape,
self.config.bos_token_id,
dtype=x.dtype,
device=x.device
)
return torch.cat([sos, x], dim=-2)
def _add_end_of_stream(self, x: torch.LongTensor) -> torch.LongTensor:
"""
Helper method to add EOS token to the end of label sequences.
Handles 3D (B, S, C) tensors
"""
# Create shape that matches x's dimensions except for seq_len which will be 1
shape = list(x.shape)
shape[-2] = 1 # Set sequence dimension to 1
eos = torch.full(
shape,
self.config.eos_token_id,
dtype=x.dtype,
device=x.device
)
return torch.cat([x, eos], dim=-2)
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
**kwargs
):
# FIXME: add/end of streaming and right shift should take place outside of the model in tokenizer
# (in tokenizer) adding beginning of stream tokens to input too
input_ids = self._add_beginning_of_stream(input_ids)
# (in tokenizer) adding end of stream tokens to labels
if labels is not None:
labels = self._add_end_of_stream(labels)
# Prepare input for the decoder
if decoder_input_ids is None and labels is not None:
decoder_input_ids = self._shift_right(labels)
src_src_key_padding_time_mask = rearrange(
self._create_padding_mask(
input_ids
),
'b s c -> (b c) s'
)
tgt_tgt_key_padding_time_mask = rearrange(
self._create_padding_mask(
decoder_input_ids
),
'b s c -> (b c) s'
)
# Forward pass through your model
outputs = self.transformer(
src=input_ids,
tgt=decoder_input_ids,
src_src_time_mask=kwargs.get("src_src_time_mask"),
src_src_dimension_mask=kwargs.get("src_src_dimension_mask"),
src_src_key_padding_time_mask=src_src_key_padding_time_mask,
tgt_tgt_time_mask=kwargs.get("tgt_tgt_time_mask"),
tgt_tgt_dimension_mask=kwargs.get("tgt_tgt_dimension_mask"),
tgt_tgt_key_padding_time_mask=tgt_tgt_key_padding_time_mask,
tgt_src_dimension_mask=kwargs.get("tgt_src_dimension_mask")
)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(
ignore_index=self.config.pad_token_id
)
loss = loss_fct(
outputs.view(-1, self.config.vocab_size_tgt), labels.view(-1)
)
return dict(
loss=loss,
logits=outputs,
)
def generate(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor]=None,
max_length: Optional[int]=None,
temperature: float=1.0,
do_sample: bool=False,
**kwargs
) -> torch.LongTensor:
batch_size, timesteps, channels = input_ids.shape
src_key_padding_mask = self._create_padding_mask(input_ids)
max_length = max_length or self.config.max_length or 128
decoder_input_ids = torch.full(
(batch_size, timesteps + 1, 306), # decoder model generates MEG data
self.config.pad_token_id,
dtype=torch.long,
device=input_ids.device
)
# Set BOS token at the start
decoder_input_ids[:, 0, :] = self.config.bos_token_id
for t in range(timesteps + max_length):
outputs = self.forward(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask
)
# Get predictions for this timestep
next_token_logits = outputs["logits"][:, t, :]
if do_sample:
scaled_logits = next_token_logits / temperature
next_token_probs = self.softmax(scaled_logits)
next_token = torch.multinomial(
next_token_probs, num_samples=1
).squeeze(-1)
else:
next_token = next_token_logits.argmax(dim=-1)
# Place the predicted token at position t
decoder_input_ids[:, t, :] = next_token
# Check if all sequences have generated EOS token
if (next_token == self.config.eos_token_id).all():
break
decoder_input_ids = decoder_input_ids[:, -(timesteps+1):, :]
return decoder_input_ids