File size: 2,357 Bytes
e821e69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import PreTrainedModel, PretrainedConfig
from torch import nn
from esm.pretrained import ESM3_sm_open_v0
import torch


class StabilityPredictionConfig(PretrainedConfig):
    def __init__(self, embed_dim=1536, *args, **kwargs):
        super().__init__(*args, embed_dim=1536, **kwargs)


class SingleMutationPooler(nn.Module):
    def __init__(self, embed_dim=1536):
        super().__init__()
        self.wt_weight = nn.Parameter(torch.ones((1, embed_dim)), requires_grad=True)
        self.mut_weight = nn.Parameter(-1 * torch.ones((1, embed_dim)), requires_grad=True)
        self.norm = nn.LayerNorm(embed_dim, bias=False)
        
    
    def forward(self, wt_embedding, mut_embedding, positions):
        embed_shape = wt_embedding.shape[-1]
        positions = positions.view(-1, 1).unsqueeze(2).repeat(1, 1, embed_shape) + 1
        wt_residues = torch.gather(wt_embedding, 1, positions).squeeze(1)
        mut_residues = torch.gather(mut_embedding, 1, positions).squeeze(1)
        wt_residues = wt_residues * self.wt_weight
        mut_residues = mut_residues * self.mut_weight
        return self.norm(wt_residues + mut_residues)
    

    
class StabilityPrediction(PreTrainedModel):
    config_class = StabilityPredictionConfig
    def __init__(self, config=StabilityPredictionConfig()):
        super().__init__(config=config)
        self.backbone = ESM3_sm_open_v0(getattr(config, "device", "cpu"))
        self.pooler = SingleMutationPooler()
        self.regressor = nn.Linear(config.embed_dim, 1)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.regressor.weight, -0.01, 0.01)
        nn.init.zeros_(self.regressor.bias)
    
    def compute_loss(self, logits, labels):
        if labels is None:
            return
        return F.mse_loss(logits, labels)
    
    def forward(self, wt_input_ids, mut_input_ids, positions, labels=None):
        wt_embeddings = self.backbone(sequence_tokens=wt_input_ids).embeddings
        mut_embeddings = self.backbone(sequence_tokens=mut_input_ids).embeddings
        aggregated_embeddings = self.pooler(wt_embeddings, mut_embeddings, positions)
        logits = self.regressor(aggregated_embeddings)
        loss = self.compute_loss(logits, labels)
        return {
            "loss": loss,
            "logits": logits,
        }