ljw20180420's picture
Upload folder using huggingface_hub
afb7b3e verified
from transformers import PreTrainedModel, RoFormerConfig, RoFormerModel
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
class CRISPRTransformerConfig(RoFormerConfig):
model_type = "CRISPR_transformer"
label_names = ["observation"]
def __init__(
self,
vocab_size = 4, # ACGT
hidden_size = 256, # model embedding dimension
num_hidden_layers = 3, # number of EncoderLayer
num_attention_heads = 4, # number of attention heads
intermediate_size = 1024, # FeedForward intermediate dimension size
hidden_dropout_prob = 0.1, # The dropout probability for all fully connected layers in the embeddings, encoder, and pooler
attention_probs_dropout_prob = 0.1, # The dropout ratio for the attention probabilities
max_position_embeddings = 256, # The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 1536).
ref1len = 127, # length of reference 1
ref2len = 127, # length of reference 2
seed = 63036, # random seed for intialization
**kwargs
):
self.ref1len = ref1len
self.ref2len = ref2len
self.seed = seed
super().__init__(
vocab_size = vocab_size,
hidden_size = hidden_size,
num_hidden_layers = num_hidden_layers,
num_attention_heads = num_attention_heads,
intermediate_size = intermediate_size,
hidden_dropout_prob = hidden_dropout_prob,
attention_probs_dropout_prob = attention_probs_dropout_prob,
max_position_embeddings = max_position_embeddings,
**kwargs
)
class CRISPRTransformerModel(PreTrainedModel):
config_class = CRISPRTransformerConfig
def __init__(self, config):
super().__init__(config)
self.generator = torch.Generator().manual_seed(config.seed)
self.model = RoFormerModel(config)
self.mlp = nn.Linear(
in_features=config.hidden_size,
out_features=(config.ref1len + 1) * (config.ref2len + 1)
)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=1, generator=self.generator)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, refcode: torch.Tensor, observation: torch.Tensor=None):
# refcode (batch_size X sequence_length)
# model(refcode) (batch_size X sequence_length X hidden_size)
# model(refcode)[:, -1, :] arbitrary choose the last position to predict the logits
batch_size = refcode.shape[0]
logit = self.mlp(
self.model(
input_ids=refcode,
attention_mask=torch.ones(
batch_size,
self.config.ref1len + self.config.ref2len,
dtype=torch.int64,
device=self.model.device
)
).last_hidden_state[:, -1, :]
).view(batch_size, self.config.ref2len + 1, self.config.ref1len + 1)
if observation is not None:
return {
"logit": logit,
"loss": - (
observation.flatten(start_dim=1) *
F.log_softmax(logit.flatten(start_dim=1), dim=1)
).sum()
}
return {"logit": logit}