|
from transformers import PreTrainedModel, RoFormerConfig, RoFormerModel |
|
import torch.nn as nn |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
class CRISPRTransformerConfig(RoFormerConfig): |
|
model_type = "CRISPR_transformer" |
|
label_names = ["observation"] |
|
|
|
def __init__( |
|
self, |
|
vocab_size = 4, |
|
hidden_size = 256, |
|
num_hidden_layers = 3, |
|
num_attention_heads = 4, |
|
intermediate_size = 1024, |
|
hidden_dropout_prob = 0.1, |
|
attention_probs_dropout_prob = 0.1, |
|
max_position_embeddings = 256, |
|
ref1len = 127, |
|
ref2len = 127, |
|
seed = 63036, |
|
**kwargs |
|
): |
|
self.ref1len = ref1len |
|
self.ref2len = ref2len |
|
self.seed = seed |
|
super().__init__( |
|
vocab_size = vocab_size, |
|
hidden_size = hidden_size, |
|
num_hidden_layers = num_hidden_layers, |
|
num_attention_heads = num_attention_heads, |
|
intermediate_size = intermediate_size, |
|
hidden_dropout_prob = hidden_dropout_prob, |
|
attention_probs_dropout_prob = attention_probs_dropout_prob, |
|
max_position_embeddings = max_position_embeddings, |
|
**kwargs |
|
) |
|
|
|
class CRISPRTransformerModel(PreTrainedModel): |
|
config_class = CRISPRTransformerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.generator = torch.Generator().manual_seed(config.seed) |
|
self.model = RoFormerModel(config) |
|
self.mlp = nn.Linear( |
|
in_features=config.hidden_size, |
|
out_features=(config.ref1len + 1) * (config.ref2len + 1) |
|
) |
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, mean=0, std=1, generator=self.generator) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, refcode: torch.Tensor, observation: torch.Tensor=None): |
|
|
|
|
|
|
|
batch_size = refcode.shape[0] |
|
logit = self.mlp( |
|
self.model( |
|
input_ids=refcode, |
|
attention_mask=torch.ones( |
|
batch_size, |
|
self.config.ref1len + self.config.ref2len, |
|
dtype=torch.int64, |
|
device=self.model.device |
|
) |
|
).last_hidden_state[:, -1, :] |
|
).view(batch_size, self.config.ref2len + 1, self.config.ref1len + 1) |
|
if observation is not None: |
|
return { |
|
"logit": logit, |
|
"loss": - ( |
|
observation.flatten(start_dim=1) * |
|
F.log_softmax(logit.flatten(start_dim=1), dim=1) |
|
).sum() |
|
} |
|
return {"logit": logit} |
|
|