Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Based on fairseq code bases | |
# https://github.com/facebookresearch/fairseq | |
# -------------------------------------------------------- | |
""" | |
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/learned_positional_embedding.py | |
1. Add clamping if the input length exceeds the max-source-tokens | |
""" | |
from typing import Dict, Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from fairseq import utils | |
from torch import Tensor | |
class LearnedPositionalEmbedding(nn.Embedding): | |
""" | |
This module learns positional embeddings up to a fixed maximum size. | |
Padding ids are ignored by either offsetting based on padding_idx | |
or by setting padding_idx to None and ensuring that the appropriate | |
position ids are passed to the forward function. | |
""" | |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): | |
super().__init__(num_embeddings, embedding_dim, padding_idx) | |
self.onnx_trace = False | |
if self.padding_idx is not None: | |
self.max_positions = self.num_embeddings - self.padding_idx - 1 | |
else: | |
self.max_positions = self.num_embeddings | |
def forward( | |
self, | |
input: Tensor, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
positions: Optional[Tensor] = None, | |
): | |
"""Input is expected to be of size [bsz x seqlen].""" | |
assert (positions is None) or ( | |
self.padding_idx is None | |
), "If positions is pre-computed then padding_idx should not be set." | |
if positions is None: | |
if incremental_state is not None: | |
# positions is the same for every token when decoding a single step | |
# Without the int() cast, it doesn't work in some cases when exporting to ONNX | |
positions = torch.zeros( | |
(1, 1), device=input.device, dtype=input.dtype | |
).fill_(int(self.padding_idx + input.size(1))) | |
else: | |
positions = utils.make_positions( | |
input, self.padding_idx, onnx_trace=self.onnx_trace | |
) | |
positions = torch.clamp(positions, max=self.padding_idx + self.max_positions) | |
return F.embedding( | |
positions, | |
self.weight, | |
self.padding_idx, | |
self.max_norm, | |
self.norm_type, | |
self.scale_grad_by_freq, | |
self.sparse, | |
) | |