Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,507 Bytes
1f1f2d8 f74eb0b 1f1f2d8 f74eb0b 1f1f2d8 f74eb0b 1f1f2d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import gradio as gr
import py3Dmol
import io
import numpy as np
import os
import traceback
import spaces
from esm.sdk import client
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain
from Bio.Data import PDBData
import biotite.structure as bs
from biotite.structure.io import pdb
from esm.utils import residue_constants as RC
# Initialize the model
token = os.environ.get("ESM_API_TOKEN")
if not token:
raise ValueError("ESM_API_TOKEN environment variable is not set")
model = client(
model="esm3-medium-2024-03",
url="https://forge.evolutionaryscale.ai",
token=token,
)
amino3to1 = {
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
}
def read_pdb_io(pdb_file):
if isinstance(pdb_file, io.StringIO):
pdb_content = pdb_file.getvalue()
elif hasattr(pdb_file, 'name'):
with open(pdb_file.name, 'r') as f:
pdb_content = f.read()
else:
raise ValueError("Unsupported file type")
if not pdb_content.strip():
raise ValueError("The PDB file is empty.")
pdb_io = io.StringIO(pdb_content)
return pdb_io, pdb_content
def get_protein(pdb_file) -> ESMProtein:
try:
pdb_io, content = read_pdb_io(pdb_file)
if not content.strip():
raise ValueError("The PDB file is empty")
# Parse the PDB file using biotite
pdb_file = pdb.PDBFile.read(pdb_io)
structure = pdb_file.get_structure()
# Check if the structure contains any atoms
if structure.array_length() == 0:
raise ValueError("The PDB file does not contain any valid atoms")
# Filter for amino acids and create a sequence
valid_residues = []
for res in bs.residue_iter(structure):
res_name = res.res_name
if isinstance(res_name, np.ndarray):
res_name = res_name[0] # Take the first element if it's an array
if res_name in amino3to1:
valid_residues.append(res)
if not valid_residues:
raise ValueError("No valid amino acid residues found in the PDB file")
sequence = ''.join(amino3to1.get(res.res_name[0] if isinstance(res.res_name, np.ndarray) else res.res_name, 'X') for res in valid_residues)
# Handle res_id as a potential sequence
residue_indices = []
for res in valid_residues:
if isinstance(res.res_id, (list, tuple, np.ndarray)):
residue_indices.append(res.res_id[0]) # Take the first element if it's a sequence
else:
residue_indices.append(res.res_id)
# Create a ProteinChain object
protein_chain = ProteinChain(
id="test",
sequence=sequence,
chain_id="A",
entity_id=None,
residue_index=np.array(residue_indices, dtype=int),
insertion_code=np.full(len(sequence), "", dtype="<U4"),
atom37_positions=np.full((len(sequence), 37, 3), np.nan),
atom37_mask=np.zeros((len(sequence), 37), dtype=bool),
confidence=np.ones(len(sequence), dtype=np.float32)
)
# Fill in atom positions and mask
for i, res in enumerate(valid_residues):
for atom in res:
atom_name = atom.atom_name
if isinstance(atom_name, np.ndarray):
atom_name = atom_name[0] # Take the first element if it's an array
if atom_name in RC.atom_order:
idx = RC.atom_order[atom_name]
coord = atom.coord
if coord.ndim > 1:
coord = coord[0] # Take the first coordinate set if multiple are present
protein_chain.atom37_positions[i, idx] = coord
protein_chain.atom37_mask[i, idx] = True
protein = ESMProtein.from_protein_chain(protein_chain)
return protein
except Exception as e:
print(f"Error processing PDB file: {str(e)}")
raise ValueError(f"Unable to process the PDB file: {str(e)}")
def add_noise_to_coordinates(protein: ESMProtein, noise_level: float) -> ESMProtein:
"""Add Gaussian noise to the atom positions of the protein."""
coordinates = protein.coordinates
noise = np.random.randn(*coordinates.shape) * noise_level
noisy_coordinates = coordinates + noise
return ESMProtein(sequence=protein.sequence, coordinates=noisy_coordinates)
def prediction_visualization(pdb_file, num_runs: int, noise_level: float, num_frames: int):
protein = get_protein(pdb_file)
runs = []
for frame in range(num_frames):
noisy_protein = add_noise_to_coordinates(protein, noise_level)
for i in range(num_runs):
structure_prediction = run_structure_prediction(noisy_protein)
aligned, crmsd = align_after_prediction(protein, structure_prediction)
runs.append((crmsd, aligned))
best_aligned = sorted(runs)[0]
view = visualize_after_pred(protein, best_aligned[1])
return view, f"Best cRMSD: {best_aligned[0]:.4f}"
def run_structure_prediction(protein: ESMProtein) -> ESMProtein:
structure_prediction_config = GenerationConfig(
track="structure",
num_steps=40,
temperature=0.7,
)
structure_prediction = model.generate(protein, structure_prediction_config)
return structure_prediction
def align_after_prediction(protein: ESMProtein, structure_prediction: ESMProtein) -> tuple[ESMProtein, float]:
structure_prediction_chain = structure_prediction.to_protein_chain()
protein_chain = protein.to_protein_chain()
structure_indices = np.arange(0, len(structure_prediction_chain.sequence))
aligned_chain = structure_prediction_chain.align(protein_chain, mobile_inds=structure_indices, target_inds=structure_indices)
crmsd = structure_prediction_chain.rmsd(protein_chain, mobile_inds=structure_indices, target_inds=structure_indices)
return ESMProtein.from_protein_chain(aligned_chain), crmsd
def visualize_after_pred(protein: ESMProtein, aligned: ESMProtein):
view = py3Dmol.view(width=800, height=600)
view.addModel(protein.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addModel(aligned.to_pdb_string(), "pdb")
view.setStyle({"model": 1}, {"cartoon": {"color": "lightgreen"}})
view.zoomTo()
return view
@spaces.GPU()
def run_prediction(pdb_file, num_runs, noise_level, num_frames):
try:
if pdb_file is None:
return "Please upload a PDB file.", "No file uploaded"
view, crmsd_text = prediction_visualization(pdb_file, num_runs, noise_level, num_frames)
html = view._make_html()
return f"""
<div style="height: 600px;">
{html}
</div>
""", crmsd_text
except Exception as e:
error_message = str(e)
stack_trace = traceback.format_exc()
return f"""
<div style='color: red;'>
<h3>Error:</h3>
<p>{error_message}</p>
<h4>Stack Trace:</h4>
<pre>{stack_trace}</pre>
</div>
""", "Error occurred"
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# Protein Structure Prediction and Visualization with Noise and MD Frames")
with gr.Row():
with gr.Column(scale=1):
pdb_file = gr.File(label="Upload PDB file")
num_runs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of runs per frame")
noise_level = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Noise level")
num_frames = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of MD frames")
run_button = gr.Button("Run Prediction")
with gr.Column(scale=2):
visualization = gr.HTML(label="3D Visualization")
alignment_result = gr.Textbox(label="Alignment Result")
run_button.click(
fn=run_prediction,
inputs=[pdb_file, num_runs, noise_level, num_frames],
outputs=[visualization, alignment_result]
)
gr.Markdown("""
## How to use
1. Upload a PDB file using the file uploader.
2. Adjust the number of prediction runs per frame using the slider.
3. Set the noise level to add random perturbations to the structure.
4. Choose the number of MD frames to simulate.
5. Click the "Run Prediction" button to start the process.
6. The 3D visualization will show the original structure (grey) and the best predicted structure (green).
7. The alignment result will display the best cRMSD (lower is better).
## About
This demo uses the ESM3 model to predict protein structures from PDB files.
It runs multiple predictions with added noise and simulated MD frames, displaying the best result based on the lowest cRMSD.
""")
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch() |