.gitattributes CHANGED
@@ -33,4 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- model.TGT filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
README.md CHANGED
@@ -60,82 +60,7 @@ Please refer to `Appendix D: Model Card` of the [preprint](https://arxiv.org/abs
60
 
61
  ### Usage Instructions
62
 
63
- Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface) for a detail description on how to use HF compatible IndicTrans2 models for inference.
64
-
65
- ```python
66
- import torch
67
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
68
- from IndicTransToolkit import IndicProcessor
69
- # recommended to run this on a gpu with flash_attn installed
70
- # don't set attn_implemetation if you don't have flash_attn
71
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
72
-
73
- src_lang, tgt_lang = "eng_Latn", "hin_Deva"
74
- model_name = "ai4bharat/indictrans2-en-indic-1B"
75
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
76
-
77
- model = AutoModelForSeq2SeqLM.from_pretrained(
78
- model_name,
79
- trust_remote_code=True,
80
- torch_dtype=torch.float16, # performance might slightly vary for bfloat16
81
- attn_implementation="flash_attention_2"
82
- ).to(DEVICE)
83
-
84
- ip = IndicProcessor(inference=True)
85
-
86
- input_sentences = [
87
- "When I was young, I used to go to the park every day.",
88
- "We watched a new movie last week, which was very inspiring.",
89
- "If you had met me at that time, we would have gone out to eat.",
90
- "My friend has invited me to his birthday party, and I will give him a gift.",
91
- ]
92
-
93
- batch = ip.preprocess_batch(
94
- input_sentences,
95
- src_lang=src_lang,
96
- tgt_lang=tgt_lang,
97
- )
98
-
99
- # Tokenize the sentences and generate input encodings
100
- inputs = tokenizer(
101
- batch,
102
- truncation=True,
103
- padding="longest",
104
- return_tensors="pt",
105
- return_attention_mask=True,
106
- ).to(DEVICE)
107
-
108
- # Generate translations using the model
109
- with torch.no_grad():
110
- generated_tokens = model.generate(
111
- **inputs,
112
- use_cache=True,
113
- min_length=0,
114
- max_length=256,
115
- num_beams=5,
116
- num_return_sequences=1,
117
- )
118
-
119
- # Decode the generated tokens into text
120
- with tokenizer.as_target_tokenizer():
121
- generated_tokens = tokenizer.batch_decode(
122
- generated_tokens.detach().cpu().tolist(),
123
- skip_special_tokens=True,
124
- clean_up_tokenization_spaces=True,
125
- )
126
-
127
- # Postprocess the translations, including entity replacement
128
- translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
129
-
130
- for input_sentence, translation in zip(input_sentences, translations):
131
- print(f"{src_lang}: {input_sentence}")
132
- print(f"{tgt_lang}: {translation}")
133
- ```
134
-
135
- ### 📢 Long Context IT2 Models
136
- - New RoPE based IndicTrans2 models which are capable of handling sequence lengths **upto 2048 tokens** are available [here](https://huggingface.co/collections/prajdabre/indictrans2-rope-6742ddac669a05db0804db35)
137
- - These models can be used by just changing the `model_name` parameter. Please read the model card of the RoPE-IT2 models for more information about the generation.
138
- - It is recommended to run these models with `flash_attention_2` for efficient generation.
139
 
140
 
141
  ### Citation
 
60
 
61
  ### Usage Instructions
62
 
63
+ Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_inference) for a detail description on how to use HF compatible IndicTrans2 models for inference.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  ### Citation
config.json CHANGED
@@ -9,7 +9,6 @@
9
  "AutoConfig": "configuration_indictrans.IndicTransConfig",
10
  "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
11
  },
12
- "tokenizer_class": "IndicTransTokenizer",
13
  "attention_dropout": 0.0,
14
  "bos_token_id": 0,
15
  "decoder_attention_heads": 16,
@@ -41,6 +40,5 @@
41
  "share_decoder_input_output_embed": false,
42
  "torch_dtype": "float32",
43
  "transformers_version": "4.32.1",
44
- "use_cache": true,
45
- "attn_implementation": "eager"
46
  }
 
9
  "AutoConfig": "configuration_indictrans.IndicTransConfig",
10
  "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
11
  },
 
12
  "attention_dropout": 0.0,
13
  "bos_token_id": 0,
14
  "decoder_attention_heads": 16,
 
40
  "share_decoder_input_output_embed": false,
41
  "torch_dtype": "float32",
42
  "transformers_version": "4.32.1",
43
+ "use_cache": true
 
44
  }
configuration_indictrans.py CHANGED
@@ -118,7 +118,6 @@ class IndicTransConfig(PretrainedConfig):
118
  pad_token_id=1,
119
  bos_token_id=0,
120
  eos_token_id=2,
121
- attn_implementation="eager",
122
  **kwargs,
123
  ):
124
  self.encoder_vocab_size = encoder_vocab_size
@@ -147,8 +146,7 @@ class IndicTransConfig(PretrainedConfig):
147
  self.num_hidden_layers = encoder_layers
148
  self.scale_embedding = scale_embedding
149
  self.share_decoder_input_output_embed = share_decoder_input_output_embed
150
- self.attn_implementation = attn_implementation
151
-
152
  super().__init__(
153
  pad_token_id=pad_token_id,
154
  bos_token_id=bos_token_id,
 
118
  pad_token_id=1,
119
  bos_token_id=0,
120
  eos_token_id=2,
 
121
  **kwargs,
122
  ):
123
  self.encoder_vocab_size = encoder_vocab_size
 
146
  self.num_hidden_layers = encoder_layers
147
  self.scale_embedding = scale_embedding
148
  self.share_decoder_input_output_embed = share_decoder_input_output_embed
149
+
 
150
  super().__init__(
151
  pad_token_id=pad_token_id,
152
  bos_token_id=bos_token_id,
dict.SRC.json DELETED
The diff for this file is too large to render. See raw diff
 
dict.TGT.json DELETED
The diff for this file is too large to render. See raw diff
 
model.SRC DELETED
Binary file (759 kB)
 
model.TGT DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
- size 3256903
 
 
 
 
model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:35d28fe035cd6ac026536b555558b07762425c8b930670219063e4fc3666c96d
3
- size 4462265272
 
 
 
 
modeling_indictrans.py CHANGED
@@ -23,57 +23,25 @@ import torch.nn as nn
23
  from torch.nn import functional as F
24
 
25
  from transformers.activations import ACT2FN
26
-
27
- from transformers.modeling_attn_mask_utils import (
28
- _prepare_4d_attention_mask,
29
- _prepare_4d_attention_mask_for_sdpa,
30
- _prepare_4d_causal_attention_mask,
31
- _prepare_4d_causal_attention_mask_for_sdpa,
32
- )
33
-
34
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
35
  from transformers.modeling_outputs import (
36
  BaseModelOutput,
37
  BaseModelOutputWithPastAndCrossAttentions,
38
  Seq2SeqLMOutput,
39
- Seq2SeqModelOutput
40
  )
41
 
42
- from transformers.utils import (
43
- logging,
44
- is_flash_attn_2_available,
45
- is_flash_attn_greater_or_equal_2_10,
46
- )
47
-
48
  from transformers.modeling_utils import PreTrainedModel
49
- from transformers.generation.utils import GenerationMixin
50
 
51
  from .configuration_indictrans import IndicTransConfig
52
 
53
 
54
  logger = logging.get_logger(__name__)
55
 
56
- INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
57
 
58
- try:
59
- if is_flash_attn_2_available():
60
- from flash_attn import flash_attn_func, flash_attn_varlen_func
61
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
62
- except:
63
- pass
64
-
65
-
66
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
67
- def _get_unpad_data(attention_mask):
68
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
69
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
70
- max_seqlen_in_batch = seqlens_in_batch.max().item()
71
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
72
- return (
73
- indices,
74
- cu_seqlens,
75
- max_seqlen_in_batch,
76
- )
77
 
78
 
79
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
@@ -95,6 +63,54 @@ def shift_tokens_right(
95
  return shifted_input_ids
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def create_position_ids_from_input_ids(
99
  input_ids, padding_idx, past_key_values_length=0
100
  ):
@@ -231,15 +247,12 @@ class IndicTransAttention(nn.Module):
231
  dropout: float = 0.0,
232
  is_decoder: bool = False,
233
  bias: bool = True,
234
- is_causal: bool = False,
235
- config: Optional[IndicTransConfig] = None,
236
  ):
237
  super().__init__()
238
  self.embed_dim = embed_dim
239
  self.num_heads = num_heads
240
  self.dropout = dropout
241
  self.head_dim = embed_dim // num_heads
242
- self.config = config
243
 
244
  if (self.head_dim * num_heads) != self.embed_dim:
245
  raise ValueError(
@@ -248,7 +261,6 @@ class IndicTransAttention(nn.Module):
248
  )
249
  self.scaling = self.head_dim**-0.5
250
  self.is_decoder = is_decoder
251
- self.is_causal = is_causal
252
 
253
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
254
  self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@@ -390,345 +402,17 @@ class IndicTransAttention(nn.Module):
390
  attn_output = self.out_proj(attn_output)
391
 
392
  return attn_output, attn_weights_reshaped, past_key_value
393
-
394
-
395
- class IndicTransFlashAttention2(IndicTransAttention):
396
- """
397
- IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
398
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
399
- flash attention and deal with padding tokens in case the input contains any of them.
400
- """
401
-
402
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
403
- def __init__(self, *args, **kwargs):
404
- super().__init__(*args, **kwargs)
405
-
406
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
407
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
408
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
409
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
410
-
411
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
412
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
413
-
414
- def forward(
415
- self,
416
- hidden_states: torch.Tensor,
417
- key_value_states: Optional[torch.Tensor] = None,
418
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
419
- attention_mask: Optional[torch.Tensor] = None,
420
- layer_head_mask: Optional[torch.Tensor] = None,
421
- output_attentions: bool = False,
422
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
423
- # IndicTransFlashAttention2 attention does not support output_attentions
424
- if output_attentions:
425
- raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
426
-
427
- # if key_value_states are provided this layer is used as a cross-attention layer
428
- # for the decoder
429
- is_cross_attention = key_value_states is not None
430
-
431
- bsz, q_len, _ = hidden_states.size()
432
-
433
- # get query proj
434
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
435
- # get key, value proj
436
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
437
- # is checking that the `sequence_length` of the `past_key_value` is the same as
438
- # the provided `key_value_states` to support prefix tuning
439
- if (
440
- is_cross_attention
441
- and past_key_value is not None
442
- and past_key_value[0].shape[2] == key_value_states.shape[1]
443
- ):
444
- # reuse k,v, cross_attentions
445
- key_states = past_key_value[0].transpose(1, 2)
446
- value_states = past_key_value[1].transpose(1, 2)
447
- elif is_cross_attention:
448
- # cross_attentions
449
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
450
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
451
- elif past_key_value is not None:
452
- # reuse k, v, self_attention
453
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
454
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
455
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
456
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
457
- else:
458
- # self_attention
459
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
460
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
461
-
462
- if self.is_decoder:
463
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
464
- # Further calls to cross_attention layer can then reuse all cross-attention
465
- # key/value_states (first "if" case)
466
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
467
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
468
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
469
- # if encoder bi-directional self-attention `past_key_value` is always `None`
470
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
471
-
472
- kv_seq_len = key_states.shape[-2]
473
- if past_key_value is not None:
474
- kv_seq_len += past_key_value[0].shape[-2]
475
-
476
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
477
- # therefore the input hidden states gets silently casted in float32. Hence, we need
478
- # cast them back in the correct dtype just to be sure everything works as expected.
479
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
480
- # in fp32. (LlamaRMSNorm handles it correctly)
481
-
482
- input_dtype = query_states.dtype
483
- if input_dtype == torch.float32:
484
- if torch.is_autocast_enabled():
485
- target_dtype = torch.get_autocast_gpu_dtype()
486
- # Handle the case where the model is quantized
487
- elif hasattr(self.config, "_pre_quantization_dtype"):
488
- target_dtype = self.config._pre_quantization_dtype
489
- else:
490
- target_dtype = self.q_proj.weight.dtype
491
-
492
- logger.warning_once(
493
- f"The input hidden states seems to be silently casted in float32, this might be related to"
494
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
495
- f" {target_dtype}."
496
- )
497
-
498
- query_states = query_states.to(target_dtype)
499
- key_states = key_states.to(target_dtype)
500
- value_states = value_states.to(target_dtype)
501
-
502
- attn_output = self._flash_attention_forward(
503
- query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
504
- )
505
-
506
- attn_output = attn_output.reshape(bsz, q_len, -1)
507
- attn_output = self.out_proj(attn_output)
508
-
509
- if not output_attentions:
510
- attn_weights = None
511
-
512
- return attn_output, attn_weights, past_key_value
513
-
514
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
515
- def _flash_attention_forward(
516
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
517
- ):
518
- """
519
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
520
- first unpad the input, then computes the attention scores and pad the final attention scores.
521
-
522
- Args:
523
- query_states (`torch.Tensor`):
524
- Input query states to be passed to Flash Attention API
525
- key_states (`torch.Tensor`):
526
- Input key states to be passed to Flash Attention API
527
- value_states (`torch.Tensor`):
528
- Input value states to be passed to Flash Attention API
529
- attention_mask (`torch.Tensor`):
530
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
531
- position of padding tokens and 1 for the position of non-padding tokens.
532
- dropout (`float`):
533
- Attention dropout
534
- softmax_scale (`float`, *optional*):
535
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
536
- """
537
- if not self._flash_attn_uses_top_left_mask:
538
- causal = self.is_causal
539
- else:
540
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
541
- causal = self.is_causal and query_length != 1
542
-
543
- # Contains at least one padding token in the sequence
544
- if attention_mask is not None:
545
- batch_size = query_states.shape[0]
546
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
547
- query_states, key_states, value_states, attention_mask, query_length
548
- )
549
-
550
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
551
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
552
-
553
- attn_output_unpad = flash_attn_varlen_func(
554
- query_states,
555
- key_states,
556
- value_states,
557
- cu_seqlens_q=cu_seqlens_q,
558
- cu_seqlens_k=cu_seqlens_k,
559
- max_seqlen_q=max_seqlen_in_batch_q,
560
- max_seqlen_k=max_seqlen_in_batch_k,
561
- dropout_p=dropout,
562
- softmax_scale=softmax_scale,
563
- causal=causal,
564
- )
565
-
566
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
567
- else:
568
- attn_output = flash_attn_func(
569
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
570
- )
571
-
572
- return attn_output
573
-
574
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
575
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
576
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
577
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
578
-
579
- key_layer = index_first_axis(
580
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
581
- )
582
- value_layer = index_first_axis(
583
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
584
- )
585
- if query_length == kv_seq_len:
586
- query_layer = index_first_axis(
587
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
588
- )
589
- cu_seqlens_q = cu_seqlens_k
590
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
591
- indices_q = indices_k
592
- elif query_length == 1:
593
- max_seqlen_in_batch_q = 1
594
- cu_seqlens_q = torch.arange(
595
- batch_size + 1, dtype=torch.int32, device=query_layer.device
596
- ) # There is a memcpy here, that is very bad.
597
- indices_q = cu_seqlens_q[:-1]
598
- query_layer = query_layer.squeeze(1)
599
- else:
600
- # The -q_len: slice assumes left padding.
601
- attention_mask = attention_mask[:, -query_length:]
602
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
603
-
604
- return (
605
- query_layer,
606
- key_layer,
607
- value_layer,
608
- indices_q,
609
- (cu_seqlens_q, cu_seqlens_k),
610
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
611
- )
612
-
613
-
614
- class IndicTransSdpaAttention(IndicTransAttention):
615
- def forward(
616
- self,
617
- hidden_states: torch.Tensor,
618
- key_value_states: Optional[torch.Tensor] = None,
619
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
620
- attention_mask: Optional[torch.Tensor] = None,
621
- layer_head_mask: Optional[torch.Tensor] = None,
622
- output_attentions: bool = False,
623
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
624
- """Input shape: Batch x Time x Channel"""
625
- if output_attentions or layer_head_mask is not None:
626
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
627
- logger.warning_once(
628
- "IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
629
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
630
- )
631
- return super().forward(
632
- hidden_states,
633
- key_value_states=key_value_states,
634
- past_key_value=past_key_value,
635
- attention_mask=attention_mask,
636
- layer_head_mask=layer_head_mask,
637
- output_attentions=output_attentions,
638
- )
639
-
640
- # if key_value_states are provided this layer is used as a cross-attention layer
641
- # for the decoder
642
- is_cross_attention = key_value_states is not None
643
-
644
- bsz, tgt_len, _ = hidden_states.size()
645
-
646
- # get query proj
647
- query_states = self.q_proj(hidden_states)
648
- # get key, value proj
649
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
650
- # is checking that the `sequence_length` of the `past_key_value` is the same as
651
- # the provided `key_value_states` to support prefix tuning
652
- if (
653
- is_cross_attention
654
- and past_key_value is not None
655
- and past_key_value[0].shape[2] == key_value_states.shape[1]
656
- ):
657
- # reuse k,v, cross_attentions
658
- key_states = past_key_value[0]
659
- value_states = past_key_value[1]
660
- elif is_cross_attention:
661
- # cross_attentions
662
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
663
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
664
- elif past_key_value is not None:
665
- # reuse k, v, self_attention
666
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
667
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
668
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
669
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
670
- else:
671
- # self_attention
672
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
673
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
674
-
675
- if self.is_decoder:
676
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
677
- # Further calls to cross_attention layer can then reuse all cross-attention
678
- # key/value_states (first "if" case)
679
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
680
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
681
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
682
- # if encoder bi-directional self-attention `past_key_value` is always `None`
683
- past_key_value = (key_states, value_states)
684
-
685
- query_states = self._shape(query_states, tgt_len, bsz)
686
-
687
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
688
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
689
- attn_output = F.scaled_dot_product_attention(
690
- query_states,
691
- key_states,
692
- value_states,
693
- attn_mask=attention_mask,
694
- dropout_p=self.dropout if self.training else 0.0,
695
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
696
- is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
697
- )
698
-
699
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
700
- raise ValueError(
701
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
702
- f" {attn_output.size()}"
703
- )
704
-
705
- attn_output = attn_output.transpose(1, 2)
706
-
707
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
708
- # partitioned across GPUs when using tensor-parallelism.
709
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
710
-
711
- attn_output = self.out_proj(attn_output)
712
-
713
- return attn_output, None, past_key_value
714
 
715
 
716
- INDICTRANS_ATTENTION_CLASSES = {
717
- "eager": IndicTransAttention,
718
- "sdpa": IndicTransSdpaAttention,
719
- "flash_attention_2": IndicTransFlashAttention2,
720
- }
721
-
722
  # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
723
  class IndicTransEncoderLayer(nn.Module):
724
  def __init__(self, config: IndicTransConfig):
725
  super().__init__()
726
  self.embed_dim = config.encoder_embed_dim
727
- self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
728
  embed_dim=self.embed_dim,
729
  num_heads=config.encoder_attention_heads,
730
  dropout=config.attention_dropout,
731
- config=config,
732
  )
733
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
734
  self.dropout = config.dropout
@@ -806,25 +490,22 @@ class IndicTransDecoderLayer(nn.Module):
806
  super().__init__()
807
  self.embed_dim = config.decoder_embed_dim
808
 
809
- self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
810
  embed_dim=self.embed_dim,
811
  num_heads=config.decoder_attention_heads,
812
  dropout=config.attention_dropout,
813
  is_decoder=True,
814
- is_causal=True,
815
- config=config,
816
  )
817
  self.dropout = config.dropout
818
  self.activation_fn = ACT2FN[config.activation_function]
819
  self.activation_dropout = config.activation_dropout
820
 
821
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
822
- self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
823
  self.embed_dim,
824
  config.decoder_attention_heads,
825
  dropout=config.attention_dropout,
826
  is_decoder=True,
827
- config=config,
828
  )
829
  self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
830
  self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
@@ -1012,9 +693,6 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
1012
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1013
  )
1014
 
1015
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1016
- self._use_sdpa = config._attn_implementation == "sdpa"
1017
-
1018
  self.gradient_checkpointing = False
1019
  # Initialize weights and apply final processing
1020
  self.post_init()
@@ -1101,21 +779,13 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
1101
 
1102
  hidden_states = inputs_embeds + embed_pos
1103
  if self.layernorm_embedding is not None:
1104
- hidden_states = self.layernorm_embedding(hidden_states)
1105
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1106
 
 
1107
  if attention_mask is not None:
1108
- if self._use_flash_attention_2:
1109
- attention_mask = attention_mask if 0 in attention_mask else None
1110
- elif self._use_sdpa and head_mask is None and not output_attentions:
1111
- # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1112
- # the manual implementation that requires a 4D causal mask in all cases.
1113
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1114
- attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
1115
- else:
1116
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1117
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1118
-
1119
 
1120
  encoder_states = () if output_hidden_states else None
1121
  all_attentions = () if output_attentions else None
@@ -1239,9 +909,6 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1239
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1240
  )
1241
 
1242
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1243
- self._use_sdpa = config._attn_implementation == "sdpa"
1244
-
1245
  self.gradient_checkpointing = False
1246
  # Initialize weights and apply final processing
1247
  self.post_init()
@@ -1364,43 +1031,29 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1364
  if inputs_embeds is None:
1365
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1366
 
1367
-
1368
- if self._use_flash_attention_2:
1369
- # 2d mask is passed through the layers
1370
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1371
- elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1372
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1373
- # the manual implementation that requires a 4D causal mask in all cases.
1374
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1375
- attention_mask,
1376
  input_shape,
1377
- inputs_embeds,
1378
- past_key_values_length,
 
1379
  )
1380
- else:
1381
- # 4d mask is passed through the layers
1382
- attention_mask = _prepare_4d_causal_attention_mask(
1383
- attention_mask, input_shape, inputs_embeds, past_key_values_length
 
1384
  )
1385
 
1386
  # expand encoder attention mask
1387
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1388
- if self._use_flash_attention_2:
1389
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1390
- elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
1391
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1392
- # the manual implementation that requires a 4D causal mask in all cases.
1393
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1394
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1395
- encoder_attention_mask,
1396
- inputs_embeds.dtype,
1397
- tgt_len=input_shape[-1],
1398
- )
1399
- else:
1400
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1401
- encoder_attention_mask = _prepare_4d_attention_mask(
1402
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1403
- )
1404
 
1405
  # embed positions
1406
  positions = self.embed_positions(
@@ -1471,7 +1124,7 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1471
  layer_outputs = torch.utils.checkpoint.checkpoint(
1472
  create_custom_forward(decoder_layer),
1473
  hidden_states,
1474
- attention_mask,
1475
  encoder_hidden_states,
1476
  encoder_attention_mask,
1477
  head_mask[idx] if head_mask is not None else None,
@@ -1483,7 +1136,7 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1483
  else:
1484
  layer_outputs = decoder_layer(
1485
  hidden_states,
1486
- attention_mask=attention_mask,
1487
  encoder_hidden_states=encoder_hidden_states,
1488
  encoder_attention_mask=encoder_attention_mask,
1489
  layer_head_mask=(
@@ -1642,9 +1295,9 @@ class IndicTransModel(IndicTransPreTrainedModel):
1642
 
1643
 
1644
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1645
- class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMixin):
1646
  base_model_prefix = "model"
1647
- _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
1648
  _label_smoothing = 0.0
1649
 
1650
  def __init__(self, config: IndicTransConfig):
@@ -1654,20 +1307,19 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
1654
  config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1655
  )
1656
 
 
 
 
1657
  self.post_init()
1658
 
1659
  def tie_weights(self):
1660
- if self.config.share_decoder_input_output_embed:
1661
- self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head)
1662
-
1663
  def get_encoder(self):
1664
- return self.model.encoder
1665
 
1666
  def get_decoder(self):
1667
- return self.model.decoder
1668
-
1669
- def get_input_embeddings(self):
1670
- return self.model.encoder.embed_tokens
1671
 
1672
  def get_output_embeddings(self):
1673
  return self.lm_head
@@ -1677,6 +1329,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
1677
 
1678
  def set_label_smoothing(self, label_smoothing):
1679
  self._label_smoothing = label_smoothing
 
1680
  def forward(
1681
  self,
1682
  input_ids: Optional[torch.LongTensor] = None,
@@ -1740,7 +1393,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel, GenerationMi
1740
  masked_lm_loss = F.cross_entropy(
1741
  input=lm_logits.view(-1, self.config.decoder_vocab_size),
1742
  target=labels.view(-1),
1743
- ignore_index=-100,
1744
  label_smoothing=self._label_smoothing,
1745
  )
1746
 
 
23
  from torch.nn import functional as F
24
 
25
  from transformers.activations import ACT2FN
 
 
 
 
 
 
 
 
26
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
27
  from transformers.modeling_outputs import (
28
  BaseModelOutput,
29
  BaseModelOutputWithPastAndCrossAttentions,
30
  Seq2SeqLMOutput,
31
+ Seq2SeqModelOutput,
32
  )
33
 
34
+ from transformers.utils import logging
 
 
 
 
 
35
  from transformers.modeling_utils import PreTrainedModel
 
36
 
37
  from .configuration_indictrans import IndicTransConfig
38
 
39
 
40
  logger = logging.get_logger(__name__)
41
 
42
+ _CONFIG_FOR_DOC = "IndicTransConfig"
43
 
44
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
 
63
  return shifted_input_ids
64
 
65
 
66
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
67
+ def _make_causal_mask(
68
+ input_ids_shape: torch.Size,
69
+ dtype: torch.dtype,
70
+ device: torch.device,
71
+ past_key_values_length: int = 0,
72
+ ):
73
+ """
74
+ Make causal mask used for bi-directional self-attention.
75
+ """
76
+ bsz, tgt_len = input_ids_shape
77
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
78
+ mask_cond = torch.arange(mask.size(-1), device=device)
79
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
80
+ mask = mask.to(dtype)
81
+
82
+ if past_key_values_length > 0:
83
+ mask = torch.cat(
84
+ [
85
+ torch.zeros(
86
+ tgt_len, past_key_values_length, dtype=dtype, device=device
87
+ ),
88
+ mask,
89
+ ],
90
+ dim=-1,
91
+ )
92
+ return mask[None, None, :, :].expand(
93
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
94
+ )
95
+
96
+
97
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
98
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
99
+ """
100
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
101
+ """
102
+ bsz, src_len = mask.size()
103
+ tgt_len = tgt_len if tgt_len is not None else src_len
104
+
105
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
106
+
107
+ inverted_mask = 1.0 - expanded_mask
108
+
109
+ return inverted_mask.masked_fill(
110
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
111
+ )
112
+
113
+
114
  def create_position_ids_from_input_ids(
115
  input_ids, padding_idx, past_key_values_length=0
116
  ):
 
247
  dropout: float = 0.0,
248
  is_decoder: bool = False,
249
  bias: bool = True,
 
 
250
  ):
251
  super().__init__()
252
  self.embed_dim = embed_dim
253
  self.num_heads = num_heads
254
  self.dropout = dropout
255
  self.head_dim = embed_dim // num_heads
 
256
 
257
  if (self.head_dim * num_heads) != self.embed_dim:
258
  raise ValueError(
 
261
  )
262
  self.scaling = self.head_dim**-0.5
263
  self.is_decoder = is_decoder
 
264
 
265
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
266
  self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
 
402
  attn_output = self.out_proj(attn_output)
403
 
404
  return attn_output, attn_weights_reshaped, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
 
 
 
 
 
 
 
407
  # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
408
  class IndicTransEncoderLayer(nn.Module):
409
  def __init__(self, config: IndicTransConfig):
410
  super().__init__()
411
  self.embed_dim = config.encoder_embed_dim
412
+ self.self_attn = IndicTransAttention(
413
  embed_dim=self.embed_dim,
414
  num_heads=config.encoder_attention_heads,
415
  dropout=config.attention_dropout,
 
416
  )
417
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
418
  self.dropout = config.dropout
 
490
  super().__init__()
491
  self.embed_dim = config.decoder_embed_dim
492
 
493
+ self.self_attn = IndicTransAttention(
494
  embed_dim=self.embed_dim,
495
  num_heads=config.decoder_attention_heads,
496
  dropout=config.attention_dropout,
497
  is_decoder=True,
 
 
498
  )
499
  self.dropout = config.dropout
500
  self.activation_fn = ACT2FN[config.activation_function]
501
  self.activation_dropout = config.activation_dropout
502
 
503
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
504
+ self.encoder_attn = IndicTransAttention(
505
  self.embed_dim,
506
  config.decoder_attention_heads,
507
  dropout=config.attention_dropout,
508
  is_decoder=True,
 
509
  )
510
  self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
511
  self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
 
693
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
694
  )
695
 
 
 
 
696
  self.gradient_checkpointing = False
697
  # Initialize weights and apply final processing
698
  self.post_init()
 
779
 
780
  hidden_states = inputs_embeds + embed_pos
781
  if self.layernorm_embedding is not None:
782
+ x = self.layernorm_embedding(hidden_states)
783
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
784
 
785
+ # expand attention_mask
786
  if attention_mask is not None:
787
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
788
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
 
 
 
 
 
 
 
 
 
789
 
790
  encoder_states = () if output_hidden_states else None
791
  all_attentions = () if output_attentions else None
 
909
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
910
  )
911
 
 
 
 
912
  self.gradient_checkpointing = False
913
  # Initialize weights and apply final processing
914
  self.post_init()
 
1031
  if inputs_embeds is None:
1032
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1033
 
1034
+ # create causal mask
1035
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1036
+ combined_attention_mask = None
1037
+ if input_shape[-1] > 1:
1038
+ combined_attention_mask = _make_causal_mask(
 
 
 
 
1039
  input_shape,
1040
+ inputs_embeds.dtype,
1041
+ device=inputs_embeds.device,
1042
+ past_key_values_length=past_key_values_length,
1043
  )
1044
+
1045
+ if attention_mask is not None and combined_attention_mask is not None:
1046
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1047
+ combined_attention_mask = combined_attention_mask + _expand_mask(
1048
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1049
  )
1050
 
1051
  # expand encoder attention mask
1052
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1053
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1054
+ encoder_attention_mask = _expand_mask(
1055
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1056
+ )
 
 
 
 
 
 
 
 
 
 
 
 
1057
 
1058
  # embed positions
1059
  positions = self.embed_positions(
 
1124
  layer_outputs = torch.utils.checkpoint.checkpoint(
1125
  create_custom_forward(decoder_layer),
1126
  hidden_states,
1127
+ combined_attention_mask,
1128
  encoder_hidden_states,
1129
  encoder_attention_mask,
1130
  head_mask[idx] if head_mask is not None else None,
 
1136
  else:
1137
  layer_outputs = decoder_layer(
1138
  hidden_states,
1139
+ attention_mask=combined_attention_mask,
1140
  encoder_hidden_states=encoder_hidden_states,
1141
  encoder_attention_mask=encoder_attention_mask,
1142
  layer_head_mask=(
 
1295
 
1296
 
1297
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1298
+ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1299
  base_model_prefix = "model"
1300
+ _tied_weights_keys = None
1301
  _label_smoothing = 0.0
1302
 
1303
  def __init__(self, config: IndicTransConfig):
 
1307
  config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1308
  )
1309
 
1310
+ if config.share_decoder_input_output_embed:
1311
+ self.lm_head.weight = self.model.decoder.embed_tokens.weight
1312
+
1313
  self.post_init()
1314
 
1315
  def tie_weights(self):
1316
+ pass
1317
+
 
1318
  def get_encoder(self):
1319
+ return self.model.get_encoder()
1320
 
1321
  def get_decoder(self):
1322
+ return self.model.get_decoder()
 
 
 
1323
 
1324
  def get_output_embeddings(self):
1325
  return self.lm_head
 
1329
 
1330
  def set_label_smoothing(self, label_smoothing):
1331
  self._label_smoothing = label_smoothing
1332
+
1333
  def forward(
1334
  self,
1335
  input_ids: Optional[torch.LongTensor] = None,
 
1393
  masked_lm_loss = F.cross_entropy(
1394
  input=lm_logits.view(-1, self.config.decoder_vocab_size),
1395
  target=labels.view(-1),
1396
+ ignore_index=self.config.pad_token_id,
1397
  label_smoothing=self._label_smoothing,
1398
  )
1399
 
special_tokens_map.json DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "bos_token": "<s>",
3
- "eos_token": "</s>",
4
- "pad_token": "<pad>",
5
- "unk_token": "<unk>"
6
- }
 
 
 
 
 
 
 
tokenization_indictrans.py DELETED
@@ -1,261 +0,0 @@
1
- import os
2
- import json
3
-
4
- from typing import Dict, List, Optional, Union, Tuple
5
-
6
- from transformers.utils import logging
7
- from sentencepiece import SentencePieceProcessor
8
- from transformers.tokenization_utils import PreTrainedTokenizer
9
-
10
-
11
- logger = logging.get_logger(__name__)
12
-
13
- SPIECE_UNDERLINE = "▁"
14
-
15
- SPECIAL_TAGS = {
16
- "_bt_",
17
- "_ft_",
18
- "asm_Beng",
19
- "awa_Deva",
20
- "ben_Beng",
21
- "bho_Deva",
22
- "brx_Deva",
23
- "doi_Deva",
24
- "eng_Latn",
25
- "gom_Deva",
26
- "gon_Deva",
27
- "guj_Gujr",
28
- "hin_Deva",
29
- "hne_Deva",
30
- "kan_Knda",
31
- "kas_Arab",
32
- "kas_Deva",
33
- "kha_Latn",
34
- "lus_Latn",
35
- "mag_Deva",
36
- "mai_Deva",
37
- "mal_Mlym",
38
- "mar_Deva",
39
- "mni_Beng",
40
- "mni_Mtei",
41
- "npi_Deva",
42
- "ory_Orya",
43
- "pan_Guru",
44
- "san_Deva",
45
- "sat_Olck",
46
- "snd_Arab",
47
- "snd_Deva",
48
- "tam_Taml",
49
- "tel_Telu",
50
- "urd_Arab",
51
- "unr_Deva",
52
- }
53
-
54
- VOCAB_FILES_NAMES = {
55
- "src_vocab_fp": "dict.SRC.json",
56
- "tgt_vocab_fp": "dict.TGT.json",
57
- "src_spm_fp": "model.SRC",
58
- "tgt_spm_fp": "model.TGT",
59
- }
60
-
61
-
62
- class IndicTransTokenizer(PreTrainedTokenizer):
63
- _added_tokens_encoder = {}
64
- _added_tokens_decoder = {}
65
-
66
- vocab_files_names = VOCAB_FILES_NAMES
67
- model_input_names = ["input_ids", "attention_mask"]
68
-
69
- def __init__(
70
- self,
71
- src_vocab_fp=None,
72
- tgt_vocab_fp=None,
73
- src_spm_fp=None,
74
- tgt_spm_fp=None,
75
- unk_token="<unk>",
76
- bos_token="<s>",
77
- eos_token="</s>",
78
- pad_token="<pad>",
79
- do_lower_case=False,
80
- **kwargs,
81
- ):
82
-
83
- self.src = True
84
-
85
- self.src_vocab_fp = src_vocab_fp
86
- self.tgt_vocab_fp = tgt_vocab_fp
87
- self.src_spm_fp = src_spm_fp
88
- self.tgt_spm_fp = tgt_spm_fp
89
-
90
- self.unk_token = unk_token.content
91
- self.pad_token = pad_token.content
92
- self.eos_token = eos_token.content
93
- self.bos_token = bos_token.content
94
-
95
- self.encoder = self._load_json(self.src_vocab_fp)
96
- if self.unk_token not in self.encoder:
97
- raise KeyError("<unk> token must be in vocab")
98
- assert self.pad_token in self.encoder
99
- self.encoder_rev = {v: k for k, v in self.encoder.items()}
100
-
101
- self.decoder = self._load_json(self.tgt_vocab_fp)
102
- if self.unk_token not in self.encoder:
103
- raise KeyError("<unk> token must be in vocab")
104
- assert self.pad_token in self.encoder
105
- self.decoder_rev = {v: k for k, v in self.decoder.items()}
106
-
107
- # load SentencePiece model for pre-processing
108
- self.src_spm = self._load_spm(self.src_spm_fp)
109
- self.tgt_spm = self._load_spm(self.tgt_spm_fp)
110
-
111
- self.current_spm = self.src_spm
112
- self.current_encoder = self.encoder
113
- self.current_encoder_rev = self.encoder_rev
114
-
115
- self.unk_token_id = self.encoder[self.unk_token]
116
- self.pad_token_id = self.encoder[self.pad_token]
117
- self.eos_token_id = self.encoder[self.eos_token]
118
- self.bos_token_id = self.encoder[self.bos_token]
119
-
120
- super().__init__(
121
- src_vocab_file=self.src_vocab_fp,
122
- tgt_vocab_file=self.src_vocab_fp,
123
- do_lower_case=do_lower_case,
124
- unk_token=unk_token,
125
- bos_token=bos_token,
126
- eos_token=eos_token,
127
- pad_token=pad_token,
128
- **kwargs,
129
- )
130
-
131
- def add_new_special_tags(self, new_tags: List[str]):
132
- SPECIAL_TAGS.update(new_tags)
133
-
134
- def _switch_to_input_mode(self):
135
- self.src = True
136
- self.padding_side = "left"
137
- self.current_spm = self.src_spm
138
- self.current_encoder = self.encoder
139
- self.current_encoder_rev = self.encoder_rev
140
-
141
- def _switch_to_target_mode(self):
142
- self.src = False
143
- self.padding_side = "right"
144
- self.current_spm = self.tgt_spm
145
- self.current_encoder = self.decoder
146
- self.current_encoder_rev = self.decoder_rev
147
-
148
- def _load_spm(self, path: str) -> SentencePieceProcessor:
149
- return SentencePieceProcessor(model_file=path)
150
-
151
- def _save_json(self, data, path: str) -> None:
152
- with open(path, "w", encoding="utf-8") as f:
153
- json.dump(data, f, indent=2)
154
-
155
- def _load_json(self, path: str) -> Union[Dict, List]:
156
- with open(path, "r", encoding="utf-8") as f:
157
- return json.load(f)
158
-
159
- def _split_tags(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
160
- tags = [token for token in tokens if token in SPECIAL_TAGS]
161
- tokens = [token for token in tokens if token not in SPECIAL_TAGS]
162
- return tags, tokens
163
-
164
- def _split_pads(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
165
- pads = [token for token in tokens if token == self.pad_token]
166
- tokens = [token for token in tokens if token != self.pad_token]
167
- return pads, tokens
168
-
169
- @property
170
- def src_vocab_size(self) -> int:
171
- return len(self.encoder)
172
-
173
- @property
174
- def tgt_vocab_size(self) -> int:
175
- return len(self.decoder)
176
-
177
- def get_src_vocab(self) -> Dict[str, int]:
178
- return dict(self.encoder, **self.added_tokens_encoder)
179
-
180
- def get_tgt_vocab(self) -> Dict[str, int]:
181
- return dict(self.decoder, **self.added_tokens_decoder)
182
-
183
- # hack override
184
- def get_vocab(self) -> Dict[str, int]:
185
- return self.get_src_vocab()
186
-
187
- # hack override
188
- @property
189
- def vocab_size(self) -> int:
190
- return self.src_vocab_size
191
-
192
- def _convert_token_to_id(self, token: str) -> int:
193
- """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
194
- return self.current_encoder.get(token, self.current_encoder[self.unk_token])
195
-
196
- def _convert_id_to_token(self, index: int) -> str:
197
- """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
198
- return self.current_encoder_rev.get(index, self.unk_token)
199
-
200
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
201
- """Uses sentencepiece model for detokenization"""
202
- pads, tokens = self._split_pads(tokens)
203
-
204
- if self.src:
205
-
206
- tags, non_tags = self._split_tags(tokens)
207
-
208
- return (
209
- " ".join(pads)
210
- + " "
211
- + " ".join(tags)
212
- + " "
213
- + "".join(non_tags).replace(SPIECE_UNDERLINE, " ").strip()
214
- )
215
-
216
- return (
217
- "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
218
- + " "
219
- + " ".join(pads)
220
- )
221
-
222
- def _tokenize(self, text) -> List[str]:
223
- if self.src:
224
- tokens = text.split(" ")
225
- tags, non_tags = self._split_tags(tokens)
226
- text = " ".join(non_tags)
227
- tokens = self.current_spm.EncodeAsPieces(text)
228
- return tags + tokens
229
- else:
230
- return self.current_spm.EncodeAsPieces(text)
231
-
232
- def build_inputs_with_special_tokens(
233
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
- ) -> List[int]:
235
- if token_ids_1 is None:
236
- return token_ids_0 + [self.eos_token_id]
237
- # We don't expect to process pairs, but leave the pair logic for API consistency
238
- return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
239
-
240
- def save_vocabulary(
241
- self, save_directory: str, filename_prefix: Optional[str] = None
242
- ) -> Tuple[str]:
243
- if not os.path.isdir(save_directory):
244
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
245
- return
246
-
247
- src_spm_fp = os.path.join(save_directory, "model.SRC")
248
- tgt_spm_fp = os.path.join(save_directory, "model.TGT")
249
- src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
250
- tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
251
-
252
- self._save_json(self.encoder, src_vocab_fp)
253
- self._save_json(self.decoder, tgt_vocab_fp)
254
-
255
- with open(src_spm_fp, "wb") as f:
256
- f.write(self.src_spm.serialized_model_proto())
257
-
258
- with open(tgt_spm_fp, "wb") as f:
259
- f.write(self.tgt_spm.serialized_model_proto())
260
-
261
- return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer_config.json DELETED
@@ -1,51 +0,0 @@
1
- {
2
- "added_tokens_decoder": {
3
- "0": {
4
- "content": "<s>",
5
- "lstrip": false,
6
- "normalized": false,
7
- "rstrip": false,
8
- "single_word": false,
9
- "special": true
10
- },
11
- "1": {
12
- "content": "<pad>",
13
- "lstrip": false,
14
- "normalized": false,
15
- "rstrip": false,
16
- "single_word": false,
17
- "special": true
18
- },
19
- "2": {
20
- "content": "</s>",
21
- "lstrip": false,
22
- "normalized": false,
23
- "rstrip": false,
24
- "single_word": false,
25
- "special": true
26
- },
27
- "3": {
28
- "content": "<unk>",
29
- "lstrip": false,
30
- "normalized": false,
31
- "rstrip": false,
32
- "single_word": false,
33
- "special": true
34
- }
35
- },
36
- "bos_token": "<s>",
37
- "clean_up_tokenization_spaces": true,
38
- "do_lower_case": false,
39
- "eos_token": "</s>",
40
- "model_max_length": 256,
41
- "pad_token": "<pad>",
42
- "name_or_path": "ai4bharat/indictrans2-en-indic-1B",
43
- "tokenizer_class": "IndicTransTokenizer",
44
- "auto_map": {
45
- "AutoTokenizer": [
46
- "tokenization_indictrans.IndicTransTokenizer",
47
- null
48
- ]
49
- },
50
- "unk_token": "<unk>"
51
- }