RDPD-mini / modelLM.py
aframson's picture
off
42fb9e5
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Define your custom language model class
class OBILanguageModel(PreTrainedModel):
def __init__(self, config):
super(OBILanguageModel,self).__init__(config)
self.token_embedding_table = nn.Embedding(config.vocab_size, config.hidden_size) # Use length of SentencePiece vocab
self.position_embedding_table = nn.Embedding(config.block_size, config.hidden_size)
self.transformer = nn.Transformer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
num_encoder_layers=config.num_hidden_layers,
num_decoder_layers=config.num_hidden_layers,
dim_feedforward=4 * config.hidden_size,
dropout=config.hidden_dropout_prob,
activation='gelu'
)
self.ln1 = nn.LayerNorm(config.hidden_size)
self.ln2 = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) # Use length of SentencePiece vocab
def forward(self, idx, targets=None):
tok_emb = self.token_embedding_table(idx)
pos_emb = None # Initialize pos_emb to None
try:
pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
except IndexError as e:
# Handle the IndexError by initializing pos_emb with zeros
print(f"IndexError: {e}")
print(f"idx.size(1): {idx.size(1)}")
print(f"Positional embedding table shape: {self.position_embedding_table.weight.shape}")
pos_emb = torch.zeros((idx.size(1), self.config.hidden_size), device=device)
x = tok_emb + pos_emb
x = self.transformer(x, x)
x = self.ln1(x)
x = self.ln2(x)
logits = self.lm_head(x)
# Always compute the loss, and set it to None if targets are not provided
loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1)) if targets is not None else None
return (logits, loss)
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.block_size:]
logits, loss = self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx