From e7261ba7ce11e358c2855b3cf4ed26aee935057a Mon Sep 17 00:00:00 2001 From: Olamide Isreal Date: Wed, 18 Mar 2026 22:31:13 +0100 Subject: [PATCH] Add LigandMPNN Nextflow pipeline for protein sequence design --- Dockerfile | 96 +++ LICENSE | 21 + README.md | 636 ++++++++++++++++ data_utils.py | 988 ++++++++++++++++++++++++ get_model_params.sh | 47 ++ main.nf | 67 ++ model_utils.py | 1772 +++++++++++++++++++++++++++++++++++++++++++ nextflow.config | 42 + params.json | 131 ++++ requirements.txt | 29 + run.py | 990 ++++++++++++++++++++++++ run_examples.sh | 244 ++++++ sc_examples.sh | 55 ++ sc_utils.py | 1158 ++++++++++++++++++++++++++++ score.py | 549 ++++++++++++++ 15 files changed, 6825 insertions(+) create mode 100644 Dockerfile create mode 100644 LICENSE create mode 100644 README.md create mode 100644 data_utils.py create mode 100644 get_model_params.sh create mode 100644 main.nf create mode 100644 model_utils.py create mode 100644 nextflow.config create mode 100644 params.json create mode 100644 requirements.txt create mode 100644 run.py create mode 100644 run_examples.sh create mode 100644 sc_examples.sh create mode 100644 sc_utils.py create mode 100644 score.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2969da2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,96 @@ +FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04 + +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Set working directory +WORKDIR /app + +# Install system dependencies including build tools for ProDy compilation +RUN apt-get update -y && \ + apt-get install -y --no-install-recommends \ + wget \ + git \ + curl \ + ca-certificates \ + python3.11 \ + python3.11-venv \ + python3.11-dev \ + python3-pip \ + procps \ + build-essential \ + && rm -rf /var/lib/apt/lists/* \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 \ + && update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 + +# Upgrade pip +RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel + +# Clone LigandMPNN repository +RUN git clone https://github.com/dauparas/LigandMPNN.git /app/LigandMPNN + +# Set working directory to LigandMPNN +WORKDIR /app/LigandMPNN + +# Install Python dependencies +RUN pip3 install --no-cache-dir \ + biopython==1.79 \ + filelock==3.13.1 \ + fsspec==2024.3.1 \ + Jinja2==3.1.3 \ + MarkupSafe==2.1.5 \ + mpmath==1.3.0 \ + networkx==3.2.1 \ + numpy==1.23.5 \ + ProDy==2.4.1 \ + pyparsing==3.1.1 \ + scipy==1.12.0 \ + sympy==1.12 \ + typing_extensions==4.10.0 \ + ml-collections==0.1.1 \ + dm-tree==0.1.8 + +# Install PyTorch with CUDA support +RUN pip3 install --no-cache-dir \ + torch==2.2.1 \ + --index-url https://download.pytorch.org/whl/cu121 + +# Download model parameters +RUN mkdir -p /app/LigandMPNN/model_params && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_002.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_002.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_010.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_010.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_020.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_030.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_030.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_005_25.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_010_25.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_020_25.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_030_25.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_002.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_002.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_010.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_010.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_020.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_030.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_030.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/per_residue_label_membrane_mpnn_v_48_020.pt -O /app/LigandMPNN/model_params/per_residue_label_membrane_mpnn_v_48_020.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/global_label_membrane_mpnn_v_48_020.pt -O /app/LigandMPNN/model_params/global_label_membrane_mpnn_v_48_020.pt && \ + wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_sc_v_32_002_16.pt -O /app/LigandMPNN/model_params/ligandmpnn_sc_v_32_002_16.pt && \ + ls -la /app/LigandMPNN/model_params/ + +# Create wrapper script for ligandmpnn (using absolute path) +RUN printf '#!/bin/bash\npython /app/LigandMPNN/run.py "$@"\n' > /usr/local/bin/ligandmpnn && \ + chmod +x /usr/local/bin/ligandmpnn + +# Create wrapper script for scoring (using absolute path) +RUN printf '#!/bin/bash\npython /app/LigandMPNN/score.py "$@"\n' > /usr/local/bin/ligandmpnn-score && \ + chmod +x /usr/local/bin/ligandmpnn-score + +# Create input/output directories +RUN mkdir -p /app/inputs /app/outputs + +# Set environment variables for model paths +ENV LIGANDMPNN_MODEL_DIR=/app/LigandMPNN/model_params +ENV PYTHONPATH=/app/LigandMPNN:${PYTHONPATH:-} + +WORKDIR /app/LigandMPNN + +CMD ["python", "run.py", "--help"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..715f310 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Justas Dauparas + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f4a1d00 --- /dev/null +++ b/README.md @@ -0,0 +1,636 @@ +## LigandMPNN + +This package provides inference code for [LigandMPNN](https://www.biorxiv.org/content/10.1101/2023.12.22.573103v1) & [ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) models. The code and model parameters are available under the MIT license. + +Third party code: side chain packing uses helper functions from [Openfold](https://github.com/aqlaboratory/openfold). + +### Running the code +``` +git clone https://github.com/dauparas/LigandMPNN.git +cd LigandMPNN +bash get_model_params.sh "./model_params" + +#setup your conda/or other environment +#conda create -n ligandmpnn_env python=3.11 +#pip3 install -r requirements.txt + +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/default" +``` + +### Dependencies +To run the model you will need to have Python>=3.0, PyTorch, Numpy installed, and to read/write PDB files you will need [Prody](https://pypi.org/project/ProDy/). + +For example to make a new conda environment for LigandMPNN run: +``` +conda create -n ligandmpnn_env python=3.11 +pip3 install -r requirements.txt +``` + +### Main differences compared with [ProteinMPNN](https://github.com/dauparas/ProteinMPNN) code +- Input PDBs are parsed using [Prody](https://pypi.org/project/ProDy/) preserving protein residue indices, chain letters, and insertion codes. If there are missing residues in the input structure the output fasta file won't have added `X` to fill the gaps. The script outputs .fasta and .pdb files. It's recommended to use .pdb files since they will hold information about chain letters and residue indices. +- Adding bias, fixing residues, and selecting residues to be redesigned now can be done using residue indices directly, e.g. A23 (means chain A residue with index 23), B42D (chain B, residue 42, insertion code D). +- Model writes to fasta files: `overall_confidence`, `ligand_confidence` which reflect the average confidence/probability (with T=1.0) over the redesigned residues `overall_confidence=exp[-mean_over_residues(log_probs)]`. Higher numbers mean the model is more confident about that sequence. min_value=0.0; max_value=1.0. Sequence recovery with respect to the input sequence is calculated only over the redesigned residues. + +### Model parameters +To download model parameters run: +``` +bash get_model_params.sh "./model_params" +``` + +### Available models + +To run the model of your choice specify `--model_type` and optionally the model checkpoint path. Available models: +- ProteinMPNN +``` +--model_type "protein_mpnn" +--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_002.pt" #noised with 0.02A Gaussian noise +--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_010.pt" #noised with 0.10A Gaussian noise +--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_020.pt" #noised with 0.20A Gaussian noise +--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_030.pt" #noised with 0.30A Gaussian noise +``` +- LigandMPNN +``` +--model_type "ligand_mpnn" +--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" #noised with 0.05A Gaussian noise +--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_010_25.pt" #noised with 0.10A Gaussian noise +--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_020_25.pt" #noised with 0.20A Gaussian noise +--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_030_25.pt" #noised with 0.30A Gaussian noise +``` +- SolubleMPNN +``` +--model_type "soluble_mpnn" +--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_002.pt" #noised with 0.02A Gaussian noise +--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_010.pt" #noised with 0.10A Gaussian noise +--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_020.pt" #noised with 0.20A Gaussian noise +--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_030.pt" #noised with 0.30A Gaussian noise +``` +- ProteinMPNN with global membrane label +``` +--model_type "global_label_membrane_mpnn" +--checkpoint_global_label_membrane_mpnn "./model_params/global_label_membrane_mpnn_v_48_020.pt" #noised with 0.20A Gaussian noise +``` +- ProteinMPNN with per residue membrane label +``` +--model_type "per_residue_label_membrane_mpnn" +--checkpoint_per_residue_label_membrane_mpnn "./model_params/per_residue_label_membrane_mpnn_v_48_020.pt" #noised with 0.20A Gaussian noise +``` +- Side chain packing model +``` +--checkpoint_path_sc "./model_params/ligandmpnn_sc_v_32_002_16.pt" +``` +## Design examples +### 1 default +Default settings will run ProteinMPNN. +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/default" +``` +### 2 --temperature +`--temperature 0.05` Change sampling temperature (higher temperature gives more sequence diversity). +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --temperature 0.05 \ + --out_folder "./outputs/temperature" +``` +### 3 --seed +`--seed` Not selecting a seed will run with a random seed. Running this multiple times will give different results. +``` +python run.py \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/random_seed" +``` +### 4 --verbose +`--verbose 0` Do not print any statements. +``` +python run.py \ + --seed 111 \ + --verbose 0 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/verbose" +``` +### 5 --save_stats +`--save_stats 1` Save sequence design statistics. +``` +#['generated_sequences', 'sampling_probs', 'log_probs', 'decoding_order', 'native_sequence', 'mask', 'chain_mask', 'seed', 'temperature'] +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/save_stats" \ + --save_stats 1 +``` +### 6 --fixed_residues +`--fixed_residues` Fixing specific amino acids. This example fixes the first 10 residues in chain C and adds global bias towards A (alanine). The output should have all alanines except the first 10 residues should be the same as in the input sequence since those are fixed. +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/fix_residues" \ + --fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \ + --bias_AA "A:10.0" +``` + +### 7 --redesigned_residues +`--redesigned_residues` Specifying which residues need to be designed. This example redesigns the first 10 residues while fixing everything else. +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/redesign_residues" \ + --redesigned_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \ + --bias_AA "A:10.0" +``` + +### 8 --number_of_batches +Design 15 sequences; with batch size 3 (can be 1 when using CPUs) and the number of batches 5. +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/batch_size" \ + --batch_size 3 \ + --number_of_batches 5 +``` +### 9 --bias_AA +Global amino acid bias. In this example, output sequences are biased towards W, P, C and away from A. +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --bias_AA "W:3.0,P:3.0,C:3.0,A:-3.0" \ + --out_folder "./outputs/global_bias" +``` +### 10 --bias_AA_per_residue +Specify per residue amino acid bias, e.g. make residues C1, C3, C5, and C7 to be prolines. +``` +# { +# "C1": {"G": -0.3, "C": -2.0, "P": 10.8}, +# "C3": {"P": 10.0}, +# "C5": {"G": -1.3, "P": 10.0}, +# "C7": {"G": -1.3, "P": 10.0} +# } +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --bias_AA_per_residue "./inputs/bias_AA_per_residue.json" \ + --out_folder "./outputs/per_residue_bias" +``` +### 11 --omit_AA +Global amino acid restrictions. This is equivalent to using `--bias_AA` and setting bias to be a large negative number. The output should be just made of E, K, A. +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --omit_AA "CDFGHILMNPQRSTVWY" \ + --out_folder "./outputs/global_omit" +``` + +### 12 --omit_AA_per_residue +Per residue amino acid restrictions. +``` +# { +# "C1": "ACDEFGHIKLMNPQRSTVW", +# "C3": "ACDEFGHIKLMNPQRSTVW", +# "C5": "ACDEFGHIKLMNPQRSTVW", +# "C7": "ACDEFGHIKLMNPQRSTVW" +# } +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --omit_AA_per_residue "./inputs/omit_AA_per_residue.json" \ + --out_folder "./outputs/per_residue_omit" +``` +### 13 --symmetry_residues +### 13 --symmetry_weights +Designing sequences with symmetry, e.g. homooligomer/2-state proteins, etc. In this example make C1=C2=C3, also C4=C5, and C6=C7. +``` +#total_logits += symmetry_weights[t]*logits +#probs = torch.nn.functional.softmax((total_logits+bias_t) / temperature, dim=-1) +#total_logits_123 = 0.33*logits_1+0.33*logits_2+0.33*logits_3 +#output should be ***ooxx +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/symmetry" \ + --symmetry_residues "C1,C2,C3|C4,C5|C6,C7" \ + --symmetry_weights "0.33,0.33,0.33|0.5,0.5|0.5,0.5" +``` + + +### 14 --homo_oligomer +Design homooligomer sequences. This automatically sets `--symmetry_residues` and `--symmetry_weights` assuming equal weighting from all chains. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/4GYT.pdb" \ + --out_folder "./outputs/homooligomer" \ + --homo_oligomer 1 \ + --number_of_batches 2 +``` + +### 15 --file_ending +Outputs will have a specified ending; e.g. `1BC8_xyz.fa` instead of `1BC8.fa` +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/file_ending" \ + --file_ending "_xyz" +``` + +### 16 --zero_indexed +Zero indexed names in /backbones/1BC8_0.pdb, 1BC8_1.pdb, 1BC8_2.pdb etc +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/zero_indexed" \ + --zero_indexed 1 \ + --number_of_batches 2 +``` + +### 17 --chains_to_design +Specify which chains (e.g. "A,B,C") need to be redesigned, other chains will be kept fixed. Outputs in seqs/backbones will still have atoms/sequences for the whole input PDB. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/4GYT.pdb" \ + --out_folder "./outputs/chains_to_design" \ + --chains_to_design "A,B" +``` +### 18 --parse_these_chains_only +Parse and design only specified chains (e.g. "A,B,C"). Outputs will have only specified chains. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/4GYT.pdb" \ + --out_folder "./outputs/parse_these_chains_only" \ + --parse_these_chains_only "A,B" +``` + +### 19 --model_type "ligand_mpnn" +Run LigandMPNN with default settings. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/ligandmpnn_default" +``` + +### 20 --checkpoint_ligand_mpnn +Run LigandMPNN using 0.05A model by specifying `--checkpoint_ligand_mpnn` flag. +``` +python run.py \ + --checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/ligandmpnn_v_32_005_25" +``` +### 21 --ligand_mpnn_use_atom_context +Setting `--ligand_mpnn_use_atom_context 0` will mask all ligand atoms. This can be used to assess how much ligand atoms affect AA probabilities. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/ligandmpnn_no_context" \ + --ligand_mpnn_use_atom_context 0 +``` + +### 22 --ligand_mpnn_use_side_chain_context +Use fixed residue side chain atoms as extra ligand atoms. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/ligandmpnn_use_side_chain_atoms" \ + --ligand_mpnn_use_side_chain_context 1 \ + --fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" +``` + +### 23 --model_type "soluble_mpnn" +Run SolubleMPNN (ProteinMPNN-like model with only soluble proteins in the training dataset). +``` +python run.py \ + --model_type "soluble_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/soluble_mpnn_default" +``` + +### 24 --model_type "global_label_membrane_mpnn" +Run global label membrane MPNN (trained with extra input - binary label soluble vs not) `--global_transmembrane_label #1 - membrane, 0 - soluble`. +``` +python run.py \ + --model_type "global_label_membrane_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/global_label_membrane_mpnn_0" \ + --global_transmembrane_label 0 +``` + +### 25 --model_type "per_residue_label_membrane_mpnn" +Run per residue label membrane MPNN (trained with extra input per residue specifying buried (hydrophobic), interface (polar), or other type residues; 3 classes). +``` +python run.py \ + --model_type "per_residue_label_membrane_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/per_residue_label_membrane_mpnn_default" \ + --transmembrane_buried "C1 C2 C3 C11" \ + --transmembrane_interface "C4 C5 C6 C22" +``` + +### 26 --fasta_seq_separation +Choose a symbol to put between different chains in fasta output format. It's recommended to PDB output format to deal with residue jumps and multiple chain parsing. +``` +python run.py \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/fasta_seq_separation" \ + --fasta_seq_separation ":" +``` + +### 27 --pdb_path_multi +Specify multiple PDB input paths. This is more efficient since the model needs to be loaded from the checkpoint once. +``` +#{ +#"./inputs/1BC8.pdb": "", +#"./inputs/4GYT.pdb": "" +#} +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --out_folder "./outputs/pdb_path_multi" \ + --seed 111 +``` + +### 28 --fixed_residues_multi +Specify fixed residues when using `--pdb_path_multi` flag. +``` +#{ +#"./inputs/1BC8.pdb": "C1 C2 C3 C4 C5 C10 C22", +#"./inputs/4GYT.pdb": "A7 A8 A9 A10 A11 A12 A13 B38" +#} +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --fixed_residues_multi "./inputs/fix_residues_multi.json" \ + --out_folder "./outputs/fixed_residues_multi" \ + --seed 111 +``` + +### 29 --redesigned_residues_multi +Specify which residues need to be redesigned when using `--pdb_path_multi` flag. +``` +#{ +#"./inputs/1BC8.pdb": "C1 C2 C3 C4 C5 C10", +#"./inputs/4GYT.pdb": "A7 A8 A9 A10 A12 A13 B38" +#} +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --redesigned_residues_multi "./inputs/redesigned_residues_multi.json" \ + --out_folder "./outputs/redesigned_residues_multi" \ + --seed 111 +``` + +### 30 --omit_AA_per_residue_multi +Specify which residues need to be omitted when using `--pdb_path_multi` flag. +``` +#{ +#"./inputs/1BC8.pdb": {"C1":"ACDEFGHILMNPQRSTVWY", "C2":"ACDEFGHILMNPQRSTVWY", "C3":"ACDEFGHILMNPQRSTVWY"}, +#"./inputs/4GYT.pdb": {"A7":"ACDEFGHILMNPQRSTVWY", "A8":"ACDEFGHILMNPQRSTVWY"} +#} +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --omit_AA_per_residue_multi "./inputs/omit_AA_per_residue_multi.json" \ + --out_folder "./outputs/omit_AA_per_residue_multi" \ + --seed 111 +``` + +### 31 --bias_AA_per_residue_multi +Specify amino acid biases per residue when using `--pdb_path_multi` flag. +``` +#{ +#"./inputs/1BC8.pdb": {"C1":{"A":3.0, "P":-2.0}, "C2":{"W":10.0, "G":-0.43}}, +#"./inputs/4GYT.pdb": {"A7":{"Y":5.0, "S":-2.0}, "A8":{"M":3.9, "G":-0.43}} +#} +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --bias_AA_per_residue_multi "./inputs/bias_AA_per_residue_multi.json" \ + --out_folder "./outputs/bias_AA_per_residue_multi" \ + --seed 111 +``` + +### 32 --ligand_mpnn_cutoff_for_score +This sets the cutoff distance in angstroms to select residues that are considered to be close to ligand atoms. This flag only affects the `num_ligand_res` and `ligand_confidence` in the output fasta files. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --ligand_mpnn_cutoff_for_score "6.0" \ + --out_folder "./outputs/ligand_mpnn_cutoff_for_score" +``` + +### 33 specifying residues with insertion codes +You can specify residue using chain_id + residue_number + insersion_code; e.g. redesign only residue B82, B82A, B82B, B82C. +``` +python run.py \ + --seed 111 \ + --pdb_path "./inputs/2GFB.pdb" \ + --out_folder "./outputs/insertion_code" \ + --redesigned_residues "B82 B82A B82B B82C" \ + --parse_these_chains_only "B" +``` + +### 34 parse atoms with zero occupancy +Parse atoms in the PDB files with zero occupancy too. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/parse_atoms_with_zero_occupancy" \ + --parse_atoms_with_zero_occupancy 1 +``` + +## Scoring examples +### Output dictionary +``` +out_dict = {} +out_dict["logits"] - raw logits from the model +out_dict["probs"] - softmax(logits) +out_dict["log_probs"] - log_softmax(logits) +out_dict["decoding_order"] - decoding order used (logits will depend on the decoding order) +out_dict["native_sequence"] - parsed input sequence in integers +out_dict["mask"] - mask for missing residues (usually all ones) +out_dict["chain_mask"] - controls which residues are decoded first +out_dict["alphabet"] - amino acid alphabet used +out_dict["residue_names"] - dictionary to map integers to residue_names, e.g. {0: "C10", 1: "C11"} +out_dict["sequence"] - parsed input sequence in alphabet +out_dict["mean_of_probs"] - averaged over batch_size*number_of_batches probabilities, [protein_length, 21] +out_dict["std_of_probs"] - same as above, but std +``` + +### 1 autoregressive with sequence info +Get probabilities/scores for backbone-sequence pairs using autoregressive probabilities: p(AA_1|backbone), p(AA_2|backbone, AA_1) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10. +``` +python score.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --autoregressive_score 1\ + --pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \ + --out_folder "./outputs/autoregressive_score_w_seq" \ + --use_sequence 1\ + --batch_size 1 \ + --number_of_batches 10 +``` +### 2 autoregressive with backbone info only +Get probabilities/scores for backbone using probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10. +``` +python score.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --autoregressive_score 1\ + --pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \ + --out_folder "./outputs/autoregressive_score_wo_seq" \ + --use_sequence 0\ + --batch_size 1 \ + --number_of_batches 10 +``` +### 3 single amino acid score with sequence info +Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone, AA_{all except AA_1}), p(AA_2|backbone, AA_{all except AA_2}) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10. +``` +python score.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --single_aa_score 1\ + --pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \ + --out_folder "./outputs/single_aa_score_w_seq" \ + --use_sequence 1\ + --batch_size 1 \ + --number_of_batches 10 +``` +### 4 single amino acid score with backbone info only +Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10. +``` +python score.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --single_aa_score 1\ + --pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \ + --out_folder "./outputs/single_aa_score_wo_seq" \ + --use_sequence 0\ + --batch_size 1 \ + --number_of_batches 10 +``` + +## Side chain packing examples + +### 1 design a new sequence and pack side chains (return 1 side chain packing sample - fast) +Design a new sequence using any of the available models and also pack side chains of the new sequence. Return only a single solution for the side chain packing. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_default_fast" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 0 \ + --pack_with_ligand_context 1 +``` +### 2 design a new sequence and pack side chains (return 4 side chain packing samples) +Same as above, but returns 4 independent samples for side chains. b-factor shows log prob density per chi angle group. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_default" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 4 \ + --pack_with_ligand_context 1 +``` + +### 3 fix specific residues fors sequence design and packing +This option will not repack side chains of the fixed residues, but use them as a context. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_fixed_residues" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 4 \ + --pack_with_ligand_context 1 \ + --fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \ + --repack_everything 0 +``` +### 4 fix specific residues for sequence design but repack everything +This option will repacks all the residues. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_fixed_residues_full_repack" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 4 \ + --pack_with_ligand_context 1 \ + --fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \ + --repack_everything 1 +``` + +### 5 design a new sequence using LigandMPNN but pack side chains without considering ligand/DNA etc atoms +You can run side chain packing without taking into account context atoms like DNA atoms. This most likely will results in side chain clashing with context atoms, but it might be interesting to see how model's uncertainty changes when ligand atoms are present vs not for side chain conformations. +``` +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_no_context" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 4 \ + --pack_with_ligand_context 0 +``` + +### Things to add +- Support for ProteinMPNN CA-only model. +- Examples for scoring sequences only. +- Side-chain packing scripts. +- TER + + +### Citing this work +If you use the code, please cite: +``` +@article{dauparas2023atomic, + title={Atomic context-conditioned protein sequence design using LigandMPNN}, + author={Dauparas, Justas and Lee, Gyu Rie and Pecoraro, Robert and An, Linna and Anishchenko, Ivan and Glasscock, Cameron and Baker, David}, + journal={Biorxiv}, + pages={2023--12}, + year={2023}, + publisher={Cold Spring Harbor Laboratory} +} + +@article{dauparas2022robust, + title={Robust deep learning--based protein sequence design using ProteinMPNN}, + author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others}, + journal={Science}, + volume={378}, + number={6615}, + pages={49--56}, + year={2022}, + publisher={American Association for the Advancement of Science} +} +``` diff --git a/data_utils.py b/data_utils.py new file mode 100644 index 0000000..8380bd8 --- /dev/null +++ b/data_utils.py @@ -0,0 +1,988 @@ +from __future__ import print_function + +import numpy as np +import torch +import torch.utils +from prody import * + +confProDy(verbosity="none") + +restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", + "X": "UNK", +} +restype_str_to_int = { + "A": 0, + "C": 1, + "D": 2, + "E": 3, + "F": 4, + "G": 5, + "H": 6, + "I": 7, + "K": 8, + "L": 9, + "M": 10, + "N": 11, + "P": 12, + "Q": 13, + "R": 14, + "S": 15, + "T": 16, + "V": 17, + "W": 18, + "Y": 19, + "X": 20, +} +restype_int_to_str = { + 0: "A", + 1: "C", + 2: "D", + 3: "E", + 4: "F", + 5: "G", + 6: "H", + 7: "I", + 8: "K", + 9: "L", + 10: "M", + 11: "N", + 12: "P", + 13: "Q", + 14: "R", + 15: "S", + 16: "T", + 17: "V", + 18: "W", + 19: "Y", + 20: "X", +} +alphabet = list(restype_str_to_int) + +element_list = [ + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mb", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Uut", + "Fl", + "Uup", + "Lv", + "Uus", + "Uuo", +] +element_list = [item.upper() for item in element_list] +# element_dict = dict(zip(element_list, range(1,len(element_list)))) +element_dict_rev = dict(zip(range(1, len(element_list)), element_list)) + + +def get_seq_rec(S: torch.Tensor, S_pred: torch.Tensor, mask: torch.Tensor): + """ + S : true sequence shape=[batch, length] + S_pred : predicted sequence shape=[batch, length] + mask : mask to compute average over the region shape=[batch, length] + + average : averaged sequence recovery shape=[batch] + """ + match = S == S_pred + average = torch.sum(match * mask, dim=-1) / torch.sum(mask, dim=-1) + return average + + +def get_score(S: torch.Tensor, log_probs: torch.Tensor, mask: torch.Tensor): + """ + S : true sequence shape=[batch, length] + log_probs : predicted sequence shape=[batch, length] + mask : mask to compute average over the region shape=[batch, length] + + average_loss : averaged categorical cross entropy (CCE) [batch] + loss_per_resdue : per position CCE [batch, length] + """ + S_one_hot = torch.nn.functional.one_hot(S, 21) + loss_per_residue = -(S_one_hot * log_probs).sum(-1) # [B, L] + average_loss = torch.sum(loss_per_residue * mask, dim=-1) / ( + torch.sum(mask, dim=-1) + 1e-8 + ) + return average_loss, loss_per_residue + + +def write_full_PDB( + save_path: str, + X: np.ndarray, + X_m: np.ndarray, + b_factors: np.ndarray, + R_idx: np.ndarray, + chain_letters: np.ndarray, + S: np.ndarray, + other_atoms=None, + icodes=None, + force_hetatm=False, +): + """ + save_path : path where the PDB will be written to + X : protein atom xyz coordinates shape=[length, 14, 3] + X_m : protein atom mask shape=[length, 14] + b_factors: shape=[length, 14] + R_idx: protein residue indices shape=[length] + chain_letters: protein chain letters shape=[length] + S : protein amino acid sequence shape=[length] + other_atoms: other atoms parsed by prody + icodes: a list of insertion codes for the PDB; e.g. antibody loops + """ + + restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", + "X": "UNK", + } + restype_INTtoSTR = { + 0: "A", + 1: "C", + 2: "D", + 3: "E", + 4: "F", + 5: "G", + 6: "H", + 7: "I", + 8: "K", + 9: "L", + 10: "M", + 11: "N", + 12: "P", + 13: "Q", + 14: "R", + 15: "S", + 16: "T", + 17: "V", + 18: "W", + 19: "Y", + 20: "X", + } + restype_name_to_atom14_names = { + "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], + "ARG": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "NE", + "CZ", + "NH1", + "NH2", + "", + "", + "", + ], + "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""], + "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""], + "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], + "GLN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "NE2", + "", + "", + "", + "", + "", + ], + "GLU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "OE2", + "", + "", + "", + "", + "", + ], + "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], + "HIS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "ND1", + "CD2", + "CE1", + "NE2", + "", + "", + "", + "", + ], + "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""], + "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""], + "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""], + "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""], + "PHE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "", + "", + "", + ], + "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], + "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], + "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""], + "TRP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE2", + "CE3", + "NE1", + "CZ2", + "CZ3", + "CH2", + ], + "TYR": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "OH", + "", + "", + ], + "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""], + "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], + } + + S_str = [restype_1to3[AA] for AA in [restype_INTtoSTR[AA] for AA in S]] + + X_list = [] + b_factor_list = [] + atom_name_list = [] + element_name_list = [] + residue_name_list = [] + residue_number_list = [] + chain_id_list = [] + icodes_list = [] + for i, AA in enumerate(S_str): + sel = X_m[i].astype(np.int32) == 1 + total = np.sum(sel) + tmp = np.array(restype_name_to_atom14_names[AA])[sel] + X_list.append(X[i][sel]) + b_factor_list.append(b_factors[i][sel]) + atom_name_list.append(tmp) + element_name_list += [AA[:1] for AA in list(tmp)] + residue_name_list += total * [AA] + residue_number_list += total * [R_idx[i]] + chain_id_list += total * [chain_letters[i]] + icodes_list += total * [icodes[i]] + + X_stack = np.concatenate(X_list, 0) + b_factor_stack = np.concatenate(b_factor_list, 0) + atom_name_stack = np.concatenate(atom_name_list, 0) + + protein = prody.AtomGroup() + protein.setCoords(X_stack) + protein.setBetas(b_factor_stack) + protein.setNames(atom_name_stack) + protein.setResnames(residue_name_list) + protein.setElements(element_name_list) + protein.setOccupancies(np.ones([X_stack.shape[0]])) + protein.setResnums(residue_number_list) + protein.setChids(chain_id_list) + protein.setIcodes(icodes_list) + + if other_atoms: + other_atoms_g = prody.AtomGroup() + other_atoms_g.setCoords(other_atoms.getCoords()) + other_atoms_g.setNames(other_atoms.getNames()) + other_atoms_g.setResnames(other_atoms.getResnames()) + other_atoms_g.setElements(other_atoms.getElements()) + other_atoms_g.setOccupancies(other_atoms.getOccupancies()) + other_atoms_g.setResnums(other_atoms.getResnums()) + other_atoms_g.setChids(other_atoms.getChids()) + if force_hetatm: + other_atoms_g.setFlags("hetatm", other_atoms.getFlags("hetatm")) + writePDB(save_path, protein + other_atoms_g) + else: + writePDB(save_path, protein) + + +def get_aligned_coordinates(protein_atoms, CA_dict: dict, atom_name: str): + """ + protein_atoms: prody atom group + CA_dict: mapping between chain_residue_idx_icodes and integers + atom_name: atom to be parsed; e.g. CA + """ + atom_atoms = protein_atoms.select(f"name {atom_name}") + + if atom_atoms != None: + atom_coords = atom_atoms.getCoords() + atom_resnums = atom_atoms.getResnums() + atom_chain_ids = atom_atoms.getChids() + atom_icodes = atom_atoms.getIcodes() + + atom_coords_ = np.zeros([len(CA_dict), 3], np.float32) + atom_coords_m = np.zeros([len(CA_dict)], np.int32) + if atom_atoms != None: + for i in range(len(atom_resnums)): + code = atom_chain_ids[i] + "_" + str(atom_resnums[i]) + "_" + atom_icodes[i] + if code in list(CA_dict): + atom_coords_[CA_dict[code], :] = atom_coords[i] + atom_coords_m[CA_dict[code]] = 1 + return atom_coords_, atom_coords_m + + +def parse_PDB( + input_path: str, + device: str = "cpu", + chains: list = [], + parse_all_atoms: bool = False, + parse_atoms_with_zero_occupancy: bool = False +): + """ + input_path : path for the input PDB + device: device for the torch.Tensor + chains: a list specifying which chains need to be parsed; e.g. ["A", "B"] + parse_all_atoms: if False parse only N,CA,C,O otherwise all 37 atoms + parse_atoms_with_zero_occupancy: if True atoms with zero occupancy will be parsed + """ + element_list = [ + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mb", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Uut", + "Fl", + "Uup", + "Lv", + "Uus", + "Uuo", + ] + element_list = [item.upper() for item in element_list] + element_dict = dict(zip(element_list, range(1, len(element_list)))) + restype_3to1 = { + "ALA": "A", + "ARG": "R", + "ASN": "N", + "ASP": "D", + "CYS": "C", + "GLN": "Q", + "GLU": "E", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LEU": "L", + "LYS": "K", + "MET": "M", + "PHE": "F", + "PRO": "P", + "SER": "S", + "THR": "T", + "TRP": "W", + "TYR": "Y", + "VAL": "V", + } + restype_STRtoINT = { + "A": 0, + "C": 1, + "D": 2, + "E": 3, + "F": 4, + "G": 5, + "H": 6, + "I": 7, + "K": 8, + "L": 9, + "M": 10, + "N": 11, + "P": 12, + "Q": 13, + "R": 14, + "S": 15, + "T": 16, + "V": 17, + "W": 18, + "Y": 19, + "X": 20, + } + + atom_order = { + "N": 0, + "CA": 1, + "C": 2, + "CB": 3, + "O": 4, + "CG": 5, + "CG1": 6, + "CG2": 7, + "OG": 8, + "OG1": 9, + "SG": 10, + "CD": 11, + "CD1": 12, + "CD2": 13, + "ND1": 14, + "ND2": 15, + "OD1": 16, + "OD2": 17, + "SD": 18, + "CE": 19, + "CE1": 20, + "CE2": 21, + "CE3": 22, + "NE": 23, + "NE1": 24, + "NE2": 25, + "OE1": 26, + "OE2": 27, + "CH2": 28, + "NH1": 29, + "NH2": 30, + "OH": 31, + "CZ": 32, + "CZ2": 33, + "CZ3": 34, + "NZ": 35, + "OXT": 36, + } + + if not parse_all_atoms: + atom_types = ["N", "CA", "C", "O"] + else: + atom_types = [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + ] + + atoms = parsePDB(input_path) + if not parse_atoms_with_zero_occupancy: + atoms = atoms.select("occupancy > 0") + if chains: + str_out = "" + for item in chains: + str_out += " chain " + item + " or" + atoms = atoms.select(str_out[1:-3]) + + protein_atoms = atoms.select("protein") + backbone = protein_atoms.select("backbone") + other_atoms = atoms.select("not protein and not water") + water_atoms = atoms.select("water") + + CA_atoms = protein_atoms.select("name CA") + CA_resnums = CA_atoms.getResnums() + CA_chain_ids = CA_atoms.getChids() + CA_icodes = CA_atoms.getIcodes() + + CA_dict = {} + for i in range(len(CA_resnums)): + code = CA_chain_ids[i] + "_" + str(CA_resnums[i]) + "_" + CA_icodes[i] + CA_dict[code] = i + + xyz_37 = np.zeros([len(CA_dict), 37, 3], np.float32) + xyz_37_m = np.zeros([len(CA_dict), 37], np.int32) + for atom_name in atom_types: + xyz, xyz_m = get_aligned_coordinates(protein_atoms, CA_dict, atom_name) + xyz_37[:, atom_order[atom_name], :] = xyz + xyz_37_m[:, atom_order[atom_name]] = xyz_m + + N = xyz_37[:, atom_order["N"], :] + CA = xyz_37[:, atom_order["CA"], :] + C = xyz_37[:, atom_order["C"], :] + O = xyz_37[:, atom_order["O"], :] + + N_m = xyz_37_m[:, atom_order["N"]] + CA_m = xyz_37_m[:, atom_order["CA"]] + C_m = xyz_37_m[:, atom_order["C"]] + O_m = xyz_37_m[:, atom_order["O"]] + + mask = N_m * CA_m * C_m * O_m # must all 4 atoms exist + + b = CA - N + c = C - CA + a = np.cross(b, c, axis=-1) + CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA + + chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32) + R_idx = np.array(CA_resnums, dtype=np.int32) + S = CA_atoms.getResnames() + S = [restype_3to1[AA] if AA in list(restype_3to1) else "X" for AA in list(S)] + S = np.array([restype_STRtoINT[AA] for AA in list(S)], np.int32) + X = np.concatenate([N[:, None], CA[:, None], C[:, None], O[:, None]], 1) + + try: + Y = np.array(other_atoms.getCoords(), dtype=np.float32) + Y_t = list(other_atoms.getElements()) + Y_t = np.array( + [ + element_dict[y_t.upper()] if y_t.upper() in element_list else 0 + for y_t in Y_t + ], + dtype=np.int32, + ) + Y_m = (Y_t != 1) * (Y_t != 0) + + Y = Y[Y_m, :] + Y_t = Y_t[Y_m] + Y_m = Y_m[Y_m] + except: + Y = np.zeros([1, 3], np.float32) + Y_t = np.zeros([1], np.int32) + Y_m = np.zeros([1], np.int32) + + output_dict = {} + output_dict["X"] = torch.tensor(X, device=device, dtype=torch.float32) + output_dict["mask"] = torch.tensor(mask, device=device, dtype=torch.int32) + output_dict["Y"] = torch.tensor(Y, device=device, dtype=torch.float32) + output_dict["Y_t"] = torch.tensor(Y_t, device=device, dtype=torch.int32) + output_dict["Y_m"] = torch.tensor(Y_m, device=device, dtype=torch.int32) + + output_dict["R_idx"] = torch.tensor(R_idx, device=device, dtype=torch.int32) + output_dict["chain_labels"] = torch.tensor( + chain_labels, device=device, dtype=torch.int32 + ) + + output_dict["chain_letters"] = CA_chain_ids + + mask_c = [] + chain_list = list(set(output_dict["chain_letters"])) + chain_list.sort() + for chain in chain_list: + mask_c.append( + torch.tensor( + [chain == item for item in output_dict["chain_letters"]], + device=device, + dtype=bool, + ) + ) + + output_dict["mask_c"] = mask_c + output_dict["chain_list"] = chain_list + + output_dict["S"] = torch.tensor(S, device=device, dtype=torch.int32) + + output_dict["xyz_37"] = torch.tensor(xyz_37, device=device, dtype=torch.float32) + output_dict["xyz_37_m"] = torch.tensor(xyz_37_m, device=device, dtype=torch.int32) + + return output_dict, backbone, other_atoms, CA_icodes, water_atoms + + +def get_nearest_neighbours(CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms): + device = CB.device + mask_CBY = mask[:, None] * Y_m[None, :] # [A,B] + L2_AB = torch.sum((CB[:, None, :] - Y[None, :, :]) ** 2, -1) + L2_AB = L2_AB * mask_CBY + (1 - mask_CBY) * 1000.0 + + nn_idx = torch.argsort(L2_AB, -1)[:, :number_of_ligand_atoms] + L2_AB_nn = torch.gather(L2_AB, 1, nn_idx) + D_AB_closest = torch.sqrt(L2_AB_nn[:, 0]) + + Y_r = Y[None, :, :].repeat(CB.shape[0], 1, 1) + Y_t_r = Y_t[None, :].repeat(CB.shape[0], 1) + Y_m_r = Y_m[None, :].repeat(CB.shape[0], 1) + + Y_tmp = torch.gather(Y_r, 1, nn_idx[:, :, None].repeat(1, 1, 3)) + Y_t_tmp = torch.gather(Y_t_r, 1, nn_idx) + Y_m_tmp = torch.gather(Y_m_r, 1, nn_idx) + + Y = torch.zeros( + [CB.shape[0], number_of_ligand_atoms, 3], dtype=torch.float32, device=device + ) + Y_t = torch.zeros( + [CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device + ) + Y_m = torch.zeros( + [CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device + ) + + num_nn_update = Y_tmp.shape[1] + Y[:, :num_nn_update] = Y_tmp + Y_t[:, :num_nn_update] = Y_t_tmp + Y_m[:, :num_nn_update] = Y_m_tmp + + return Y, Y_t, Y_m, D_AB_closest + + +def featurize( + input_dict, + cutoff_for_score=8.0, + use_atom_context=True, + number_of_ligand_atoms=16, + model_type="protein_mpnn", +): + output_dict = {} + if model_type == "ligand_mpnn": + mask = input_dict["mask"] + Y = input_dict["Y"] + Y_t = input_dict["Y_t"] + Y_m = input_dict["Y_m"] + N = input_dict["X"][:, 0, :] + CA = input_dict["X"][:, 1, :] + C = input_dict["X"][:, 2, :] + b = CA - N + c = C - CA + a = torch.cross(b, c, axis=-1) + CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA + Y, Y_t, Y_m, D_XY = get_nearest_neighbours( + CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms + ) + mask_XY = (D_XY < cutoff_for_score) * mask * Y_m[:, 0] + output_dict["mask_XY"] = mask_XY[None,] + if "side_chain_mask" in list(input_dict): + output_dict["side_chain_mask"] = input_dict["side_chain_mask"][None,] + output_dict["Y"] = Y[None,] + output_dict["Y_t"] = Y_t[None,] + output_dict["Y_m"] = Y_m[None,] + if not use_atom_context: + output_dict["Y_m"] = 0.0 * output_dict["Y_m"] + elif ( + model_type == "per_residue_label_membrane_mpnn" + or model_type == "global_label_membrane_mpnn" + ): + output_dict["membrane_per_residue_labels"] = input_dict[ + "membrane_per_residue_labels" + ][None,] + + R_idx_list = [] + count = 0 + R_idx_prev = -100000 + for R_idx in list(input_dict["R_idx"]): + if R_idx_prev == R_idx: + count += 1 + R_idx_list.append(R_idx + count) + R_idx_prev = R_idx + R_idx_renumbered = torch.tensor(R_idx_list, device=R_idx.device) + output_dict["R_idx"] = R_idx_renumbered[None,] + output_dict["R_idx_original"] = input_dict["R_idx"][None,] + output_dict["chain_labels"] = input_dict["chain_labels"][None,] + output_dict["S"] = input_dict["S"][None,] + output_dict["chain_mask"] = input_dict["chain_mask"][None,] + output_dict["mask"] = input_dict["mask"][None,] + + output_dict["X"] = input_dict["X"][None,] + + if "xyz_37" in list(input_dict): + output_dict["xyz_37"] = input_dict["xyz_37"][None,] + output_dict["xyz_37_m"] = input_dict["xyz_37_m"][None,] + + return output_dict diff --git a/get_model_params.sh b/get_model_params.sh new file mode 100644 index 0000000..8a508da --- /dev/null +++ b/get_model_params.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +#make new directory for model parameters +#e.g. bash get_model_params.sh "./model_params" + +mkdir -p $1 + +#Original ProteinMPNN weights +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_002.pt -O $1"/proteinmpnn_v_48_002.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_010.pt -O $1"/proteinmpnn_v_48_010.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt -O $1"/proteinmpnn_v_48_020.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_030.pt -O $1"/proteinmpnn_v_48_030.pt" + +#ProteinMPNN with num_edges=32 +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_002.pt -O $1"/proteinmpnn_v_32_002.pt" +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_010.pt -O $1"/proteinmpnn_v_32_010.pt" +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_020.pt -O $1"/proteinmpnn_v_32_020.pt" +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_030.pt -O $1"/proteinmpnn_v_32_030.pt" + +#LigandMPNN with num_edges=32; atom_context_num=25 +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_25.pt -O $1"/ligandmpnn_v_32_005_25.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt -O $1"/ligandmpnn_v_32_010_25.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_25.pt -O $1"/ligandmpnn_v_32_020_25.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_25.pt -O $1"/ligandmpnn_v_32_030_25.pt" + +#LigandMPNN with num_edges=32; atom_context_num=16 +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_16.pt -O $1"/ligandmpnn_v_32_005_16.pt" +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_16.pt -O $1"/ligandmpnn_v_32_010_16.pt" +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_16.pt -O $1"/ligandmpnn_v_32_020_16.pt" +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_16.pt -O $1"/ligandmpnn_v_32_030_16.pt" + +# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/publication_version_ligandmpnn_v_32_010_25.pt -O $1"/publication_version_ligandmpnn_v_32_010_25.pt" + +#Per residue label membrane ProteinMPNN +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/per_residue_label_membrane_mpnn_v_48_020.pt -O $1"/per_residue_label_membrane_mpnn_v_48_020.pt" + +#Global label membrane ProteinMPNN +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/global_label_membrane_mpnn_v_48_020.pt -O $1"/global_label_membrane_mpnn_v_48_020.pt" + +#SolubleMPNN +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_002.pt -O $1"/solublempnn_v_48_002.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_010.pt -O $1"/solublempnn_v_48_010.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt -O $1"/solublempnn_v_48_020.pt" +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_030.pt -O $1"/solublempnn_v_48_030.pt" + +#LigandMPNN for side-chain packing (multi-step denoising model) +wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_sc_v_32_002_16.pt -O $1"/ligandmpnn_sc_v_32_002_16.pt" diff --git a/main.nf b/main.nf new file mode 100644 index 0000000..e978d38 --- /dev/null +++ b/main.nf @@ -0,0 +1,67 @@ +#!/usr/bin/env nextflow + +nextflow.enable.dsl=2 + +params.pdb = '/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb' +params.outdir = '/mnt/OmicNAS/private/old/olamide/ligandmpnn/output' +params.model_type = 'ligand_mpnn' +params.temperature = 0.1 +params.seed = 111 +params.batch_size = 1 +params.number_of_batches = 1 +params.chains_to_design = '' +params.fixed_residues = '' +params.pack_side_chains = 0 + +process LIGANDMPNN { + container 'ligandmpnn:latest' + containerOptions '--rm --gpus all -v /mnt:/mnt' + publishDir params.outdir, mode: 'copy' + stageInMode 'copy' + + input: + path pdb + + output: + path "${pdb.simpleName}/seqs/*.fa" + path "${pdb.simpleName}/backbones/*.pdb" + path "${pdb.simpleName}/packed/*.pdb", optional: true + path "run.log" + + script: + // Set checkpoint path based on model type + def checkpoint_arg = '' + if (params.model_type == 'ligand_mpnn') { + checkpoint_arg = '--checkpoint_ligand_mpnn /app/LigandMPNN/model_params/ligandmpnn_v_32_010_25.pt' + } else if (params.model_type == 'protein_mpnn') { + checkpoint_arg = '--checkpoint_protein_mpnn /app/LigandMPNN/model_params/proteinmpnn_v_48_020.pt' + } else if (params.model_type == 'soluble_mpnn') { + checkpoint_arg = '--checkpoint_soluble_mpnn /app/LigandMPNN/model_params/solublempnn_v_48_020.pt' + } else if (params.model_type == 'global_label_membrane_mpnn') { + checkpoint_arg = '--checkpoint_global_label_membrane_mpnn /app/LigandMPNN/model_params/global_label_membrane_mpnn_v_48_020.pt' + } else if (params.model_type == 'per_residue_label_membrane_mpnn') { + checkpoint_arg = '--checkpoint_per_residue_label_membrane_mpnn /app/LigandMPNN/model_params/per_residue_label_membrane_mpnn_v_48_020.pt' + } + + """ + mkdir -p ${pdb.simpleName}/seqs ${pdb.simpleName}/backbones ${pdb.simpleName}/packed + + ligandmpnn \\ + --pdb_path \$PWD/${pdb} \\ + --out_folder \$PWD/${pdb.simpleName} \\ + --model_type ${params.model_type} \\ + ${checkpoint_arg} \\ + --temperature ${params.temperature} \\ + --seed ${params.seed} \\ + --batch_size ${params.batch_size} \\ + --number_of_batches ${params.number_of_batches} \\ + --pack_side_chains ${params.pack_side_chains} \\ + ${params.chains_to_design ? "--chains_to_design ${params.chains_to_design}" : ''} \\ + ${params.fixed_residues ? "--fixed_residues \"${params.fixed_residues}\"" : ''} \\ + 2>&1 | tee run.log + """ +} + +workflow { + LIGANDMPNN(Channel.fromPath(params.pdb)) +} diff --git a/model_utils.py b/model_utils.py new file mode 100644 index 0000000..7cc9759 --- /dev/null +++ b/model_utils.py @@ -0,0 +1,1772 @@ +from __future__ import print_function + +import itertools +import sys + +import numpy as np +import torch + + +class ProteinMPNN(torch.nn.Module): + def __init__( + self, + num_letters=21, + node_features=128, + edge_features=128, + hidden_dim=128, + num_encoder_layers=3, + num_decoder_layers=3, + vocab=21, + k_neighbors=48, + augment_eps=0.0, + dropout=0.0, + device=None, + atom_context_num=0, + model_type="protein_mpnn", + ligand_mpnn_use_side_chain_context=False, + ): + super(ProteinMPNN, self).__init__() + + self.model_type = model_type + self.node_features = node_features + self.edge_features = edge_features + self.hidden_dim = hidden_dim + + if self.model_type == "ligand_mpnn": + self.features = ProteinFeaturesLigand( + node_features, + edge_features, + top_k=k_neighbors, + augment_eps=augment_eps, + device=device, + atom_context_num=atom_context_num, + use_side_chains=ligand_mpnn_use_side_chain_context, + ) + self.W_v = torch.nn.Linear(node_features, hidden_dim, bias=True) + self.W_c = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.W_nodes_y = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) + self.W_edges_y = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.V_C = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) + self.V_C_norm = torch.nn.LayerNorm(hidden_dim) + + self.context_encoder_layers = torch.nn.ModuleList( + [ + DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout) + for _ in range(2) + ] + ) + + self.y_context_encoder_layers = torch.nn.ModuleList( + [DecLayerJ(hidden_dim, hidden_dim, dropout=dropout) for _ in range(2)] + ) + elif self.model_type == "protein_mpnn" or self.model_type == "soluble_mpnn": + self.features = ProteinFeatures( + node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps + ) + elif ( + self.model_type == "per_residue_label_membrane_mpnn" + or self.model_type == "global_label_membrane_mpnn" + ): + self.W_v = torch.nn.Linear(node_features, hidden_dim, bias=True) + self.features = ProteinFeaturesMembrane( + node_features, + edge_features, + top_k=k_neighbors, + augment_eps=augment_eps, + num_classes=3, + ) + else: + print("Choose --model_type flag from currently available models") + sys.exit() + + self.W_e = torch.nn.Linear(edge_features, hidden_dim, bias=True) + self.W_s = torch.nn.Embedding(vocab, hidden_dim) + + self.dropout = torch.nn.Dropout(dropout) + + # Encoder layers + self.encoder_layers = torch.nn.ModuleList( + [ + EncLayer(hidden_dim, hidden_dim * 2, dropout=dropout) + for _ in range(num_encoder_layers) + ] + ) + + # Decoder layers + self.decoder_layers = torch.nn.ModuleList( + [ + DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout) + for _ in range(num_decoder_layers) + ] + ) + + self.W_out = torch.nn.Linear(hidden_dim, num_letters, bias=True) + + for p in self.parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) + + def encode(self, feature_dict): + # xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed + # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords + # Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords + # Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type + # Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask + # X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O + S_true = feature_dict[ + "S" + ] # [B,L] - integer protein sequence encoded using "restype_STRtoINT" + # R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index + mask = feature_dict[ + "mask" + ] # [B,L] - mask for missing regions - should be removed! all ones most of the time + # chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters + + B, L = S_true.shape + device = S_true.device + + if self.model_type == "ligand_mpnn": + V, E, E_idx, Y_nodes, Y_edges, Y_m = self.features(feature_dict) + h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device) + h_E = self.W_e(E) + h_E_context = self.W_v(V) + + mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) + mask_attend = mask.unsqueeze(-1) * mask_attend + for layer in self.encoder_layers: + h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) + + h_V_C = self.W_c(h_V) + Y_m_edges = Y_m[:, :, :, None] * Y_m[:, :, None, :] + Y_nodes = self.W_nodes_y(Y_nodes) + Y_edges = self.W_edges_y(Y_edges) + for i in range(len(self.context_encoder_layers)): + Y_nodes = self.y_context_encoder_layers[i]( + Y_nodes, Y_edges, Y_m, Y_m_edges + ) + h_E_context_cat = torch.cat([h_E_context, Y_nodes], -1) + h_V_C = self.context_encoder_layers[i]( + h_V_C, h_E_context_cat, mask, Y_m + ) + + h_V_C = self.V_C(h_V_C) + h_V = h_V + self.V_C_norm(self.dropout(h_V_C)) + elif self.model_type == "protein_mpnn" or self.model_type == "soluble_mpnn": + E, E_idx = self.features(feature_dict) + h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device) + h_E = self.W_e(E) + + mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) + mask_attend = mask.unsqueeze(-1) * mask_attend + for layer in self.encoder_layers: + h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) + elif ( + self.model_type == "per_residue_label_membrane_mpnn" + or self.model_type == "global_label_membrane_mpnn" + ): + V, E, E_idx = self.features(feature_dict) + h_V = self.W_v(V) + h_E = self.W_e(E) + + mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) + mask_attend = mask.unsqueeze(-1) * mask_attend + for layer in self.encoder_layers: + h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) + + return h_V, h_E, E_idx + + def sample(self, feature_dict): + # xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed + # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords + # Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords + # Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type + # Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask + # X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O + B_decoder = feature_dict["batch_size"] + S_true = feature_dict[ + "S" + ] # [B,L] - integer proitein sequence encoded using "restype_STRtoINT" + # R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index + mask = feature_dict[ + "mask" + ] # [B,L] - mask for missing regions - should be removed! all ones most of the time + chain_mask = feature_dict[ + "chain_mask" + ] # [B,L] - mask for which residues need to be fixed; 0.0 - fixed; 1.0 - will be designed + bias = feature_dict["bias"] # [B,L,21] - amino acid bias per position + # chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters + randn = feature_dict[ + "randn" + ] # [B,L] - random numbers for decoding order; only the first entry is used since decoding within a batch needs to match for symmetry + temperature = feature_dict[ + "temperature" + ] # float - sampling temperature; prob = softmax(logits/temperature) + symmetry_list_of_lists = feature_dict[ + "symmetry_residues" + ] # [[0, 1, 14], [10,11,14,15], [20, 21]] #indices to select X over length - L + symmetry_weights_list_of_lists = feature_dict[ + "symmetry_weights" + ] # [[1.0, 1.0, 1.0], [-2.0,1.1,0.2,1.1], [2.3, 1.1]] + + B, L = S_true.shape + device = S_true.device + + h_V, h_E, E_idx = self.encode(feature_dict) + + chain_mask = mask * chain_mask # update chain_M to include missing regions + decoding_order = torch.argsort( + (chain_mask + 0.0001) * (torch.abs(randn)) + ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] + if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1: + E_idx = E_idx.repeat(B_decoder, 1, 1) + permutation_matrix_reverse = torch.nn.functional.one_hot( + decoding_order, num_classes=L + ).float() + order_mask_backward = torch.einsum( + "ij, biq, bjp->bqp", + (1 - torch.triu(torch.ones(L, L, device=device))), + permutation_matrix_reverse, + permutation_matrix_reverse, + ) + mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) + mask_1D = mask.view([B, L, 1, 1]) + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1.0 - mask_attend) + + # repeat for decoding + S_true = S_true.repeat(B_decoder, 1) + h_V = h_V.repeat(B_decoder, 1, 1) + h_E = h_E.repeat(B_decoder, 1, 1, 1) + chain_mask = chain_mask.repeat(B_decoder, 1) + mask = mask.repeat(B_decoder, 1) + bias = bias.repeat(B_decoder, 1, 1) + + all_probs = torch.zeros( + (B_decoder, L, 20), device=device, dtype=torch.float32 + ) + all_log_probs = torch.zeros( + (B_decoder, L, 21), device=device, dtype=torch.float32 + ) + h_S = torch.zeros_like(h_V, device=device) + S = 20 * torch.ones((B_decoder, L), dtype=torch.int64, device=device) + h_V_stack = [h_V] + [ + torch.zeros_like(h_V, device=device) + for _ in range(len(self.decoder_layers)) + ] + + h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) + h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) + h_EXV_encoder_fw = mask_fw * h_EXV_encoder + + for t_ in range(L): + t = decoding_order[:, t_] # [B] + chain_mask_t = torch.gather(chain_mask, 1, t[:, None])[:, 0] # [B] + mask_t = torch.gather(mask, 1, t[:, None])[:, 0] # [B] + bias_t = torch.gather(bias, 1, t[:, None, None].repeat(1, 1, 21))[ + :, 0, : + ] # [B,21] + + E_idx_t = torch.gather( + E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1]) + ) + h_E_t = torch.gather( + h_E, + 1, + t[:, None, None, None].repeat(1, 1, h_E.shape[-2], h_E.shape[-1]), + ) + h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) + h_EXV_encoder_t = torch.gather( + h_EXV_encoder_fw, + 1, + t[:, None, None, None].repeat( + 1, 1, h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1] + ), + ) + + mask_bw_t = torch.gather( + mask_bw, + 1, + t[:, None, None, None].repeat( + 1, 1, mask_bw.shape[-2], mask_bw.shape[-1] + ), + ) + + for l, layer in enumerate(self.decoder_layers): + h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) + h_V_t = torch.gather( + h_V_stack[l], + 1, + t[:, None, None].repeat(1, 1, h_V_stack[l].shape[-1]), + ) + h_ESV_t = mask_bw_t * h_ESV_decoder_t + h_EXV_encoder_t + h_V_stack[l + 1].scatter_( + 1, + t[:, None, None].repeat(1, 1, h_V.shape[-1]), + layer(h_V_t, h_ESV_t, mask_V=mask_t), + ) + + h_V_t = torch.gather( + h_V_stack[-1], + 1, + t[:, None, None].repeat(1, 1, h_V_stack[-1].shape[-1]), + )[:, 0] + logits = self.W_out(h_V_t) # [B,21] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # [B,21] + + probs = torch.nn.functional.softmax( + (logits + bias_t) / temperature, dim=-1 + ) # [B,21] + probs_sample = probs[:, :20] / torch.sum( + probs[:, :20], dim=-1, keepdim=True + ) # hard omit X #[B,20] + S_t = torch.multinomial(probs_sample, 1)[:, 0] # [B] + + all_probs.scatter_( + 1, + t[:, None, None].repeat(1, 1, 20), + (chain_mask_t[:, None, None] * probs_sample[:, None, :]).float(), + ) + all_log_probs.scatter_( + 1, + t[:, None, None].repeat(1, 1, 21), + (chain_mask_t[:, None, None] * log_probs[:, None, :]).float(), + ) + S_true_t = torch.gather(S_true, 1, t[:, None])[:, 0] + S_t = (S_t * chain_mask_t + S_true_t * (1.0 - chain_mask_t)).long() + h_S.scatter_( + 1, + t[:, None, None].repeat(1, 1, h_S.shape[-1]), + self.W_s(S_t)[:, None, :], + ) + S.scatter_(1, t[:, None], S_t[:, None]) + + output_dict = { + "S": S, + "sampling_probs": all_probs, + "log_probs": all_log_probs, + "decoding_order": decoding_order, + } + else: + # weights for symmetric design + symmetry_weights = torch.ones([L], device=device, dtype=torch.float32) + for i1, item_list in enumerate(symmetry_list_of_lists): + for i2, item in enumerate(item_list): + symmetry_weights[item] = symmetry_weights_list_of_lists[i1][i2] + + new_decoding_order = [] + for t_dec in list(decoding_order[0,].cpu().data.numpy()): + if t_dec not in list(itertools.chain(*new_decoding_order)): + list_a = [item for item in symmetry_list_of_lists if t_dec in item] + if list_a: + new_decoding_order.append(list_a[0]) + else: + new_decoding_order.append([t_dec]) + + decoding_order = torch.tensor( + list(itertools.chain(*new_decoding_order)), device=device + )[None,].repeat(B, 1) + + permutation_matrix_reverse = torch.nn.functional.one_hot( + decoding_order, num_classes=L + ).float() + order_mask_backward = torch.einsum( + "ij, biq, bjp->bqp", + (1 - torch.triu(torch.ones(L, L, device=device))), + permutation_matrix_reverse, + permutation_matrix_reverse, + ) + mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) + mask_1D = mask.view([B, L, 1, 1]) + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1.0 - mask_attend) + + # repeat for decoding + S_true = S_true.repeat(B_decoder, 1) + h_V = h_V.repeat(B_decoder, 1, 1) + h_E = h_E.repeat(B_decoder, 1, 1, 1) + E_idx = E_idx.repeat(B_decoder, 1, 1) + mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1) + mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1) + chain_mask = chain_mask.repeat(B_decoder, 1) + mask = mask.repeat(B_decoder, 1) + bias = bias.repeat(B_decoder, 1, 1) + + all_probs = torch.zeros( + (B_decoder, L, 20), device=device, dtype=torch.float32 + ) + all_log_probs = torch.zeros( + (B_decoder, L, 21), device=device, dtype=torch.float32 + ) + h_S = torch.zeros_like(h_V, device=device) + S = 20 * torch.ones((B_decoder, L), dtype=torch.int64, device=device) + h_V_stack = [h_V] + [ + torch.zeros_like(h_V, device=device) + for _ in range(len(self.decoder_layers)) + ] + + h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) + h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) + h_EXV_encoder_fw = mask_fw * h_EXV_encoder + + for t_list in new_decoding_order: + total_logits = 0.0 + for t in t_list: + chain_mask_t = chain_mask[:, t] # [B] + mask_t = mask[:, t] # [B] + bias_t = bias[:, t] # [B, 21] + + E_idx_t = E_idx[:, t : t + 1] + h_E_t = h_E[:, t : t + 1] + h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) + h_EXV_encoder_t = h_EXV_encoder_fw[:, t : t + 1] + for l, layer in enumerate(self.decoder_layers): + h_ESV_decoder_t = cat_neighbors_nodes( + h_V_stack[l], h_ES_t, E_idx_t + ) + h_V_t = h_V_stack[l][:, t : t + 1] + h_ESV_t = ( + mask_bw[:, t : t + 1] * h_ESV_decoder_t + h_EXV_encoder_t + ) + h_V_stack[l + 1][:, t : t + 1, :] = layer( + h_V_t, h_ESV_t, mask_V=mask_t[:, None] + ) + + h_V_t = h_V_stack[-1][:, t] + logits = self.W_out(h_V_t) # [B,21] + log_probs = torch.nn.functional.log_softmax( + logits, dim=-1 + ) # [B,21] + all_log_probs[:, t] = ( + chain_mask_t[:, None] * log_probs + ).float() # [B,21] + total_logits += symmetry_weights[t] * logits + + probs = torch.nn.functional.softmax( + (total_logits + bias_t) / temperature, dim=-1 + ) # [B,21] + probs_sample = probs[:, :20] / torch.sum( + probs[:, :20], dim=-1, keepdim=True + ) # hard omit X #[B,20] + S_t = torch.multinomial(probs_sample, 1)[:, 0] # [B] + for t in t_list: + chain_mask_t = chain_mask[:, t] # [B] + all_probs[:, t] = ( + chain_mask_t[:, None] * probs_sample + ).float() # [B,20] + S_true_t = S_true[:, t] # [B] + S_t = (S_t * chain_mask_t + S_true_t * (1.0 - chain_mask_t)).long() + h_S[:, t] = self.W_s(S_t) + S[:, t] = S_t + + output_dict = { + "S": S, + "sampling_probs": all_probs, + "log_probs": all_log_probs, + "decoding_order": decoding_order.repeat(B_decoder, 1), + } + return output_dict + + def single_aa_score(self, feature_dict, use_sequence: bool): + """ + feature_dict - input features + use_sequence - False using backbone info only + """ + B_decoder = feature_dict["batch_size"] + S_true_enc = feature_dict[ + "S" + ] + mask_enc = feature_dict[ + "mask" + ] + chain_mask_enc = feature_dict[ + "chain_mask" + ] + randn = feature_dict[ + "randn" + ] + B, L = S_true_enc.shape + device = S_true_enc.device + + h_V_enc, h_E_enc, E_idx_enc = self.encode(feature_dict) + log_probs_out = torch.zeros([B_decoder, L, 21], device=device).float() + logits_out = torch.zeros([B_decoder, L, 21], device=device).float() + decoding_order_out = torch.zeros([B_decoder, L, L], device=device).float() + + for idx in range(L): + h_V = torch.clone(h_V_enc) + E_idx = torch.clone(E_idx_enc) + mask = torch.clone(mask_enc) + S_true = torch.clone(S_true_enc) + if not use_sequence: + order_mask = torch.zeros(chain_mask_enc.shape[1], device=device).float() + order_mask[idx] = 1. + else: + order_mask = torch.ones(chain_mask_enc.shape[1], device=device).float() + order_mask[idx] = 0. + decoding_order = torch.argsort( + (order_mask + 0.0001) * (torch.abs(randn)) + ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] + E_idx = E_idx.repeat(B_decoder, 1, 1) + permutation_matrix_reverse = torch.nn.functional.one_hot( + decoding_order, num_classes=L + ).float() + order_mask_backward = torch.einsum( + "ij, biq, bjp->bqp", + (1 - torch.triu(torch.ones(L, L, device=device))), + permutation_matrix_reverse, + permutation_matrix_reverse, + ) + mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) + mask_1D = mask.view([B, L, 1, 1]) + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1.0 - mask_attend) + S_true = S_true.repeat(B_decoder, 1) + h_V = h_V.repeat(B_decoder, 1, 1) + h_E = h_E_enc.repeat(B_decoder, 1, 1, 1) + mask = mask.repeat(B_decoder, 1) + + h_S = self.W_s(S_true) + h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) + + # Build encoder embeddings + h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) + h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) + + h_EXV_encoder_fw = mask_fw * h_EXV_encoder + for layer in self.decoder_layers: + # Masked positions attend to encoder information, unmasked see. + h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) + h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw + h_V = layer(h_V, h_ESV, mask) + + logits = self.W_out(h_V) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + log_probs_out[:,idx,:] = log_probs[:,idx,:] + logits_out[:,idx,:] = logits[:,idx,:] + decoding_order_out[:,idx,:] = decoding_order + + output_dict = { + "S": S_true, + "log_probs": log_probs_out, + "logits": logits_out, + "decoding_order": decoding_order_out, + } + return output_dict + + + def score(self, feature_dict, use_sequence: bool): + B_decoder = feature_dict["batch_size"] + S_true = feature_dict[ + "S" + ] + mask = feature_dict[ + "mask" + ] + chain_mask = feature_dict[ + "chain_mask" + ] + randn = feature_dict[ + "randn" + ] + symmetry_list_of_lists = feature_dict[ + "symmetry_residues" + ] + B, L = S_true.shape + device = S_true.device + + h_V, h_E, E_idx = self.encode(feature_dict) + + chain_mask = mask * chain_mask # update chain_M to include missing regions + decoding_order = torch.argsort( + (chain_mask + 0.0001) * (torch.abs(randn)) + ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] + if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1: + E_idx = E_idx.repeat(B_decoder, 1, 1) + permutation_matrix_reverse = torch.nn.functional.one_hot( + decoding_order, num_classes=L + ).float() + order_mask_backward = torch.einsum( + "ij, biq, bjp->bqp", + (1 - torch.triu(torch.ones(L, L, device=device))), + permutation_matrix_reverse, + permutation_matrix_reverse, + ) + mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) + mask_1D = mask.view([B, L, 1, 1]) + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1.0 - mask_attend) + else: + new_decoding_order = [] + for t_dec in list(decoding_order[0,].cpu().data.numpy()): + if t_dec not in list(itertools.chain(*new_decoding_order)): + list_a = [item for item in symmetry_list_of_lists if t_dec in item] + if list_a: + new_decoding_order.append(list_a[0]) + else: + new_decoding_order.append([t_dec]) + + decoding_order = torch.tensor( + list(itertools.chain(*new_decoding_order)), device=device + )[None,].repeat(B, 1) + + permutation_matrix_reverse = torch.nn.functional.one_hot( + decoding_order, num_classes=L + ).float() + order_mask_backward = torch.einsum( + "ij, biq, bjp->bqp", + (1 - torch.triu(torch.ones(L, L, device=device))), + permutation_matrix_reverse, + permutation_matrix_reverse, + ) + mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) + mask_1D = mask.view([B, L, 1, 1]) + mask_bw = mask_1D * mask_attend + mask_fw = mask_1D * (1.0 - mask_attend) + + E_idx = E_idx.repeat(B_decoder, 1, 1) + mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1) + mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1) + decoding_order = decoding_order.repeat(B_decoder, 1) + + S_true = S_true.repeat(B_decoder, 1) + h_V = h_V.repeat(B_decoder, 1, 1) + h_E = h_E.repeat(B_decoder, 1, 1, 1) + mask = mask.repeat(B_decoder, 1) + + h_S = self.W_s(S_true) + h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) + + # Build encoder embeddings + h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) + h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) + + h_EXV_encoder_fw = mask_fw * h_EXV_encoder + if not use_sequence: + for layer in self.decoder_layers: + h_V = layer(h_V, h_EXV_encoder_fw, mask) + else: + for layer in self.decoder_layers: + # Masked positions attend to encoder information, unmasked see. + h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) + h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw + h_V = layer(h_V, h_ESV, mask) + + logits = self.W_out(h_V) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + output_dict = { + "S": S_true, + "log_probs": log_probs, + "logits": logits, + "decoding_order": decoding_order, + } + return output_dict + + +class ProteinFeaturesLigand(torch.nn.Module): + def __init__( + self, + edge_features, + node_features, + num_positional_embeddings=16, + num_rbf=16, + top_k=30, + augment_eps=0.0, + device=None, + atom_context_num=16, + use_side_chains=False, + ): + """Extract protein features""" + super(ProteinFeaturesLigand, self).__init__() + + self.use_side_chains = use_side_chains + + self.edge_features = edge_features + self.node_features = node_features + self.top_k = top_k + self.augment_eps = augment_eps + self.num_rbf = num_rbf + self.num_positional_embeddings = num_positional_embeddings + + self.embeddings = PositionalEncodings(num_positional_embeddings) + edge_in = num_positional_embeddings + num_rbf * 25 + self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False) + self.norm_edges = torch.nn.LayerNorm(edge_features) + + self.node_project_down = torch.nn.Linear( + 5 * num_rbf + 64 + 4, node_features, bias=True + ) + self.norm_nodes = torch.nn.LayerNorm(node_features) + + self.type_linear = torch.nn.Linear(147, 64) + + self.y_nodes = torch.nn.Linear(147, node_features, bias=False) + self.y_edges = torch.nn.Linear(num_rbf, node_features, bias=False) + + self.norm_y_edges = torch.nn.LayerNorm(node_features) + self.norm_y_nodes = torch.nn.LayerNorm(node_features) + + self.atom_context_num = atom_context_num + + # the last 32 atoms in the 37 atom representation + self.side_chain_atom_types = torch.tensor( + [ + 6, + 6, + 6, + 8, + 8, + 16, + 6, + 6, + 6, + 7, + 7, + 8, + 8, + 16, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 8, + 8, + 6, + 7, + 7, + 8, + 6, + 6, + 6, + 7, + 8, + ], + device=device, + ) + + self.periodic_table_features = torch.tensor( + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + ], + [ + 0, + 1, + 18, + 1, + 2, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + ], + [ + 0, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + ], + ], + dtype=torch.long, + device=device, + ) + + def _make_angle_features(self, A, B, C, Y): + v1 = A - B + v2 = C - B + e1 = torch.nn.functional.normalize(v1, dim=-1) + e1_v2_dot = torch.einsum("bli, bli -> bl", e1, v2)[..., None] + u2 = v2 - e1 * e1_v2_dot + e2 = torch.nn.functional.normalize(u2, dim=-1) + e3 = torch.cross(e1, e2, dim=-1) + R_residue = torch.cat( + (e1[:, :, :, None], e2[:, :, :, None], e3[:, :, :, None]), dim=-1 + ) + + local_vectors = torch.einsum( + "blqp, blyq -> blyp", R_residue, Y - B[:, :, None, :] + ) + + rxy = torch.sqrt(local_vectors[..., 0] ** 2 + local_vectors[..., 1] ** 2 + 1e-8) + f1 = local_vectors[..., 0] / rxy + f2 = local_vectors[..., 1] / rxy + rxyz = torch.norm(local_vectors, dim=-1) + 1e-8 + f3 = rxy / rxyz + f4 = local_vectors[..., 2] / rxyz + + f = torch.cat([f1[..., None], f2[..., None], f3[..., None], f4[..., None]], -1) + return f + + def _dist(self, X, mask, eps=1e-6): + mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) + dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2) + D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) + D_max, _ = torch.max(D, -1, keepdim=True) + D_adjust = D + (1.0 - mask_2D) * D_max + D_neighbors, E_idx = torch.topk( + D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False + ) + return D_neighbors, E_idx + + def _rbf(self, D): + device = D.device + D_min, D_max, D_count = 2.0, 22.0, self.num_rbf + D_mu = torch.linspace(D_min, D_max, D_count, device=device) + D_mu = D_mu.view([1, 1, 1, -1]) + D_sigma = (D_max - D_min) / D_count + D_expand = torch.unsqueeze(D, -1) + RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) + return RBF + + def _get_rbf(self, A, B, E_idx): + D_A_B = torch.sqrt( + torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6 + ) # [B, L, L] + D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[ + :, :, :, 0 + ] # [B,L,K] + RBF_A_B = self._rbf(D_A_B_neighbors) + return RBF_A_B + + def forward(self, input_features): + Y = input_features["Y"] + Y_m = input_features["Y_m"] + Y_t = input_features["Y_t"] + X = input_features["X"] + mask = input_features["mask"] + R_idx = input_features["R_idx"] + chain_labels = input_features["chain_labels"] + + if self.augment_eps > 0: + X = X + self.augment_eps * torch.randn_like(X) + Y = Y + self.augment_eps * torch.randn_like(Y) + + B, L, _, _ = X.shape + + Ca = X[:, :, 1, :] + N = X[:, :, 0, :] + C = X[:, :, 2, :] + O = X[:, :, 3, :] + + b = Ca - N + c = C - Ca + a = torch.cross(b, c, dim=-1) + Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca # shift from CA + + D_neighbors, E_idx = self._dist(Ca, mask) + + RBF_all = [] + RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca + RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N + RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C + RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O + RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb + RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N + RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C + RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O + RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb + RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C + RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O + RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb + RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C + RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O + RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C + RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca + RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca + RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca + RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca + RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N + RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N + RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N + RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb + RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb + RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O + RBF_all = torch.cat(tuple(RBF_all), dim=-1) + + offset = R_idx[:, :, None] - R_idx[:, None, :] + offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K] + + d_chains = ( + (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0 + ).long() # find self vs non-self interaction + E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0] + E_positional = self.embeddings(offset.long(), E_chains) + E = torch.cat((E_positional, RBF_all), -1) + E = self.edge_embedding(E) + E = self.norm_edges(E) + + if self.use_side_chains: + xyz_37 = input_features["xyz_37"] + xyz_37_m = input_features["xyz_37_m"] + E_idx_sub = E_idx[:, :, :16] # [B, L, 15] + mask_residues = input_features["chain_mask"] + xyz_37_m = xyz_37_m * (1 - mask_residues[:, :, None]) + R_m = gather_nodes(xyz_37_m[:, :, 5:], E_idx_sub) + + X_sidechain = xyz_37[:, :, 5:, :].view(B, L, -1) + R = gather_nodes(X_sidechain, E_idx_sub).view( + B, L, E_idx_sub.shape[2], -1, 3 + ) + R_t = self.side_chain_atom_types[None, None, None, :].repeat( + B, L, E_idx_sub.shape[2], 1 + ) + + # Side chain atom context + R = R.view(B, L, -1, 3) # coordinates + R_m = R_m.view(B, L, -1) # mask + R_t = R_t.view(B, L, -1) # atom types + + # Ligand atom context + Y = torch.cat((R, Y), 2) # [B, L, atoms, 3] + Y_m = torch.cat((R_m, Y_m), 2) # [B, L, atoms] + Y_t = torch.cat((R_t, Y_t), 2) # [B, L, atoms] + + Cb_Y_distances = torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) + mask_Y = mask[:, :, None] * Y_m + Cb_Y_distances_adjusted = Cb_Y_distances * mask_Y + (1.0 - mask_Y) * 10000.0 + _, E_idx_Y = torch.topk( + Cb_Y_distances_adjusted, self.atom_context_num, dim=-1, largest=False + ) + + Y = torch.gather(Y, 2, E_idx_Y[:, :, :, None].repeat(1, 1, 1, 3)) + Y_t = torch.gather(Y_t, 2, E_idx_Y) + Y_m = torch.gather(Y_m, 2, E_idx_Y) + + Y_t = Y_t.long() + Y_t_g = self.periodic_table_features[1][Y_t] # group; 19 categories including 0 + Y_t_p = self.periodic_table_features[2][Y_t] # period; 8 categories including 0 + + Y_t_g_1hot_ = torch.nn.functional.one_hot(Y_t_g, 19) # [B, L, M, 19] + Y_t_p_1hot_ = torch.nn.functional.one_hot(Y_t_p, 8) # [B, L, M, 8] + Y_t_1hot_ = torch.nn.functional.one_hot(Y_t, 120) # [B, L, M, 120] + + Y_t_1hot_ = torch.cat( + [Y_t_1hot_, Y_t_g_1hot_, Y_t_p_1hot_], -1 + ) # [B, L, M, 147] + Y_t_1hot = self.type_linear(Y_t_1hot_.float()) + + D_N_Y = self._rbf( + torch.sqrt(torch.sum((N[:, :, None, :] - Y) ** 2, -1) + 1e-6) + ) # [B, L, M, num_bins] + D_Ca_Y = self._rbf( + torch.sqrt(torch.sum((Ca[:, :, None, :] - Y) ** 2, -1) + 1e-6) + ) + D_C_Y = self._rbf(torch.sqrt(torch.sum((C[:, :, None, :] - Y) ** 2, -1) + 1e-6)) + D_O_Y = self._rbf(torch.sqrt(torch.sum((O[:, :, None, :] - Y) ** 2, -1) + 1e-6)) + D_Cb_Y = self._rbf( + torch.sqrt(torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) + 1e-6) + ) + + f_angles = self._make_angle_features(N, Ca, C, Y) # [B, L, M, 4] + + D_all = torch.cat( + (D_N_Y, D_Ca_Y, D_C_Y, D_O_Y, D_Cb_Y, Y_t_1hot, f_angles), dim=-1 + ) # [B,L,M,5*num_bins+5] + V = self.node_project_down(D_all) # [B, L, M, node_features] + V = self.norm_nodes(V) + + Y_edges = self._rbf( + torch.sqrt( + torch.sum((Y[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6 + ) + ) # [B, L, M, M, num_bins] + + Y_edges = self.y_edges(Y_edges) + Y_nodes = self.y_nodes(Y_t_1hot_.float()) + + Y_edges = self.norm_y_edges(Y_edges) + Y_nodes = self.norm_y_nodes(Y_nodes) + + return V, E, E_idx, Y_nodes, Y_edges, Y_m + + +class ProteinFeatures(torch.nn.Module): + def __init__( + self, + edge_features, + node_features, + num_positional_embeddings=16, + num_rbf=16, + top_k=48, + augment_eps=0.0, + ): + """Extract protein features""" + super(ProteinFeatures, self).__init__() + self.edge_features = edge_features + self.node_features = node_features + self.top_k = top_k + self.augment_eps = augment_eps + self.num_rbf = num_rbf + self.num_positional_embeddings = num_positional_embeddings + + self.embeddings = PositionalEncodings(num_positional_embeddings) + edge_in = num_positional_embeddings + num_rbf * 25 + self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False) + self.norm_edges = torch.nn.LayerNorm(edge_features) + + def _dist(self, X, mask, eps=1e-6): + mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) + dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2) + D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) + D_max, _ = torch.max(D, -1, keepdim=True) + D_adjust = D + (1.0 - mask_2D) * D_max + D_neighbors, E_idx = torch.topk( + D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False + ) + return D_neighbors, E_idx + + def _rbf(self, D): + device = D.device + D_min, D_max, D_count = 2.0, 22.0, self.num_rbf + D_mu = torch.linspace(D_min, D_max, D_count, device=device) + D_mu = D_mu.view([1, 1, 1, -1]) + D_sigma = (D_max - D_min) / D_count + D_expand = torch.unsqueeze(D, -1) + RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) + return RBF + + def _get_rbf(self, A, B, E_idx): + D_A_B = torch.sqrt( + torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6 + ) # [B, L, L] + D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[ + :, :, :, 0 + ] # [B,L,K] + RBF_A_B = self._rbf(D_A_B_neighbors) + return RBF_A_B + + def forward(self, input_features): + X = input_features["X"] + mask = input_features["mask"] + R_idx = input_features["R_idx"] + chain_labels = input_features["chain_labels"] + + if self.augment_eps > 0: + X = X + self.augment_eps * torch.randn_like(X) + + b = X[:, :, 1, :] - X[:, :, 0, :] + c = X[:, :, 2, :] - X[:, :, 1, :] + a = torch.cross(b, c, dim=-1) + Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :] + Ca = X[:, :, 1, :] + N = X[:, :, 0, :] + C = X[:, :, 2, :] + O = X[:, :, 3, :] + + D_neighbors, E_idx = self._dist(Ca, mask) + + RBF_all = [] + RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca + RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N + RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C + RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O + RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb + RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N + RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C + RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O + RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb + RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C + RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O + RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb + RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C + RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O + RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C + RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca + RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca + RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca + RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca + RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N + RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N + RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N + RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb + RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb + RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O + RBF_all = torch.cat(tuple(RBF_all), dim=-1) + + offset = R_idx[:, :, None] - R_idx[:, None, :] + offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K] + + d_chains = ( + (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0 + ).long() # find self vs non-self interaction + E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0] + E_positional = self.embeddings(offset.long(), E_chains) + E = torch.cat((E_positional, RBF_all), -1) + E = self.edge_embedding(E) + E = self.norm_edges(E) + + return E, E_idx + + +class ProteinFeaturesMembrane(torch.nn.Module): + def __init__( + self, + edge_features, + node_features, + num_positional_embeddings=16, + num_rbf=16, + top_k=48, + augment_eps=0.0, + num_classes=3, + ): + """Extract protein features""" + super(ProteinFeaturesMembrane, self).__init__() + self.edge_features = edge_features + self.node_features = node_features + self.top_k = top_k + self.augment_eps = augment_eps + self.num_rbf = num_rbf + self.num_positional_embeddings = num_positional_embeddings + self.num_classes = num_classes + + self.embeddings = PositionalEncodings(num_positional_embeddings) + edge_in = num_positional_embeddings + num_rbf * 25 + self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False) + self.norm_edges = torch.nn.LayerNorm(edge_features) + + self.node_embedding = torch.nn.Linear( + self.num_classes, node_features, bias=False + ) + self.norm_nodes = torch.nn.LayerNorm(node_features) + + def _dist(self, X, mask, eps=1e-6): + mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) + dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2) + D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) + D_max, _ = torch.max(D, -1, keepdim=True) + D_adjust = D + (1.0 - mask_2D) * D_max + D_neighbors, E_idx = torch.topk( + D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False + ) + return D_neighbors, E_idx + + def _rbf(self, D): + device = D.device + D_min, D_max, D_count = 2.0, 22.0, self.num_rbf + D_mu = torch.linspace(D_min, D_max, D_count, device=device) + D_mu = D_mu.view([1, 1, 1, -1]) + D_sigma = (D_max - D_min) / D_count + D_expand = torch.unsqueeze(D, -1) + RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) + return RBF + + def _get_rbf(self, A, B, E_idx): + D_A_B = torch.sqrt( + torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6 + ) # [B, L, L] + D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[ + :, :, :, 0 + ] # [B,L,K] + RBF_A_B = self._rbf(D_A_B_neighbors) + return RBF_A_B + + def forward(self, input_features): + X = input_features["X"] + mask = input_features["mask"] + R_idx = input_features["R_idx"] + chain_labels = input_features["chain_labels"] + membrane_per_residue_labels = input_features["membrane_per_residue_labels"] + + if self.augment_eps > 0: + X = X + self.augment_eps * torch.randn_like(X) + + b = X[:, :, 1, :] - X[:, :, 0, :] + c = X[:, :, 2, :] - X[:, :, 1, :] + a = torch.cross(b, c, dim=-1) + Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :] + Ca = X[:, :, 1, :] + N = X[:, :, 0, :] + C = X[:, :, 2, :] + O = X[:, :, 3, :] + + D_neighbors, E_idx = self._dist(Ca, mask) + + RBF_all = [] + RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca + RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N + RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C + RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O + RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb + RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N + RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C + RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O + RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb + RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C + RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O + RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb + RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C + RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O + RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C + RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca + RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca + RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca + RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca + RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N + RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N + RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N + RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb + RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb + RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O + RBF_all = torch.cat(tuple(RBF_all), dim=-1) + + offset = R_idx[:, :, None] - R_idx[:, None, :] + offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K] + + d_chains = ( + (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0 + ).long() # find self vs non-self interaction + E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0] + E_positional = self.embeddings(offset.long(), E_chains) + E = torch.cat((E_positional, RBF_all), -1) + E = self.edge_embedding(E) + E = self.norm_edges(E) + + C_1hot = torch.nn.functional.one_hot( + membrane_per_residue_labels, self.num_classes + ).float() + V = self.node_embedding(C_1hot) + V = self.norm_nodes(V) + + return V, E, E_idx + + +class DecLayerJ(torch.nn.Module): + def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): + super(DecLayerJ, self).__init__() + self.num_hidden = num_hidden + self.num_in = num_in + self.scale = scale + self.dropout1 = torch.nn.Dropout(dropout) + self.dropout2 = torch.nn.Dropout(dropout) + self.norm1 = torch.nn.LayerNorm(num_hidden) + self.norm2 = torch.nn.LayerNorm(num_hidden) + + self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True) + self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True) + self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True) + self.act = torch.nn.GELU() + self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) + + def forward(self, h_V, h_E, mask_V=None, mask_attend=None): + """Parallel computation of full transformer layer""" + + # Concatenate h_V_i to h_E_ij + h_V_expand = h_V.unsqueeze(-2).expand( + -1, -1, -1, h_E.size(-2), -1 + ) # the only difference + h_EV = torch.cat([h_V_expand, h_E], -1) + + h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV))))) + if mask_attend is not None: + h_message = mask_attend.unsqueeze(-1) * h_message + dh = torch.sum(h_message, -2) / self.scale + + h_V = self.norm1(h_V + self.dropout1(dh)) + + # Position-wise feedforward + dh = self.dense(h_V) + h_V = self.norm2(h_V + self.dropout2(dh)) + + if mask_V is not None: + mask_V = mask_V.unsqueeze(-1) + h_V = mask_V * h_V + return h_V + + +class PositionWiseFeedForward(torch.nn.Module): + def __init__(self, num_hidden, num_ff): + super(PositionWiseFeedForward, self).__init__() + self.W_in = torch.nn.Linear(num_hidden, num_ff, bias=True) + self.W_out = torch.nn.Linear(num_ff, num_hidden, bias=True) + self.act = torch.nn.GELU() + + def forward(self, h_V): + h = self.act(self.W_in(h_V)) + h = self.W_out(h) + return h + + +class PositionalEncodings(torch.nn.Module): + def __init__(self, num_embeddings, max_relative_feature=32): + super(PositionalEncodings, self).__init__() + self.num_embeddings = num_embeddings + self.max_relative_feature = max_relative_feature + self.linear = torch.nn.Linear(2 * max_relative_feature + 1 + 1, num_embeddings) + + def forward(self, offset, mask): + d = torch.clip( + offset + self.max_relative_feature, 0, 2 * self.max_relative_feature + ) * mask + (1 - mask) * (2 * self.max_relative_feature + 1) + d_onehot = torch.nn.functional.one_hot(d, 2 * self.max_relative_feature + 1 + 1) + E = self.linear(d_onehot.float()) + return E + + +class DecLayer(torch.nn.Module): + def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): + super(DecLayer, self).__init__() + self.num_hidden = num_hidden + self.num_in = num_in + self.scale = scale + self.dropout1 = torch.nn.Dropout(dropout) + self.dropout2 = torch.nn.Dropout(dropout) + self.norm1 = torch.nn.LayerNorm(num_hidden) + self.norm2 = torch.nn.LayerNorm(num_hidden) + + self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True) + self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True) + self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True) + self.act = torch.nn.GELU() + self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) + + def forward(self, h_V, h_E, mask_V=None, mask_attend=None): + """Parallel computation of full transformer layer""" + + # Concatenate h_V_i to h_E_ij + h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_E.size(-2), -1) + h_EV = torch.cat([h_V_expand, h_E], -1) + + h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV))))) + if mask_attend is not None: + h_message = mask_attend.unsqueeze(-1) * h_message + dh = torch.sum(h_message, -2) / self.scale + + h_V = self.norm1(h_V + self.dropout1(dh)) + + # Position-wise feedforward + dh = self.dense(h_V) + h_V = self.norm2(h_V + self.dropout2(dh)) + + if mask_V is not None: + mask_V = mask_V.unsqueeze(-1) + h_V = mask_V * h_V + return h_V + + +class EncLayer(torch.nn.Module): + def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): + super(EncLayer, self).__init__() + self.num_hidden = num_hidden + self.num_in = num_in + self.scale = scale + self.dropout1 = torch.nn.Dropout(dropout) + self.dropout2 = torch.nn.Dropout(dropout) + self.dropout3 = torch.nn.Dropout(dropout) + self.norm1 = torch.nn.LayerNorm(num_hidden) + self.norm2 = torch.nn.LayerNorm(num_hidden) + self.norm3 = torch.nn.LayerNorm(num_hidden) + + self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True) + self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True) + self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True) + self.W11 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True) + self.W12 = torch.nn.Linear(num_hidden, num_hidden, bias=True) + self.W13 = torch.nn.Linear(num_hidden, num_hidden, bias=True) + self.act = torch.nn.GELU() + self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) + + def forward(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None): + """Parallel computation of full transformer layer""" + + h_EV = cat_neighbors_nodes(h_V, h_E, E_idx) + h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1) + h_EV = torch.cat([h_V_expand, h_EV], -1) + h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV))))) + if mask_attend is not None: + h_message = mask_attend.unsqueeze(-1) * h_message + dh = torch.sum(h_message, -2) / self.scale + h_V = self.norm1(h_V + self.dropout1(dh)) + + dh = self.dense(h_V) + h_V = self.norm2(h_V + self.dropout2(dh)) + if mask_V is not None: + mask_V = mask_V.unsqueeze(-1) + h_V = mask_V * h_V + + h_EV = cat_neighbors_nodes(h_V, h_E, E_idx) + h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1) + h_EV = torch.cat([h_V_expand, h_EV], -1) + h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV))))) + h_E = self.norm3(h_E + self.dropout3(h_message)) + return h_V, h_E + + +# The following gather functions +def gather_edges(edges, neighbor_idx): + # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C] + neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1)) + edge_features = torch.gather(edges, 2, neighbors) + return edge_features + + +def gather_nodes(nodes, neighbor_idx): + # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C] + # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C] + neighbors_flat = neighbor_idx.reshape((neighbor_idx.shape[0], -1)) + neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) + # Gather and re-pack + neighbor_features = torch.gather(nodes, 1, neighbors_flat) + neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) + return neighbor_features + + +def gather_nodes_t(nodes, neighbor_idx): + # Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C] + idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2)) + neighbor_features = torch.gather(nodes, 1, idx_flat) + return neighbor_features + + +def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx): + h_nodes = gather_nodes(h_nodes, E_idx) + h_nn = torch.cat([h_neighbors, h_nodes], -1) + return h_nn diff --git a/nextflow.config b/nextflow.config new file mode 100644 index 0000000..575a9ab --- /dev/null +++ b/nextflow.config @@ -0,0 +1,42 @@ +// Manifest for Nextflow metadata +manifest { + name = 'LigandMPNN-Nextflow' + author = 'Generated from LigandMPNN repository' + homePage = 'https://github.com/dauparas/LigandMPNN' + description = 'Nextflow pipeline for LigandMPNN - Protein sequence design with ligand context' + mainScript = 'main.nf' + version = '1.0.0' +} + +// Global default parameters +params { + pdb = "/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb" + outdir = "/mnt/OmicNAS/private/old/olamide/ligandmpnn/output" + model_type = "ligand_mpnn" + temperature = 0.1 + seed = 111 + batch_size = 1 + number_of_batches = 1 + chains_to_design = "" + fixed_residues = "" +} + +// Container configurations +docker { + enabled = true + runOptions = '--gpus all' +} + +// Process configurations +process { + cpus = 4 + memory = '16 GB' +} + +// Execution configurations +executor { + $local { + cpus = 8 + memory = '32 GB' + } +} diff --git a/params.json b/params.json new file mode 100644 index 0000000..f75854e --- /dev/null +++ b/params.json @@ -0,0 +1,131 @@ +{ + "params": { + "pdb": { + "type": "file", + "description": "Path to input PDB file for protein sequence design", + "default": "/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb", + "required": true, + "pipeline_io": "input", + "var_name": "params.pdb", + "examples": [ + "/mnt/workflow/input/protein.pdb", + "/mnt/workflow/input/*.pdb" + ], + "pattern": ".*\\.pdb$", + "enum": [], + "validation": {}, + "notes": "Input PDB file containing the protein structure for sequence design." + }, + "outdir": { + "type": "folder", + "description": "Directory for LigandMPNN output results", + "default": "/mnt/OmicNAS/private/old/olamide/ligandmpnn/output", + "required": true, + "pipeline_io": "output", + "var_name": "params.outdir", + "examples": [ + "/mnt/workflow/output", + "/path/to/results" + ], + "pattern": ".*", + "enum": [], + "validation": {}, + "notes": "Directory where designed sequences and backbone PDBs will be saved." + }, + "model_type": { + "type": "string", + "description": "Type of MPNN model to use for sequence design", + "default": "ligand_mpnn", + "required": false, + "pipeline_io": "parameter", + "var_name": "params.model_type", + "examples": [ + "protein_mpnn", + "ligand_mpnn", + "soluble_mpnn" + ], + "pattern": "^(protein_mpnn|ligand_mpnn|soluble_mpnn|global_label_membrane_mpnn|per_residue_label_membrane_mpnn)$", + "enum": ["protein_mpnn", "ligand_mpnn", "soluble_mpnn", "global_label_membrane_mpnn", "per_residue_label_membrane_mpnn"], + "validation": {}, + "notes": "protein_mpnn: Original ProteinMPNN. ligand_mpnn: Context-aware with ligands. soluble_mpnn: Trained on soluble proteins." + }, + "temperature": { + "type": "number", + "description": "Sampling temperature for sequence generation", + "default": 0.1, + "required": false, + "pipeline_io": "parameter", + "var_name": "params.temperature", + "examples": [0.05, 0.1, 0.2], + "pattern": null, + "enum": [], + "validation": {}, + "notes": "Higher temperature gives more sequence diversity. Recommended range: 0.05-0.5" + }, + "seed": { + "type": "integer", + "description": "Random seed for reproducibility", + "default": 111, + "required": false, + "pipeline_io": "parameter", + "var_name": "params.seed", + "examples": [111, 42, 12345], + "pattern": null, + "enum": [], + "validation": {}, + "notes": "Set for reproducible results." + }, + "batch_size": { + "type": "integer", + "description": "Number of sequences to generate per batch", + "default": 1, + "required": false, + "pipeline_io": "parameter", + "var_name": "params.batch_size", + "examples": [1, 3, 5], + "pattern": null, + "enum": [], + "validation": {}, + "notes": "Higher batch sizes require more GPU memory." + }, + "number_of_batches": { + "type": "integer", + "description": "Number of batches to run", + "default": 1, + "required": false, + "pipeline_io": "parameter", + "var_name": "params.number_of_batches", + "examples": [1, 5, 10], + "pattern": null, + "enum": [], + "validation": {}, + "notes": "Total sequences = batch_size × number_of_batches" + }, + "chains_to_design": { + "type": "string", + "description": "Comma-separated chain IDs to redesign", + "default": "", + "required": false, + "pipeline_io": "parameter", + "var_name": "params.chains_to_design", + "examples": ["A", "A,B", "A,B,C"], + "pattern": "^([A-Z],?)*$", + "enum": [], + "validation": {}, + "notes": "Leave empty to design all chains." + }, + "fixed_residues": { + "type": "string", + "description": "Space-separated list of residues to keep fixed", + "default": "", + "required": false, + "pipeline_io": "parameter", + "var_name": "params.fixed_residues", + "examples": ["A1 A2 A3", "A12 B25 B26"], + "pattern": null, + "enum": [], + "validation": {}, + "notes": "Format: ChainResidue (e.g., A12). Leave empty to design all residues." + } + } +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2c1c63f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +biopython==1.79 +filelock==3.13.1 +fsspec==2024.3.1 +Jinja2==3.1.3 +MarkupSafe==2.1.5 +mpmath==1.3.0 +networkx==3.2.1 +numpy==1.23.5 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.19.3 +nvidia-nvjitlink-cu12==12.4.99 +nvidia-nvtx-cu12==12.1.105 +ProDy==2.4.1 +pyparsing==3.1.1 +scipy==1.12.0 +sympy==1.12 +torch==2.2.1 +triton==2.2.0 +typing_extensions==4.10.0 +ml-collections==0.1.1 +dm-tree==0.1.8 diff --git a/run.py b/run.py new file mode 100644 index 0000000..b26720f --- /dev/null +++ b/run.py @@ -0,0 +1,990 @@ +import argparse +import copy +import json +import os.path +import random +import sys + +import numpy as np +import torch +from data_utils import ( + alphabet, + element_dict_rev, + featurize, + get_score, + get_seq_rec, + parse_PDB, + restype_1to3, + restype_int_to_str, + restype_str_to_int, + write_full_PDB, +) +from model_utils import ProteinMPNN +from prody import writePDB +from sc_utils import Packer, pack_side_chains + + +def main(args) -> None: + """ + Inference function + """ + if args.seed: + seed = args.seed + else: + seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0]) + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") + folder_for_outputs = args.out_folder + base_folder = folder_for_outputs + if base_folder[-1] != "/": + base_folder = base_folder + "/" + if not os.path.exists(base_folder): + os.makedirs(base_folder, exist_ok=True) + if not os.path.exists(base_folder + "seqs"): + os.makedirs(base_folder + "seqs", exist_ok=True) + if not os.path.exists(base_folder + "backbones"): + os.makedirs(base_folder + "backbones", exist_ok=True) + if not os.path.exists(base_folder + "packed"): + os.makedirs(base_folder + "packed", exist_ok=True) + if args.save_stats: + if not os.path.exists(base_folder + "stats"): + os.makedirs(base_folder + "stats", exist_ok=True) + if args.model_type == "protein_mpnn": + checkpoint_path = args.checkpoint_protein_mpnn + elif args.model_type == "ligand_mpnn": + checkpoint_path = args.checkpoint_ligand_mpnn + elif args.model_type == "per_residue_label_membrane_mpnn": + checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn + elif args.model_type == "global_label_membrane_mpnn": + checkpoint_path = args.checkpoint_global_label_membrane_mpnn + elif args.model_type == "soluble_mpnn": + checkpoint_path = args.checkpoint_soluble_mpnn + else: + print("Choose one of the available models") + sys.exit() + checkpoint = torch.load(checkpoint_path, map_location=device) + if args.model_type == "ligand_mpnn": + atom_context_num = checkpoint["atom_context_num"] + ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context + k_neighbors = checkpoint["num_edges"] + else: + atom_context_num = 1 + ligand_mpnn_use_side_chain_context = 0 + k_neighbors = checkpoint["num_edges"] + + model = ProteinMPNN( + node_features=128, + edge_features=128, + hidden_dim=128, + num_encoder_layers=3, + num_decoder_layers=3, + k_neighbors=k_neighbors, + device=device, + atom_context_num=atom_context_num, + model_type=args.model_type, + ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context, + ) + + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(device) + model.eval() + + if args.pack_side_chains: + model_sc = Packer( + node_features=128, + edge_features=128, + num_positional_embeddings=16, + num_chain_embeddings=16, + num_rbf=16, + hidden_dim=128, + num_encoder_layers=3, + num_decoder_layers=3, + atom_context_num=16, + lower_bound=0.0, + upper_bound=20.0, + top_k=32, + dropout=0.0, + augment_eps=0.0, + atom37_order=False, + device=device, + num_mix=3, + ) + + checkpoint_sc = torch.load(args.checkpoint_path_sc, map_location=device) + model_sc.load_state_dict(checkpoint_sc["model_state_dict"]) + model_sc.to(device) + model_sc.eval() + + if args.pdb_path_multi: + with open(args.pdb_path_multi, "r") as fh: + pdb_paths = list(json.load(fh)) + else: + pdb_paths = [args.pdb_path] + + if args.fixed_residues_multi: + with open(args.fixed_residues_multi, "r") as fh: + fixed_residues_multi = json.load(fh) + fixed_residues_multi = {key:value.split() for key,value in fixed_residues_multi.items()} + else: + fixed_residues = [item for item in args.fixed_residues.split()] + fixed_residues_multi = {} + for pdb in pdb_paths: + fixed_residues_multi[pdb] = fixed_residues + + if args.redesigned_residues_multi: + with open(args.redesigned_residues_multi, "r") as fh: + redesigned_residues_multi = json.load(fh) + redesigned_residues_multi = {key:value.split() for key,value in redesigned_residues_multi.items()} + else: + redesigned_residues = [item for item in args.redesigned_residues.split()] + redesigned_residues_multi = {} + for pdb in pdb_paths: + redesigned_residues_multi[pdb] = redesigned_residues + + bias_AA = torch.zeros([21], device=device, dtype=torch.float32) + if args.bias_AA: + tmp = [item.split(":") for item in args.bias_AA.split(",")] + a1 = [b[0] for b in tmp] + a2 = [float(b[1]) for b in tmp] + for i, AA in enumerate(a1): + bias_AA[restype_str_to_int[AA]] = a2[i] + + if args.bias_AA_per_residue_multi: + with open(args.bias_AA_per_residue_multi, "r") as fh: + bias_AA_per_residue_multi = json.load( + fh + ) # {"pdb_path" : {"A12": {"G": 1.1}}} + else: + if args.bias_AA_per_residue: + with open(args.bias_AA_per_residue, "r") as fh: + bias_AA_per_residue = json.load(fh) # {"A12": {"G": 1.1}} + bias_AA_per_residue_multi = {} + for pdb in pdb_paths: + bias_AA_per_residue_multi[pdb] = bias_AA_per_residue + + if args.omit_AA_per_residue_multi: + with open(args.omit_AA_per_residue_multi, "r") as fh: + omit_AA_per_residue_multi = json.load( + fh + ) # {"pdb_path" : {"A12": "PQR", "A13": "QS"}} + else: + if args.omit_AA_per_residue: + with open(args.omit_AA_per_residue, "r") as fh: + omit_AA_per_residue = json.load(fh) # {"A12": "PG"} + omit_AA_per_residue_multi = {} + for pdb in pdb_paths: + omit_AA_per_residue_multi[pdb] = omit_AA_per_residue + omit_AA_list = args.omit_AA + omit_AA = torch.tensor( + np.array([AA in omit_AA_list for AA in alphabet]).astype(np.float32), + device=device, + ) + + if len(args.parse_these_chains_only) != 0: + parse_these_chains_only_list = args.parse_these_chains_only.split(",") + else: + parse_these_chains_only_list = [] + + + # loop over PDB paths + for pdb in pdb_paths: + if args.verbose: + print("Designing protein from this path:", pdb) + fixed_residues = fixed_residues_multi[pdb] + redesigned_residues = redesigned_residues_multi[pdb] + parse_all_atoms_flag = args.ligand_mpnn_use_side_chain_context or ( + args.pack_side_chains and not args.repack_everything + ) + protein_dict, backbone, other_atoms, icodes, _ = parse_PDB( + pdb, + device=device, + chains=parse_these_chains_only_list, + parse_all_atoms=parse_all_atoms_flag, + parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy, + ) + # make chain_letter + residue_idx + insertion_code mapping to integers + R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices + chain_letters_list = list(protein_dict["chain_letters"]) # chain letters + encoded_residues = [] + for i, R_idx_item in enumerate(R_idx_list): + tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i] + encoded_residues.append(tmp) + encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues)))) + encoded_residue_dict_rev = dict( + zip(list(range(len(encoded_residues))), encoded_residues) + ) + + bias_AA_per_residue = torch.zeros( + [len(encoded_residues), 21], device=device, dtype=torch.float32 + ) + if args.bias_AA_per_residue_multi or args.bias_AA_per_residue: + bias_dict = bias_AA_per_residue_multi[pdb] + for residue_name, v1 in bias_dict.items(): + if residue_name in encoded_residues: + i1 = encoded_residue_dict[residue_name] + for amino_acid, v2 in v1.items(): + if amino_acid in alphabet: + j1 = restype_str_to_int[amino_acid] + bias_AA_per_residue[i1, j1] = v2 + + omit_AA_per_residue = torch.zeros( + [len(encoded_residues), 21], device=device, dtype=torch.float32 + ) + if args.omit_AA_per_residue_multi or args.omit_AA_per_residue: + omit_dict = omit_AA_per_residue_multi[pdb] + for residue_name, v1 in omit_dict.items(): + if residue_name in encoded_residues: + i1 = encoded_residue_dict[residue_name] + for amino_acid in v1: + if amino_acid in alphabet: + j1 = restype_str_to_int[amino_acid] + omit_AA_per_residue[i1, j1] = 1.0 + + fixed_positions = torch.tensor( + [int(item not in fixed_residues) for item in encoded_residues], + device=device, + ) + redesigned_positions = torch.tensor( + [int(item not in redesigned_residues) for item in encoded_residues], + device=device, + ) + + # specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model + if args.transmembrane_buried: + buried_residues = [item for item in args.transmembrane_buried.split()] + buried_positions = torch.tensor( + [int(item in buried_residues) for item in encoded_residues], + device=device, + ) + else: + buried_positions = torch.zeros_like(fixed_positions) + + if args.transmembrane_interface: + interface_residues = [item for item in args.transmembrane_interface.split()] + interface_positions = torch.tensor( + [int(item in interface_residues) for item in encoded_residues], + device=device, + ) + else: + interface_positions = torch.zeros_like(fixed_positions) + protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * ( + 1 - interface_positions + ) + 1 * interface_positions * (1 - buried_positions) + + if args.model_type == "global_label_membrane_mpnn": + protein_dict["membrane_per_residue_labels"] = ( + args.global_transmembrane_label + 0 * fixed_positions + ) + if len(args.chains_to_design) != 0: + chains_to_design_list = args.chains_to_design.split(",") + else: + chains_to_design_list = protein_dict["chain_letters"] + + chain_mask = torch.tensor( + np.array( + [ + item in chains_to_design_list + for item in protein_dict["chain_letters"] + ], + dtype=np.int32, + ), + device=device, + ) + + # create chain_mask to notify which residues are fixed (0) and which need to be designed (1) + if redesigned_residues: + protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions) + elif fixed_residues: + protein_dict["chain_mask"] = chain_mask * fixed_positions + else: + protein_dict["chain_mask"] = chain_mask + + if args.verbose: + PDB_residues_to_be_redesigned = [ + encoded_residue_dict_rev[item] + for item in range(protein_dict["chain_mask"].shape[0]) + if protein_dict["chain_mask"][item] == 1 + ] + PDB_residues_to_be_fixed = [ + encoded_residue_dict_rev[item] + for item in range(protein_dict["chain_mask"].shape[0]) + if protein_dict["chain_mask"][item] == 0 + ] + print("These residues will be redesigned: ", PDB_residues_to_be_redesigned) + print("These residues will be fixed: ", PDB_residues_to_be_fixed) + + # specify which residues are linked + if args.symmetry_residues: + symmetry_residues_list_of_lists = [ + x.split(",") for x in args.symmetry_residues.split("|") + ] + remapped_symmetry_residues = [] + for t_list in symmetry_residues_list_of_lists: + tmp_list = [] + for t in t_list: + tmp_list.append(encoded_residue_dict[t]) + remapped_symmetry_residues.append(tmp_list) + else: + remapped_symmetry_residues = [[]] + + # specify linking weights + if args.symmetry_weights: + symmetry_weights = [ + [float(item) for item in x.split(",")] + for x in args.symmetry_weights.split("|") + ] + else: + symmetry_weights = [[]] + + if args.homo_oligomer: + if args.verbose: + print("Designing HOMO-OLIGOMER") + chain_letters_set = list(set(chain_letters_list)) + reference_chain = chain_letters_set[0] + lc = len(reference_chain) + residue_indices = [ + item[lc:] for item in encoded_residues if item[:lc] == reference_chain + ] + remapped_symmetry_residues = [] + symmetry_weights = [] + for res in residue_indices: + tmp_list = [] + tmp_w_list = [] + for chain in chain_letters_set: + name = chain + res + tmp_list.append(encoded_residue_dict[name]) + tmp_w_list.append(1 / len(chain_letters_set)) + remapped_symmetry_residues.append(tmp_list) + symmetry_weights.append(tmp_w_list) + + # set other atom bfactors to 0.0 + if other_atoms: + other_bfactors = other_atoms.getBetas() + other_atoms.setBetas(other_bfactors * 0.0) + + # adjust input PDB name by dropping .pdb if it does exist + name = pdb[pdb.rfind("/") + 1 :] + if name[-4:] == ".pdb": + name = name[:-4] + + with torch.no_grad(): + # run featurize to remap R_idx and add batch dimension + if args.verbose: + if "Y" in list(protein_dict): + atom_coords = protein_dict["Y"].cpu().numpy() + atom_types = list(protein_dict["Y_t"].cpu().numpy()) + atom_mask = list(protein_dict["Y_m"].cpu().numpy()) + number_of_atoms_parsed = np.sum(atom_mask) + else: + print("No ligand atoms parsed") + number_of_atoms_parsed = 0 + atom_types = "" + atom_coords = [] + if number_of_atoms_parsed == 0: + print("No ligand atoms parsed") + elif args.model_type == "ligand_mpnn": + print( + f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}" + ) + for i, atom_type in enumerate(atom_types): + print( + f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}" + ) + feature_dict = featurize( + protein_dict, + cutoff_for_score=args.ligand_mpnn_cutoff_for_score, + use_atom_context=args.ligand_mpnn_use_atom_context, + number_of_ligand_atoms=atom_context_num, + model_type=args.model_type, + ) + feature_dict["batch_size"] = args.batch_size + B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now. + # add additional keys to the feature dictionary + feature_dict["temperature"] = args.temperature + feature_dict["bias"] = ( + (-1e8 * omit_AA[None, None, :] + bias_AA).repeat([1, L, 1]) + + bias_AA_per_residue[None] + - 1e8 * omit_AA_per_residue[None] + ) + feature_dict["symmetry_residues"] = remapped_symmetry_residues + feature_dict["symmetry_weights"] = symmetry_weights + + sampling_probs_list = [] + log_probs_list = [] + decoding_order_list = [] + S_list = [] + loss_list = [] + loss_per_residue_list = [] + loss_XY_list = [] + for _ in range(args.number_of_batches): + feature_dict["randn"] = torch.randn( + [feature_dict["batch_size"], feature_dict["mask"].shape[1]], + device=device, + ) + output_dict = model.sample(feature_dict) + + # compute confidence scores + loss, loss_per_residue = get_score( + output_dict["S"], + output_dict["log_probs"], + feature_dict["mask"] * feature_dict["chain_mask"], + ) + if args.model_type == "ligand_mpnn": + combined_mask = ( + feature_dict["mask"] + * feature_dict["mask_XY"] + * feature_dict["chain_mask"] + ) + else: + combined_mask = feature_dict["mask"] * feature_dict["chain_mask"] + loss_XY, _ = get_score( + output_dict["S"], output_dict["log_probs"], combined_mask + ) + # ----- + S_list.append(output_dict["S"]) + log_probs_list.append(output_dict["log_probs"]) + sampling_probs_list.append(output_dict["sampling_probs"]) + decoding_order_list.append(output_dict["decoding_order"]) + loss_list.append(loss) + loss_per_residue_list.append(loss_per_residue) + loss_XY_list.append(loss_XY) + S_stack = torch.cat(S_list, 0) + log_probs_stack = torch.cat(log_probs_list, 0) + sampling_probs_stack = torch.cat(sampling_probs_list, 0) + decoding_order_stack = torch.cat(decoding_order_list, 0) + loss_stack = torch.cat(loss_list, 0) + loss_per_residue_stack = torch.cat(loss_per_residue_list, 0) + loss_XY_stack = torch.cat(loss_XY_list, 0) + rec_mask = feature_dict["mask"][:1] * feature_dict["chain_mask"][:1] + rec_stack = get_seq_rec(feature_dict["S"][:1], S_stack, rec_mask) + + native_seq = "".join( + [restype_int_to_str[AA] for AA in feature_dict["S"][0].cpu().numpy()] + ) + seq_np = np.array(list(native_seq)) + seq_out_str = [] + for mask in protein_dict["mask_c"]: + seq_out_str += list(seq_np[mask.cpu().numpy()]) + seq_out_str += [args.fasta_seq_separation] + seq_out_str = "".join(seq_out_str)[:-1] + + output_fasta = base_folder + "/seqs/" + name + args.file_ending + ".fa" + output_backbones = base_folder + "/backbones/" + output_packed = base_folder + "/packed/" + output_stats_path = base_folder + "stats/" + name + args.file_ending + ".pt" + + out_dict = {} + out_dict["generated_sequences"] = S_stack.cpu() + out_dict["sampling_probs"] = sampling_probs_stack.cpu() + out_dict["log_probs"] = log_probs_stack.cpu() + out_dict["decoding_order"] = decoding_order_stack.cpu() + out_dict["native_sequence"] = feature_dict["S"][0].cpu() + out_dict["mask"] = feature_dict["mask"][0].cpu() + out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu() + out_dict["seed"] = seed + out_dict["temperature"] = args.temperature + if args.save_stats: + torch.save(out_dict, output_stats_path) + + if args.pack_side_chains: + if args.verbose: + print("Packing side chains...") + feature_dict_ = featurize( + protein_dict, + cutoff_for_score=8.0, + use_atom_context=args.pack_with_ligand_context, + number_of_ligand_atoms=16, + model_type="ligand_mpnn", + ) + sc_feature_dict = copy.deepcopy(feature_dict_) + B = args.batch_size + for k, v in sc_feature_dict.items(): + if k != "S": + try: + num_dim = len(v.shape) + if num_dim == 2: + sc_feature_dict[k] = v.repeat(B, 1) + elif num_dim == 3: + sc_feature_dict[k] = v.repeat(B, 1, 1) + elif num_dim == 4: + sc_feature_dict[k] = v.repeat(B, 1, 1, 1) + elif num_dim == 5: + sc_feature_dict[k] = v.repeat(B, 1, 1, 1, 1) + except: + pass + X_stack_list = [] + X_m_stack_list = [] + b_factor_stack_list = [] + for _ in range(args.number_of_packs_per_design): + X_list = [] + X_m_list = [] + b_factor_list = [] + for c in range(args.number_of_batches): + sc_feature_dict["S"] = S_list[c] + sc_dict = pack_side_chains( + sc_feature_dict, + model_sc, + args.sc_num_denoising_steps, + args.sc_num_samples, + args.repack_everything, + ) + X_list.append(sc_dict["X"]) + X_m_list.append(sc_dict["X_m"]) + b_factor_list.append(sc_dict["b_factors"]) + + X_stack = torch.cat(X_list, 0) + X_m_stack = torch.cat(X_m_list, 0) + b_factor_stack = torch.cat(b_factor_list, 0) + + X_stack_list.append(X_stack) + X_m_stack_list.append(X_m_stack) + b_factor_stack_list.append(b_factor_stack) + + with open(output_fasta, "w") as f: + f.write( + ">{}, T={}, seed={}, num_res={}, num_ligand_res={}, use_ligand_context={}, ligand_cutoff_distance={}, batch_size={}, number_of_batches={}, model_path={}\n{}\n".format( + name, + args.temperature, + seed, + torch.sum(rec_mask).cpu().numpy(), + torch.sum(combined_mask[:1]).cpu().numpy(), + bool(args.ligand_mpnn_use_atom_context), + float(args.ligand_mpnn_cutoff_for_score), + args.batch_size, + args.number_of_batches, + checkpoint_path, + seq_out_str, + ) + ) + for ix in range(S_stack.shape[0]): + ix_suffix = ix + if not args.zero_indexed: + ix_suffix += 1 + seq_rec_print = np.format_float_positional( + rec_stack[ix].cpu().numpy(), unique=False, precision=4 + ) + loss_np = np.format_float_positional( + np.exp(-loss_stack[ix].cpu().numpy()), unique=False, precision=4 + ) + loss_XY_np = np.format_float_positional( + np.exp(-loss_XY_stack[ix].cpu().numpy()), + unique=False, + precision=4, + ) + seq = "".join( + [restype_int_to_str[AA] for AA in S_stack[ix].cpu().numpy()] + ) + + # write new sequences into PDB with backbone coordinates + seq_prody = np.array([restype_1to3[AA] for AA in list(seq)])[ + None, + ].repeat(4, 1) + bfactor_prody = ( + loss_per_residue_stack[ix].cpu().numpy()[None, :].repeat(4, 1) + ) + backbone.setResnames(seq_prody) + backbone.setBetas( + np.exp(-bfactor_prody) + * (bfactor_prody > 0.01).astype(np.float32) + ) + if other_atoms: + writePDB( + output_backbones + + name + + "_" + + str(ix_suffix) + + args.file_ending + + ".pdb", + backbone + other_atoms, + ) + else: + writePDB( + output_backbones + + name + + "_" + + str(ix_suffix) + + args.file_ending + + ".pdb", + backbone, + ) + + # write full PDB files + if args.pack_side_chains: + for c_pack in range(args.number_of_packs_per_design): + X_stack = X_stack_list[c_pack] + X_m_stack = X_m_stack_list[c_pack] + b_factor_stack = b_factor_stack_list[c_pack] + write_full_PDB( + output_packed + + name + + args.packed_suffix + + "_" + + str(ix_suffix) + + "_" + + str(c_pack + 1) + + args.file_ending + + ".pdb", + X_stack[ix].cpu().numpy(), + X_m_stack[ix].cpu().numpy(), + b_factor_stack[ix].cpu().numpy(), + feature_dict["R_idx_original"][0].cpu().numpy(), + protein_dict["chain_letters"], + S_stack[ix].cpu().numpy(), + other_atoms=other_atoms, + icodes=icodes, + force_hetatm=args.force_hetatm, + ) + # ----- + + # write fasta lines + seq_np = np.array(list(seq)) + seq_out_str = [] + for mask in protein_dict["mask_c"]: + seq_out_str += list(seq_np[mask.cpu().numpy()]) + seq_out_str += [args.fasta_seq_separation] + seq_out_str = "".join(seq_out_str)[:-1] + if ix == S_stack.shape[0] - 1: + # final 2 lines + f.write( + ">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}".format( + name, + ix_suffix, + args.temperature, + seed, + loss_np, + loss_XY_np, + seq_rec_print, + seq_out_str, + ) + ) + else: + f.write( + ">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}\n".format( + name, + ix_suffix, + args.temperature, + seed, + loss_np, + loss_XY_np, + seq_rec_print, + seq_out_str, + ) + ) + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + argparser.add_argument( + "--model_type", + type=str, + default="protein_mpnn", + help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn", + ) + # protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms + # ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB + # per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed + # global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane + # soluble_mpnn - ProteinMPNN trained only on soluble PDB ids + argparser.add_argument( + "--checkpoint_protein_mpnn", + type=str, + default="./model_params/proteinmpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_ligand_mpnn", + type=str, + default="./model_params/ligandmpnn_v_32_010_25.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_per_residue_label_membrane_mpnn", + type=str, + default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_global_label_membrane_mpnn", + type=str, + default="./model_params/global_label_membrane_mpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_soluble_mpnn", + type=str, + default="./model_params/solublempnn_v_48_020.pt", + help="Path to model weights.", + ) + + argparser.add_argument( + "--fasta_seq_separation", + type=str, + default=":", + help="Symbol to use between sequences from different chains", + ) + argparser.add_argument("--verbose", type=int, default=1, help="Print stuff") + + argparser.add_argument( + "--pdb_path", type=str, default="", help="Path to the input PDB." + ) + argparser.add_argument( + "--pdb_path_multi", + type=str, + default="", + help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.", + ) + + argparser.add_argument( + "--fixed_residues", + type=str, + default="", + help="Provide fixed residues, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--fixed_residues_multi", + type=str, + default="", + help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}", + ) + + argparser.add_argument( + "--redesigned_residues", + type=str, + default="", + help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--redesigned_residues_multi", + type=str, + default="", + help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}", + ) + + argparser.add_argument( + "--bias_AA", + type=str, + default="", + help="Bias generation of amino acids, e.g. 'A:-1.024,P:2.34,C:-12.34'", + ) + argparser.add_argument( + "--bias_AA_per_residue", + type=str, + default="", + help="Path to json mapping of bias {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}", + ) + argparser.add_argument( + "--bias_AA_per_residue_multi", + type=str, + default="", + help="Path to json mapping of bias {'pdb_path': {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}}", + ) + + argparser.add_argument( + "--omit_AA", + type=str, + default="", + help="Bias generation of amino acids, e.g. 'ACG'", + ) + argparser.add_argument( + "--omit_AA_per_residue", + type=str, + default="", + help="Path to json mapping of bias {'A12': 'APQ', 'A13': 'QST'}", + ) + argparser.add_argument( + "--omit_AA_per_residue_multi", + type=str, + default="", + help="Path to json mapping of bias {'pdb_path': {'A12': 'QSPC', 'A13': 'AGE'}}", + ) + + argparser.add_argument( + "--symmetry_residues", + type=str, + default="", + help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'", + ) + argparser.add_argument( + "--symmetry_weights", + type=str, + default="", + help="Add weights that match symmetry_residues, e.g. '1.01,1.0,1.0|-1.0,2.0|2.0,2.3'", + ) + argparser.add_argument( + "--homo_oligomer", + type=int, + default=0, + help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.", + ) + + argparser.add_argument( + "--out_folder", + type=str, + help="Path to a folder to output sequences, e.g. /home/out/", + ) + argparser.add_argument( + "--file_ending", type=str, default="", help="adding_string_to_the_end" + ) + argparser.add_argument( + "--zero_indexed", + type=str, + default=0, + help="1 - to start output PDB numbering with 0", + ) + argparser.add_argument( + "--seed", + type=int, + default=0, + help="Set seed for torch, numpy, and python random.", + ) + argparser.add_argument( + "--batch_size", + type=int, + default=1, + help="Number of sequence to generate per one pass.", + ) + argparser.add_argument( + "--number_of_batches", + type=int, + default=1, + help="Number of times to design sequence using a chosen batch size.", + ) + argparser.add_argument( + "--temperature", + type=float, + default=0.1, + help="Temperature to sample sequences.", + ) + argparser.add_argument( + "--save_stats", type=int, default=0, help="Save output statistics" + ) + + argparser.add_argument( + "--ligand_mpnn_use_atom_context", + type=int, + default=1, + help="1 - use atom context, 0 - do not use atom context.", + ) + argparser.add_argument( + "--ligand_mpnn_cutoff_for_score", + type=float, + default=8.0, + help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.", + ) + argparser.add_argument( + "--ligand_mpnn_use_side_chain_context", + type=int, + default=0, + help="Flag to use side chain atoms as ligand context for the fixed residues", + ) + argparser.add_argument( + "--chains_to_design", + type=str, + default="", + help="Specify which chains to redesign, all others will be kept fixed, 'A,B,C,F'", + ) + + argparser.add_argument( + "--parse_these_chains_only", + type=str, + default="", + help="Provide chains letters for parsing backbones, 'A,B,C,F'", + ) + + argparser.add_argument( + "--transmembrane_buried", + type=str, + default="", + help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--transmembrane_interface", + type=str, + default="", + help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25", + ) + + argparser.add_argument( + "--global_transmembrane_label", + type=int, + default=0, + help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble", + ) + + argparser.add_argument( + "--parse_atoms_with_zero_occupancy", + type=int, + default=0, + help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy", + ) + + argparser.add_argument( + "--pack_side_chains", + type=int, + default=0, + help="1 - to run side chain packer, 0 - do not run it", + ) + + argparser.add_argument( + "--checkpoint_path_sc", + type=str, + default="./model_params/ligandmpnn_sc_v_32_002_16.pt", + help="Path to model weights.", + ) + + argparser.add_argument( + "--number_of_packs_per_design", + type=int, + default=4, + help="Number of independent side chain packing samples to return per design", + ) + + argparser.add_argument( + "--sc_num_denoising_steps", + type=int, + default=3, + help="Number of denoising/recycling steps to make for side chain packing", + ) + + argparser.add_argument( + "--sc_num_samples", + type=int, + default=16, + help="Number of samples to draw from a mixture distribution and then take a sample with the highest likelihood.", + ) + + argparser.add_argument( + "--repack_everything", + type=int, + default=0, + help="1 - repacks side chains of all residues including the fixed ones; 0 - keeps the side chains fixed for fixed residues", + ) + + argparser.add_argument( + "--force_hetatm", + type=int, + default=0, + help="To force ligand atoms to be written as HETATM to PDB file after packing.", + ) + + argparser.add_argument( + "--packed_suffix", + type=str, + default="_packed", + help="Suffix for packed PDB paths", + ) + + argparser.add_argument( + "--pack_with_ligand_context", + type=int, + default=1, + help="1-pack side chains using ligand context, 0 - do not use it.", + ) + + args = argparser.parse_args() + main(args) diff --git a/run_examples.sh b/run_examples.sh new file mode 100644 index 0000000..b507906 --- /dev/null +++ b/run_examples.sh @@ -0,0 +1,244 @@ +#!/bin/bash + +#1 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/default" +#2 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --temperature 0.05 \ + --out_folder "./outputs/temperature" + +#3 +python run.py \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/random_seed" + +#4 +python run.py \ + --seed 111 \ + --verbose 0 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/verbose" + +#5 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/save_stats" \ + --save_stats 1 + +#6 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/fix_residues" \ + --fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \ + --bias_AA "A:10.0" + +#7 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/redesign_residues" \ + --redesigned_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \ + --bias_AA "A:10.0" + +#8 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/batch_size" \ + --batch_size 3 \ + --number_of_batches 5 + +#9 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --bias_AA "W:3.0,P:3.0,C:3.0,A:-3.0" \ + --out_folder "./outputs/global_bias" + +#10 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --bias_AA_per_residue "./inputs/bias_AA_per_residue.json" \ + --out_folder "./outputs/per_residue_bias" + +#11 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --omit_AA "CDFGHILMNPQRSTVWY" \ + --out_folder "./outputs/global_omit" + +#12 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --omit_AA_per_residue "./inputs/omit_AA_per_residue.json" \ + --out_folder "./outputs/per_residue_omit" + +#13 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/symmetry" \ + --symmetry_residues "C1,C2,C3|C4,C5|C6,C7" \ + --symmetry_weights "0.33,0.33,0.33|0.5,0.5|0.5,0.5" + +#14 +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/4GYT.pdb" \ + --out_folder "./outputs/homooligomer" \ + --homo_oligomer 1 \ + --number_of_batches 2 + +#15 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/file_ending" \ + --file_ending "_xyz" + +#16 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/zero_indexed" \ + --zero_indexed 1 \ + --number_of_batches 2 + +#17 +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/4GYT.pdb" \ + --out_folder "./outputs/chains_to_design" \ + --chains_to_design "A,B" + +#18 +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/4GYT.pdb" \ + --out_folder "./outputs/parse_these_chains_only" \ + --parse_these_chains_only "A,B" + +#19 +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/ligandmpnn_default" + +#20 +python run.py \ + --checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/ligandmpnn_v_32_005_25" + +#21 +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/ligandmpnn_no_context" \ + --ligand_mpnn_use_atom_context 0 + +#22 +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/ligandmpnn_use_side_chain_atoms" \ + --ligand_mpnn_use_side_chain_context 1 \ + --fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" + +#23 +python run.py \ + --model_type "soluble_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/soluble_mpnn_default" + +#24 +python run.py \ + --model_type "global_label_membrane_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/global_label_membrane_mpnn_0" \ + --global_transmembrane_label 0 + +#25 +python run.py \ + --model_type "per_residue_label_membrane_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/per_residue_label_membrane_mpnn_default" \ + --transmembrane_buried "C1 C2 C3 C11" \ + --transmembrane_interface "C4 C5 C6 C22" + +#26 +python run.py \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/fasta_seq_separation" \ + --fasta_seq_separation ":" + +#27 +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --out_folder "./outputs/pdb_path_multi" \ + --seed 111 + +#28 +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --fixed_residues_multi "./inputs/fix_residues_multi.json" \ + --out_folder "./outputs/fixed_residues_multi" \ + --seed 111 + +#29 +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --redesigned_residues_multi "./inputs/redesigned_residues_multi.json" \ + --out_folder "./outputs/redesigned_residues_multi" \ + --seed 111 + +#30 +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --omit_AA_per_residue_multi "./inputs/omit_AA_per_residue_multi.json" \ + --out_folder "./outputs/omit_AA_per_residue_multi" \ + --seed 111 + +#31 +python run.py \ + --pdb_path_multi "./inputs/pdb_ids.json" \ + --bias_AA_per_residue_multi "./inputs/bias_AA_per_residue_multi.json" \ + --out_folder "./outputs/bias_AA_per_residue_multi" \ + --seed 111 + +#32 +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --ligand_mpnn_cutoff_for_score "6.0" \ + --out_folder "./outputs/ligand_mpnn_cutoff_for_score" + +#33 +python run.py \ + --seed 111 \ + --pdb_path "./inputs/2GFB.pdb" \ + --out_folder "./outputs/insertion_code" \ + --redesigned_residues "B82 B82A B82B B82C" \ + --parse_these_chains_only "B" diff --git a/sc_examples.sh b/sc_examples.sh new file mode 100644 index 0000000..bba50d7 --- /dev/null +++ b/sc_examples.sh @@ -0,0 +1,55 @@ +#1 design a new sequence and pack side chains (return 1 side chain packing sample - fast) +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_default_fast" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 0 \ + --pack_with_ligand_context 1 + +#2 design a new sequence and pack side chains (return 4 side chain packing samples) +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_default" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 4 \ + --pack_with_ligand_context 1 + + +#3 fix specific residues for design and packing +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_fixed_residues" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 4 \ + --pack_with_ligand_context 1 \ + --fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \ + --repack_everything 0 + +#4 fix specific residues for sequence design but repack everything +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_fixed_residues_full_repack" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 4 \ + --pack_with_ligand_context 1 \ + --fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \ + --repack_everything 1 + + +#5 design a new sequence using LigandMPNN but pack side chains without considering ligand/DNA etc atoms +python run.py \ + --model_type "ligand_mpnn" \ + --seed 111 \ + --pdb_path "./inputs/1BC8.pdb" \ + --out_folder "./outputs/sc_no_context" \ + --pack_side_chains 1 \ + --number_of_packs_per_design 4 \ + --pack_with_ligand_context 0 diff --git a/sc_utils.py b/sc_utils.py new file mode 100644 index 0000000..5f30779 --- /dev/null +++ b/sc_utils.py @@ -0,0 +1,1158 @@ +import sys + +import numpy as np +import torch +import torch.distributions as D +import torch.nn as nn +from model_utils import ( + DecLayer, + DecLayerJ, + EncLayer, + PositionalEncodings, + cat_neighbors_nodes, + gather_edges, + gather_nodes, +) + +from openfold.data.data_transforms import atom37_to_torsion_angles, make_atom14_masks +from openfold.np.residue_constants import ( + restype_atom14_mask, + restype_atom14_rigid_group_positions, + restype_atom14_to_rigid_group, + restype_rigid_group_default_frame, +) +from openfold.utils import feats +from openfold.utils.rigid_utils import Rigid + +torch_pi = torch.tensor(np.pi, device="cpu") + + +map_mpnn_to_af2_seq = torch.tensor( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + ], + device="cpu", +) + + +def pack_side_chains( + feature_dict, + model_sc, + num_denoising_steps, + num_samples=10, + repack_everything=True, + num_context_atoms=16, +): + device = feature_dict["X"].device + torsion_dict = make_torsion_features(feature_dict, repack_everything) + feature_dict["X"] = torsion_dict["xyz14_noised"] + feature_dict["X_m"] = torsion_dict["xyz14_m"] + if "Y" not in list(feature_dict): + feature_dict["Y"] = torch.zeros( + [ + feature_dict["X"].shape[0], + feature_dict["X"].shape[1], + num_context_atoms, + 3, + ], + device=device, + ) + feature_dict["Y_t"] = torch.zeros( + [feature_dict["X"].shape[0], feature_dict["X"].shape[1], num_context_atoms], + device=device, + ) + feature_dict["Y_m"] = torch.zeros( + [feature_dict["X"].shape[0], feature_dict["X"].shape[1], num_context_atoms], + device=device, + ) + h_V, h_E, E_idx = model_sc.encode(feature_dict) + feature_dict["h_V"] = h_V + feature_dict["h_E"] = h_E + feature_dict["E_idx"] = E_idx + for step in range(num_denoising_steps): + mean, concentration, mix_logits = model_sc.decode(feature_dict) + mix = D.Categorical(logits=mix_logits) + comp = D.VonMises(mean, concentration) + pred_dist = D.MixtureSameFamily(mix, comp) + predicted_samples = pred_dist.sample([num_samples]) + log_probs_of_samples = pred_dist.log_prob(predicted_samples) + sample = torch.gather( + predicted_samples, dim=0, index=torch.argmax(log_probs_of_samples, 0)[None,] + )[0,] + torsions_pred_unit = torch.cat( + [torch.sin(sample[:, :, :, None]), torch.cos(sample[:, :, :, None])], -1 + ) + torsion_dict["torsions_noised"][:, :, 3:] = torsions_pred_unit * torsion_dict[ + "mask_fix_sc" + ] + torsion_dict["torsions_true"] * (1 - torsion_dict["mask_fix_sc"]) + pred_frames = feats.torsion_angles_to_frames( + torsion_dict["rigids"], + torsion_dict["torsions_noised"], + torsion_dict["aatype"], + torch.tensor(restype_rigid_group_default_frame, device=device), + ) + xyz14_noised = feats.frames_and_literature_positions_to_atom14_pos( + pred_frames, + torsion_dict["aatype"], + torch.tensor(restype_rigid_group_default_frame, device=device), + torch.tensor(restype_atom14_to_rigid_group, device=device), + torch.tensor(restype_atom14_mask, device=device), + torch.tensor(restype_atom14_rigid_group_positions, device=device), + ) + xyz14_noised = xyz14_noised * feature_dict["X_m"][:, :, :, None] + feature_dict["X"] = xyz14_noised + S_af2 = torsion_dict["S_af2"] + + feature_dict["X"] = xyz14_noised + + log_prob = pred_dist.log_prob(sample) * torsion_dict["mask_fix_sc"][ + ..., 0 + ] + 2.0 * (1 - torsion_dict["mask_fix_sc"][..., 0]) + + tmp_types = torch.tensor(restype_atom14_to_rigid_group, device=device)[S_af2] + tmp_types[tmp_types < 4] = 4 + tmp_types -= 4 + atom_types_for_b_factor = torch.nn.functional.one_hot(tmp_types, 4) # [B, L, 14, 4] + + uncertainty = log_prob[:, :, None, :] * atom_types_for_b_factor # [B,L,14,4] + b_factor_pred = uncertainty.sum(-1) # [B, L, 14] + feature_dict["b_factors"] = b_factor_pred + feature_dict["mean"] = mean + feature_dict["concentration"] = concentration + feature_dict["mix_logits"] = mix_logits + feature_dict["log_prob"] = log_prob + feature_dict["sample"] = sample + feature_dict["true_torsion_sin_cos"] = torsion_dict["torsions_true"] + return feature_dict + + +def make_torsion_features(feature_dict, repack_everything=True): + device = feature_dict["mask"].device + + mask = feature_dict["mask"] + B, L = mask.shape + + xyz37 = torch.zeros([B, L, 37, 3], device=device, dtype=torch.float32) + xyz37[:, :, :3] = feature_dict["X"][:, :, :3] + xyz37[:, :, 4] = feature_dict["X"][:, :, 3] + + S_af2 = torch.argmax( + torch.nn.functional.one_hot(feature_dict["S"], 21).float() + @ map_mpnn_to_af2_seq.to(device).float(), + -1, + ) + masks14_37 = make_atom14_masks({"aatype": S_af2}) + temp_dict = { + "aatype": S_af2, + "all_atom_positions": xyz37, + "all_atom_mask": masks14_37["atom37_atom_exists"], + } + torsion_dict = atom37_to_torsion_angles("")(temp_dict) + + rigids = Rigid.make_transform_from_reference( + n_xyz=xyz37[:, :, 0, :], + ca_xyz=xyz37[:, :, 1, :], + c_xyz=xyz37[:, :, 2, :], + eps=1e-9, + ) + + if not repack_everything: + xyz37_true = feature_dict["xyz_37"] + temp_dict_true = { + "aatype": S_af2, + "all_atom_positions": xyz37_true, + "all_atom_mask": masks14_37["atom37_atom_exists"], + } + torsion_dict_true = atom37_to_torsion_angles("")(temp_dict_true) + torsions_true = torch.clone(torsion_dict_true["torsion_angles_sin_cos"])[ + :, :, 3: + ] + mask_fix_sc = feature_dict["chain_mask"][:, :, None, None] + else: + torsions_true = torch.zeros([B, L, 4, 2], device=device) + mask_fix_sc = torch.ones([B, L, 1, 1], device=device) + + random_angle = ( + 2 * torch_pi * torch.rand([S_af2.shape[0], S_af2.shape[1], 4], device=device) + ) + random_sin_cos = torch.cat( + [torch.sin(random_angle)[..., None], torch.cos(random_angle)[..., None]], -1 + ) + torsions_noised = torch.clone(torsion_dict["torsion_angles_sin_cos"]) + torsions_noised[:, :, 3:] = random_sin_cos * mask_fix_sc + torsions_true * ( + 1 - mask_fix_sc + ) + pred_frames = feats.torsion_angles_to_frames( + rigids, + torsions_noised, + S_af2, + torch.tensor(restype_rigid_group_default_frame, device=device), + ) + + xyz14_noised = feats.frames_and_literature_positions_to_atom14_pos( + pred_frames, + S_af2, + torch.tensor(restype_rigid_group_default_frame, device=device), + torch.tensor(restype_atom14_to_rigid_group, device=device).long(), + torch.tensor(restype_atom14_mask, device=device), + torch.tensor(restype_atom14_rigid_group_positions, device=device), + ) + + xyz14_m = masks14_37["atom14_atom_exists"] * mask[:, :, None] + xyz14_noised = xyz14_noised * xyz14_m[:, :, :, None] + torsion_dict["xyz14_m"] = xyz14_m + torsion_dict["xyz14_noised"] = xyz14_noised + torsion_dict["mask_for_loss"] = mask + torsion_dict["rigids"] = rigids + torsion_dict["torsions_noised"] = torsions_noised + torsion_dict["mask_fix_sc"] = mask_fix_sc + torsion_dict["torsions_true"] = torsions_true + torsion_dict["S_af2"] = S_af2 + return torsion_dict + + +class Packer(nn.Module): + def __init__( + self, + edge_features=128, + node_features=128, + num_positional_embeddings=16, + num_chain_embeddings=16, + num_rbf=16, + top_k=30, + augment_eps=0.0, + atom37_order=False, + device=None, + atom_context_num=16, + lower_bound=0.0, + upper_bound=20.0, + hidden_dim=128, + num_encoder_layers=3, + num_decoder_layers=3, + dropout=0.1, + num_mix=3, + ): + super(Packer, self).__init__() + self.edge_features = edge_features + self.node_features = node_features + self.num_positional_embeddings = num_positional_embeddings + self.num_chain_embeddings = num_chain_embeddings + self.num_rbf = num_rbf + self.top_k = top_k + self.augment_eps = augment_eps + self.atom37_order = atom37_order + self.device = device + self.atom_context_num = atom_context_num + self.lower_bound = lower_bound + self.upper_bound = upper_bound + + self.hidden_dim = hidden_dim + self.num_encoder_layers = num_encoder_layers + self.num_decoder_layers = num_decoder_layers + self.dropout = dropout + self.softplus = nn.Softplus(beta=1, threshold=20) + + self.features = ProteinFeatures( + edge_features=edge_features, + node_features=node_features, + num_positional_embeddings=num_positional_embeddings, + num_chain_embeddings=num_chain_embeddings, + num_rbf=num_rbf, + top_k=top_k, + augment_eps=augment_eps, + atom37_order=atom37_order, + device=device, + atom_context_num=atom_context_num, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + + self.W_e = nn.Linear(edge_features, hidden_dim, bias=True) + self.W_v = nn.Linear(node_features, hidden_dim, bias=True) + self.W_f = nn.Linear(edge_features, hidden_dim, bias=True) + self.W_v_sc = nn.Linear(node_features, hidden_dim, bias=True) + self.linear_down = nn.Linear(2 * hidden_dim, hidden_dim, bias=True) + self.W_torsions = nn.Linear(hidden_dim, 4 * 3 * num_mix, bias=True) + self.num_mix = num_mix + + self.dropout = nn.Dropout(dropout) + + # Encoder layers + self.encoder_layers = nn.ModuleList( + [ + EncLayer(hidden_dim, hidden_dim * 2, dropout=dropout) + for _ in range(num_encoder_layers) + ] + ) + + self.W_c = nn.Linear(hidden_dim, hidden_dim, bias=True) + self.W_e_context = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.W_nodes_y = nn.Linear(hidden_dim, hidden_dim, bias=True) + self.W_edges_y = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.context_encoder_layers = nn.ModuleList( + [DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout) for _ in range(2)] + ) + + self.V_C = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.V_C_norm = nn.LayerNorm(hidden_dim) + self.y_context_encoder_layers = nn.ModuleList( + [DecLayerJ(hidden_dim, hidden_dim, dropout=dropout) for _ in range(2)] + ) + + self.h_V_C_dropout = nn.Dropout(dropout) + + # Decoder layers + self.decoder_layers = nn.ModuleList( + [ + DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout) + for _ in range(num_decoder_layers) + ] + ) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def encode(self, feature_dict): + mask = feature_dict["mask"] + V, E, E_idx, Y_nodes, Y_edges, E_context, Y_m = self.features.features_encode( + feature_dict + ) + + h_E_context = self.W_e_context(E_context) + h_V = self.W_v(V) + h_E = self.W_e(E) + mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) + mask_attend = mask.unsqueeze(-1) * mask_attend + for layer in self.encoder_layers: + h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) + + h_V_C = self.W_c(h_V) + Y_m_edges = Y_m[:, :, :, None] * Y_m[:, :, None, :] + Y_nodes = self.W_nodes_y(Y_nodes) + Y_edges = self.W_edges_y(Y_edges) + for i in range(len(self.context_encoder_layers)): + Y_nodes = self.y_context_encoder_layers[i](Y_nodes, Y_edges, Y_m, Y_m_edges) + h_E_context_cat = torch.cat([h_E_context, Y_nodes], -1) + h_V_C = self.context_encoder_layers[i](h_V_C, h_E_context_cat, mask, Y_m) + + h_V_C = self.V_C(h_V_C) + h_V = h_V + self.V_C_norm(self.h_V_C_dropout(h_V_C)) + + return h_V, h_E, E_idx + + def decode(self, feature_dict): + h_V = feature_dict["h_V"] + h_E = feature_dict["h_E"] + E_idx = feature_dict["E_idx"] + mask = feature_dict["mask"] + device = h_V.device + V, F = self.features.features_decode(feature_dict) + + h_F = self.W_f(F) + h_EF = torch.cat([h_E, h_F], -1) + + h_V_sc = self.W_v_sc(V) + h_V_combined = torch.cat([h_V, h_V_sc], -1) + h_V = self.linear_down(h_V_combined) + + for layer in self.decoder_layers: + h_EV = cat_neighbors_nodes(h_V, h_EF, E_idx) + h_V = layer(h_V, h_EV, mask) + + torsions = self.W_torsions(h_V) + torsions = torsions.reshape(h_V.shape[0], h_V.shape[1], 4, self.num_mix, 3) + mean = torsions[:, :, :, :, 0].float() + concentration = 0.1 + self.softplus(torsions[:, :, :, :, 1]).float() + mix_logits = torsions[:, :, :, :, 2].float() + return mean, concentration, mix_logits + + +class ProteinFeatures(nn.Module): + def __init__( + self, + edge_features=128, + node_features=128, + num_positional_embeddings=16, + num_chain_embeddings=16, + num_rbf=16, + top_k=30, + augment_eps=0.0, + atom37_order=False, + device=None, + atom_context_num=16, + lower_bound=0.0, + upper_bound=20.0, + ): + """Extract protein features""" + super(ProteinFeatures, self).__init__() + self.edge_features = edge_features + self.node_features = node_features + self.num_positional_embeddings = num_positional_embeddings + self.num_chain_embeddings = num_chain_embeddings + self.num_rbf = num_rbf + self.top_k = top_k + self.augment_eps = augment_eps + self.atom37_order = atom37_order + self.device = device + self.atom_context_num = atom_context_num + self.lower_bound = lower_bound + self.upper_bound = upper_bound + + # deal with oxygen index + # ------ + self.N_idx = 0 + self.CA_idx = 1 + self.C_idx = 2 + + if atom37_order: + self.O_idx = 4 + else: + self.O_idx = 3 + # ------- + self.positional_embeddings = PositionalEncodings(num_positional_embeddings) + + # Features for the encoder + enc_node_in = 21 # alphabet for the sequence + enc_edge_in = ( + num_positional_embeddings + num_rbf * 25 + ) # positional + distance features + + self.enc_node_in = enc_node_in + self.enc_edge_in = enc_edge_in + + self.enc_edge_embedding = nn.Linear(enc_edge_in, edge_features, bias=False) + self.enc_norm_edges = nn.LayerNorm(edge_features) + self.enc_node_embedding = nn.Linear(enc_node_in, node_features, bias=False) + self.enc_norm_nodes = nn.LayerNorm(node_features) + + # Features for the decoder + dec_node_in = 14 * atom_context_num * num_rbf + dec_edge_in = num_rbf * 14 * 14 + 42 + + self.dec_node_in = dec_node_in + self.dec_edge_in = dec_edge_in + + self.W_XY_project_down1 = nn.Linear(num_rbf + 120, num_rbf, bias=True) + self.dec_edge_embedding1 = nn.Linear(dec_edge_in, edge_features, bias=False) + self.dec_norm_edges1 = nn.LayerNorm(edge_features) + self.dec_node_embedding1 = nn.Linear(dec_node_in, node_features, bias=False) + self.dec_norm_nodes1 = nn.LayerNorm(node_features) + + self.node_project_down = nn.Linear( + 5 * num_rbf + 64 + 4, node_features, bias=True + ) + self.norm_nodes = nn.LayerNorm(node_features) + + self.type_linear = nn.Linear(147, 64) + + self.y_nodes = nn.Linear(147, node_features, bias=False) + self.y_edges = nn.Linear(num_rbf, node_features, bias=False) + + self.norm_y_edges = nn.LayerNorm(node_features) + self.norm_y_nodes = nn.LayerNorm(node_features) + + self.periodic_table_features = torch.tensor( + [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + ], + [ + 0, + 1, + 18, + 1, + 2, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 1, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + ], + [ + 0, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + ], + ], + dtype=torch.long, + device=device, + ) + + def _dist(self, X, mask, eps=1e-6): + mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) + dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2) + D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) + D_max, _ = torch.max(D, -1, keepdim=True) + D_adjust = D + (1.0 - mask_2D) * D_max + sampled_top_k = self.top_k + D_neighbors, E_idx = torch.topk( + D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False + ) + return D_neighbors, E_idx + + def _make_angle_features(self, A, B, C, Y): + v1 = A - B + v2 = C - B + e1 = torch.nn.functional.normalize(v1, dim=-1) + e1_v2_dot = torch.einsum("bli, bli -> bl", e1, v2)[..., None] + u2 = v2 - e1 * e1_v2_dot + e2 = torch.nn.functional.normalize(u2, dim=-1) + e3 = torch.cross(e1, e2, dim=-1) + R_residue = torch.cat( + (e1[:, :, :, None], e2[:, :, :, None], e3[:, :, :, None]), dim=-1 + ) + + local_vectors = torch.einsum( + "blqp, blyq -> blyp", R_residue, Y - B[:, :, None, :] + ) + + rxy = torch.sqrt(local_vectors[..., 0] ** 2 + local_vectors[..., 1] ** 2 + 1e-8) + f1 = local_vectors[..., 0] / rxy + f2 = local_vectors[..., 1] / rxy + rxyz = torch.norm(local_vectors, dim=-1) + 1e-8 + f3 = rxy / rxyz + f4 = local_vectors[..., 2] / rxyz + + f = torch.cat([f1[..., None], f2[..., None], f3[..., None], f4[..., None]], -1) + return f + + def _rbf( + self, + D, + D_mu_shape=[1, 1, 1, -1], + lower_bound=0.0, + upper_bound=20.0, + num_bins=16, + ): + device = D.device + D_min, D_max, D_count = lower_bound, upper_bound, num_bins + D_mu = torch.linspace(D_min, D_max, D_count, device=device) + D_mu = D_mu.view(D_mu_shape) + D_sigma = (D_max - D_min) / D_count + D_expand = torch.unsqueeze(D, -1) + RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) + return RBF + + def _get_rbf( + self, + A, + B, + E_idx, + D_mu_shape=[1, 1, 1, -1], + lower_bound=2.0, + upper_bound=22.0, + num_bins=16, + ): + D_A_B = torch.sqrt( + torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6 + ) # [B, L, L] + D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[ + :, :, :, 0 + ] # [B,L,K] + RBF_A_B = self._rbf( + D_A_B_neighbors, + D_mu_shape=D_mu_shape, + lower_bound=lower_bound, + upper_bound=upper_bound, + num_bins=num_bins, + ) + return RBF_A_B + + def features_encode(self, features): + """ + make protein graph and encode backbone + """ + S = features["S"] + X = features["X"] + Y = features["Y"] + Y_m = features["Y_m"] + Y_t = features["Y_t"] + mask = features["mask"] + R_idx = features["R_idx"] + chain_labels = features["chain_labels"] + + if self.training and self.augment_eps > 0: + X = X + self.augment_eps * torch.randn_like(X) + + Ca = X[:, :, self.CA_idx, :] + N = X[:, :, self.N_idx, :] + C = X[:, :, self.C_idx, :] + O = X[:, :, self.O_idx, :] + + b = Ca - N + c = C - Ca + a = torch.cross(b, c, dim=-1) + Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca # shift from CA + + _, E_idx = self._dist(Ca, mask) + + backbone_coords_list = [N, Ca, C, O, Cb] + + RBF_all = [] + for atom_1 in backbone_coords_list: + for atom_2 in backbone_coords_list: + RBF_all.append( + self._get_rbf( + atom_1, + atom_2, + E_idx, + D_mu_shape=[1, 1, 1, -1], + lower_bound=self.lower_bound, + upper_bound=self.upper_bound, + num_bins=self.num_rbf, + ) + ) + RBF_all = torch.cat(tuple(RBF_all), dim=-1) + + offset = R_idx[:, :, None] - R_idx[:, None, :] + offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K] + + d_chains = ( + (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0 + ).long() # find self vs non-self interaction + E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0] + E_positional = self.positional_embeddings(offset.long(), E_chains) + E = torch.cat((E_positional, RBF_all), -1) + E = self.enc_edge_embedding(E) + E = self.enc_norm_edges(E) + + V = torch.nn.functional.one_hot(S, self.enc_node_in).float() + V = self.enc_node_embedding(V) + V = self.enc_norm_nodes(V) + + Y_t = Y_t.long() + Y_t_g = self.periodic_table_features[1][Y_t] # group; 19 categories including 0 + Y_t_p = self.periodic_table_features[2][Y_t] # period; 8 categories including 0 + + Y_t_g_1hot_ = torch.nn.functional.one_hot(Y_t_g, 19) # [B, L, M, 19] + Y_t_p_1hot_ = torch.nn.functional.one_hot(Y_t_p, 8) # [B, L, M, 8] + Y_t_1hot_ = torch.nn.functional.one_hot(Y_t, 120) # [B, L, M, 120] + + Y_t_1hot_ = torch.cat( + [Y_t_1hot_, Y_t_g_1hot_, Y_t_p_1hot_], -1 + ) # [B, L, M, 147] + Y_t_1hot = self.type_linear(Y_t_1hot_.float()) + + D_N_Y = torch.sqrt( + torch.sum((N[:, :, None, :] - Y) ** 2, -1) + 1e-6 + ) # [B, L, M, num_bins] + D_N_Y = self._rbf( + D_N_Y, + D_mu_shape=[1, 1, 1, -1], + lower_bound=self.lower_bound, + upper_bound=self.upper_bound, + num_bins=self.num_rbf, + ) + + D_Ca_Y = torch.sqrt( + torch.sum((Ca[:, :, None, :] - Y) ** 2, -1) + 1e-6 + ) # [B, L, M, num_bins] + D_Ca_Y = self._rbf( + D_Ca_Y, + D_mu_shape=[1, 1, 1, -1], + lower_bound=self.lower_bound, + upper_bound=self.upper_bound, + num_bins=self.num_rbf, + ) + + D_C_Y = torch.sqrt( + torch.sum((C[:, :, None, :] - Y) ** 2, -1) + 1e-6 + ) # [B, L, M, num_bins] + D_C_Y = self._rbf( + D_C_Y, + D_mu_shape=[1, 1, 1, -1], + lower_bound=self.lower_bound, + upper_bound=self.upper_bound, + num_bins=self.num_rbf, + ) + + D_O_Y = torch.sqrt( + torch.sum((O[:, :, None, :] - Y) ** 2, -1) + 1e-6 + ) # [B, L, M, num_bins] + D_O_Y = self._rbf( + D_O_Y, + D_mu_shape=[1, 1, 1, -1], + lower_bound=self.lower_bound, + upper_bound=self.upper_bound, + num_bins=self.num_rbf, + ) + + D_Cb_Y = torch.sqrt( + torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) + 1e-6 + ) # [B, L, M, num_bins] + D_Cb_Y = self._rbf( + D_Cb_Y, + D_mu_shape=[1, 1, 1, -1], + lower_bound=self.lower_bound, + upper_bound=self.upper_bound, + num_bins=self.num_rbf, + ) + + f_angles = self._make_angle_features(N, Ca, C, Y) + + D_all = torch.cat( + (D_N_Y, D_Ca_Y, D_C_Y, D_O_Y, D_Cb_Y, Y_t_1hot, f_angles), dim=-1 + ) # [B,L,M,5*num_bins+5] + E_context = self.node_project_down(D_all) # [B, L, M, node_features] + E_context = self.norm_nodes(E_context) + + Y_edges = self._rbf( + torch.sqrt( + torch.sum((Y[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6 + ) + ) # [B, L, M, M, num_bins] + + Y_edges = self.y_edges(Y_edges) + Y_nodes = self.y_nodes(Y_t_1hot_.float()) + + Y_edges = self.norm_y_edges(Y_edges) + Y_nodes = self.norm_y_nodes(Y_nodes) + + return V, E, E_idx, Y_nodes, Y_edges, E_context, Y_m + + def features_decode(self, features): + """ + Make features for decoding. Explicit side chain atom and other atom distances. + """ + + S = features["S"] + X = features["X"] + X_m = features["X_m"] + mask = features["mask"] + E_idx = features["E_idx"] + + Y = features["Y"][:, :, : self.atom_context_num] + Y_m = features["Y_m"][:, :, : self.atom_context_num] + Y_t = features["Y_t"][:, :, : self.atom_context_num] + + X_m = X_m * mask[:, :, None] + device = S.device + + B, L, _, _ = X.shape + + RBF_sidechain = [] + X_m_gathered = gather_nodes(X_m, E_idx) # [B, L, K, 14] + + for i in range(14): + for j in range(14): + rbf_features = self._get_rbf( + X[:, :, i, :], + X[:, :, j, :], + E_idx, + D_mu_shape=[1, 1, 1, -1], + lower_bound=self.lower_bound, + upper_bound=self.upper_bound, + num_bins=self.num_rbf, + ) + rbf_features = ( + rbf_features + * X_m[:, :, i, None, None] + * X_m_gathered[:, :, :, j, None] + ) + RBF_sidechain.append(rbf_features) + + D_XY = torch.sqrt( + torch.sum((X[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6 + ) # [B, L, 14, atom_context_num] + XY_features = self._rbf( + D_XY, + D_mu_shape=[1, 1, 1, 1, -1], + lower_bound=self.lower_bound, + upper_bound=self.upper_bound, + num_bins=self.num_rbf, + ) # [B, L, 14, atom_context_num, num_rbf] + XY_features = XY_features * X_m[:, :, :, None, None] * Y_m[:, :, None, :, None] + + Y_t_1hot = torch.nn.functional.one_hot( + Y_t.long(), 120 + ).float() # [B, L, atom_context_num, 120] + XY_Y_t = torch.cat( + [XY_features, Y_t_1hot[:, :, None, :, :].repeat(1, 1, 14, 1, 1)], -1 + ) # [B, L, 14, atom_context_num, num_rbf+120] + XY_Y_t = self.W_XY_project_down1( + XY_Y_t + ) # [B, L, 14, atom_context_num, num_rbf] + XY_features = XY_Y_t.view([B, L, -1]) + + V = self.dec_node_embedding1(XY_features) + V = self.dec_norm_nodes1(V) + + S_1h = torch.nn.functional.one_hot(S, self.enc_node_in).float() + S_1h_gathered = gather_nodes(S_1h, E_idx) # [B, L, K, 21] + S_features = torch.cat( + [S_1h[:, :, None, :].repeat(1, 1, E_idx.shape[2], 1), S_1h_gathered], -1 + ) # [B, L, K, 42] + + F = torch.cat( + tuple(RBF_sidechain), dim=-1 + ) # [B,L,atom_context_num,14*14*num_rbf] + F = torch.cat([F, S_features], -1) + F = self.dec_edge_embedding1(F) + F = self.dec_norm_edges1(F) + return V, F diff --git a/score.py b/score.py new file mode 100644 index 0000000..9ef7449 --- /dev/null +++ b/score.py @@ -0,0 +1,549 @@ +import argparse +import json +import os.path +import random +import sys + +import numpy as np +import torch + +from data_utils import ( + element_dict_rev, + alphabet, + restype_int_to_str, + featurize, + parse_PDB, +) +from model_utils import ProteinMPNN + + +def main(args) -> None: + """ + Inference function + """ + if args.seed: + seed = args.seed + else: + seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0]) + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") + folder_for_outputs = args.out_folder + base_folder = folder_for_outputs + if base_folder[-1] != "/": + base_folder = base_folder + "/" + if not os.path.exists(base_folder): + os.makedirs(base_folder, exist_ok=True) + if args.model_type == "protein_mpnn": + checkpoint_path = args.checkpoint_protein_mpnn + elif args.model_type == "ligand_mpnn": + checkpoint_path = args.checkpoint_ligand_mpnn + elif args.model_type == "per_residue_label_membrane_mpnn": + checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn + elif args.model_type == "global_label_membrane_mpnn": + checkpoint_path = args.checkpoint_global_label_membrane_mpnn + elif args.model_type == "soluble_mpnn": + checkpoint_path = args.checkpoint_soluble_mpnn + else: + print("Choose one of the available models") + sys.exit() + checkpoint = torch.load(checkpoint_path, map_location=device) + if args.model_type == "ligand_mpnn": + atom_context_num = checkpoint["atom_context_num"] + ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context + k_neighbors = checkpoint["num_edges"] + else: + atom_context_num = 1 + ligand_mpnn_use_side_chain_context = 0 + k_neighbors = checkpoint["num_edges"] + + model = ProteinMPNN( + node_features=128, + edge_features=128, + hidden_dim=128, + num_encoder_layers=3, + num_decoder_layers=3, + k_neighbors=k_neighbors, + device=device, + atom_context_num=atom_context_num, + model_type=args.model_type, + ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context, + ) + + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(device) + model.eval() + + if args.pdb_path_multi: + with open(args.pdb_path_multi, "r") as fh: + pdb_paths = list(json.load(fh)) + else: + pdb_paths = [args.pdb_path] + + if args.fixed_residues_multi: + with open(args.fixed_residues_multi, "r") as fh: + fixed_residues_multi = json.load(fh) + else: + fixed_residues = [item for item in args.fixed_residues.split()] + fixed_residues_multi = {} + for pdb in pdb_paths: + fixed_residues_multi[pdb] = fixed_residues + + if args.redesigned_residues_multi: + with open(args.redesigned_residues_multi, "r") as fh: + redesigned_residues_multi = json.load(fh) + else: + redesigned_residues = [item for item in args.redesigned_residues.split()] + redesigned_residues_multi = {} + for pdb in pdb_paths: + redesigned_residues_multi[pdb] = redesigned_residues + + # loop over PDB paths + for pdb in pdb_paths: + if args.verbose: + print("Designing protein from this path:", pdb) + fixed_residues = fixed_residues_multi[pdb] + redesigned_residues = redesigned_residues_multi[pdb] + protein_dict, backbone, other_atoms, icodes, _ = parse_PDB( + pdb, + device=device, + chains=args.parse_these_chains_only, + parse_all_atoms=args.ligand_mpnn_use_side_chain_context, + parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy + ) + # make chain_letter + residue_idx + insertion_code mapping to integers + R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices + chain_letters_list = list(protein_dict["chain_letters"]) # chain letters + encoded_residues = [] + for i, R_idx_item in enumerate(R_idx_list): + tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i] + encoded_residues.append(tmp) + encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues)))) + encoded_residue_dict_rev = dict( + zip(list(range(len(encoded_residues))), encoded_residues) + ) + + fixed_positions = torch.tensor( + [int(item not in fixed_residues) for item in encoded_residues], + device=device, + ) + redesigned_positions = torch.tensor( + [int(item not in redesigned_residues) for item in encoded_residues], + device=device, + ) + + # specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model + if args.transmembrane_buried: + buried_residues = [item for item in args.transmembrane_buried.split()] + buried_positions = torch.tensor( + [int(item in buried_residues) for item in encoded_residues], + device=device, + ) + else: + buried_positions = torch.zeros_like(fixed_positions) + + if args.transmembrane_interface: + interface_residues = [item for item in args.transmembrane_interface.split()] + interface_positions = torch.tensor( + [int(item in interface_residues) for item in encoded_residues], + device=device, + ) + else: + interface_positions = torch.zeros_like(fixed_positions) + protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * ( + 1 - interface_positions + ) + 1 * interface_positions * (1 - buried_positions) + + if args.model_type == "global_label_membrane_mpnn": + protein_dict["membrane_per_residue_labels"] = ( + args.global_transmembrane_label + 0 * fixed_positions + ) + if type(args.chains_to_design) == str: + chains_to_design_list = args.chains_to_design.split(",") + else: + chains_to_design_list = protein_dict["chain_letters"] + chain_mask = torch.tensor( + np.array( + [ + item in chains_to_design_list + for item in protein_dict["chain_letters"] + ], + dtype=np.int32, + ), + device=device, + ) + + # create chain_mask to notify which residues are fixed (0) and which need to be designed (1) + if redesigned_residues: + protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions) + elif fixed_residues: + protein_dict["chain_mask"] = chain_mask * fixed_positions + else: + protein_dict["chain_mask"] = chain_mask + + if args.verbose: + PDB_residues_to_be_redesigned = [ + encoded_residue_dict_rev[item] + for item in range(protein_dict["chain_mask"].shape[0]) + if protein_dict["chain_mask"][item] == 1 + ] + PDB_residues_to_be_fixed = [ + encoded_residue_dict_rev[item] + for item in range(protein_dict["chain_mask"].shape[0]) + if protein_dict["chain_mask"][item] == 0 + ] + print("These residues will be redesigned: ", PDB_residues_to_be_redesigned) + print("These residues will be fixed: ", PDB_residues_to_be_fixed) + + # specify which residues are linked + if args.symmetry_residues: + symmetry_residues_list_of_lists = [ + x.split(",") for x in args.symmetry_residues.split("|") + ] + remapped_symmetry_residues = [] + for t_list in symmetry_residues_list_of_lists: + tmp_list = [] + for t in t_list: + tmp_list.append(encoded_residue_dict[t]) + remapped_symmetry_residues.append(tmp_list) + else: + remapped_symmetry_residues = [[]] + + if args.homo_oligomer: + if args.verbose: + print("Designing HOMO-OLIGOMER") + chain_letters_set = list(set(chain_letters_list)) + reference_chain = chain_letters_set[0] + lc = len(reference_chain) + residue_indices = [ + item[lc:] for item in encoded_residues if item[:lc] == reference_chain + ] + remapped_symmetry_residues = [] + for res in residue_indices: + tmp_list = [] + tmp_w_list = [] + for chain in chain_letters_set: + name = chain + res + tmp_list.append(encoded_residue_dict[name]) + tmp_w_list.append(1 / len(chain_letters_set)) + remapped_symmetry_residues.append(tmp_list) + + # set other atom bfactors to 0.0 + if other_atoms: + other_bfactors = other_atoms.getBetas() + other_atoms.setBetas(other_bfactors * 0.0) + + # adjust input PDB name by dropping .pdb if it does exist + name = pdb[pdb.rfind("/") + 1 :] + if name[-4:] == ".pdb": + name = name[:-4] + + with torch.no_grad(): + # run featurize to remap R_idx and add batch dimension + if args.verbose: + if "Y" in list(protein_dict): + atom_coords = protein_dict["Y"].cpu().numpy() + atom_types = list(protein_dict["Y_t"].cpu().numpy()) + atom_mask = list(protein_dict["Y_m"].cpu().numpy()) + number_of_atoms_parsed = np.sum(atom_mask) + else: + print("No ligand atoms parsed") + number_of_atoms_parsed = 0 + atom_types = "" + atom_coords = [] + if number_of_atoms_parsed == 0: + print("No ligand atoms parsed") + elif args.model_type == "ligand_mpnn": + print( + f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}" + ) + for i, atom_type in enumerate(atom_types): + print( + f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}" + ) + feature_dict = featurize( + protein_dict, + cutoff_for_score=args.ligand_mpnn_cutoff_for_score, + use_atom_context=args.ligand_mpnn_use_atom_context, + number_of_ligand_atoms=atom_context_num, + model_type=args.model_type, + ) + feature_dict["batch_size"] = args.batch_size + B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now. + # add additional keys to the feature dictionary + feature_dict["symmetry_residues"] = remapped_symmetry_residues + + logits_list = [] + probs_list = [] + log_probs_list = [] + decoding_order_list = [] + for _ in range(args.number_of_batches): + feature_dict["randn"] = torch.randn( + [feature_dict["batch_size"], feature_dict["mask"].shape[1]], + device=device, + ) + if args.autoregressive_score: + score_dict = model.score(feature_dict, use_sequence=args.use_sequence) + elif args.single_aa_score: + score_dict = model.single_aa_score(feature_dict, use_sequence=args.use_sequence) + else: + print("Set either autoregressive_score or single_aa_score to True") + sys.exit() + logits_list.append(score_dict["logits"]) + log_probs_list.append(score_dict["log_probs"]) + probs_list.append(torch.exp(score_dict["log_probs"])) + decoding_order_list.append(score_dict["decoding_order"]) + log_probs_stack = torch.cat(log_probs_list, 0) + logits_stack = torch.cat(logits_list, 0) + probs_stack = torch.cat(probs_list, 0) + decoding_order_stack = torch.cat(decoding_order_list, 0) + + output_stats_path = base_folder + name + args.file_ending + ".pt" + out_dict = {} + out_dict["logits"] = logits_stack.cpu().numpy() + out_dict["probs"] = probs_stack.cpu().numpy() + out_dict["log_probs"] = log_probs_stack.cpu().numpy() + out_dict["decoding_order"] = decoding_order_stack.cpu().numpy() + out_dict["native_sequence"] = feature_dict["S"][0].cpu().numpy() + out_dict["mask"] = feature_dict["mask"][0].cpu().numpy() + out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu().numpy() #this affects decoding order + out_dict["seed"] = seed + out_dict["alphabet"] = alphabet + out_dict["residue_names"] = encoded_residue_dict_rev + + mean_probs = np.mean(out_dict["probs"], 0) + std_probs = np.std(out_dict["probs"], 0) + sequence = [restype_int_to_str[AA] for AA in out_dict["native_sequence"]] + mean_dict = {} + std_dict = {} + for residue in range(L): + mean_dict_ = dict(zip(alphabet, mean_probs[residue])) + mean_dict[encoded_residue_dict_rev[residue]] = mean_dict_ + std_dict_ = dict(zip(alphabet, std_probs[residue])) + std_dict[encoded_residue_dict_rev[residue]] = std_dict_ + + out_dict["sequence"] = sequence + out_dict["mean_of_probs"] = mean_dict + out_dict["std_of_probs"] = std_dict + torch.save(out_dict, output_stats_path) + + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + argparser.add_argument( + "--model_type", + type=str, + default="protein_mpnn", + help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn", + ) + # protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms + # ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB + # per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed + # global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane + # soluble_mpnn - ProteinMPNN trained only on soluble PDB ids + argparser.add_argument( + "--checkpoint_protein_mpnn", + type=str, + default="./model_params/proteinmpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_ligand_mpnn", + type=str, + default="./model_params/ligandmpnn_v_32_010_25.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_per_residue_label_membrane_mpnn", + type=str, + default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_global_label_membrane_mpnn", + type=str, + default="./model_params/global_label_membrane_mpnn_v_48_020.pt", + help="Path to model weights.", + ) + argparser.add_argument( + "--checkpoint_soluble_mpnn", + type=str, + default="./model_params/solublempnn_v_48_020.pt", + help="Path to model weights.", + ) + + argparser.add_argument("--verbose", type=int, default=1, help="Print stuff") + + argparser.add_argument( + "--pdb_path", type=str, default="", help="Path to the input PDB." + ) + argparser.add_argument( + "--pdb_path_multi", + type=str, + default="", + help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.", + ) + + argparser.add_argument( + "--fixed_residues", + type=str, + default="", + help="Provide fixed residues, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--fixed_residues_multi", + type=str, + default="", + help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}", + ) + + argparser.add_argument( + "--redesigned_residues", + type=str, + default="", + help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--redesigned_residues_multi", + type=str, + default="", + help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}", + ) + + argparser.add_argument( + "--symmetry_residues", + type=str, + default="", + help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'", + ) + + argparser.add_argument( + "--homo_oligomer", + type=int, + default=0, + help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.", + ) + + argparser.add_argument( + "--out_folder", + type=str, + help="Path to a folder to output scores, e.g. /home/out/", + ) + argparser.add_argument( + "--file_ending", type=str, default="", help="adding_string_to_the_end" + ) + argparser.add_argument( + "--zero_indexed", + type=str, + default=0, + help="1 - to start output PDB numbering with 0", + ) + argparser.add_argument( + "--seed", + type=int, + default=0, + help="Set seed for torch, numpy, and python random.", + ) + argparser.add_argument( + "--batch_size", + type=int, + default=1, + help="Number of sequence to generate per one pass.", + ) + argparser.add_argument( + "--number_of_batches", + type=int, + default=1, + help="Number of times to design sequence using a chosen batch size.", + ) + + argparser.add_argument( + "--ligand_mpnn_use_atom_context", + type=int, + default=1, + help="1 - use atom context, 0 - do not use atom context.", + ) + + argparser.add_argument( + "--ligand_mpnn_use_side_chain_context", + type=int, + default=0, + help="Flag to use side chain atoms as ligand context for the fixed residues", + ) + + argparser.add_argument( + "--ligand_mpnn_cutoff_for_score", + type=float, + default=8.0, + help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.", + ) + + argparser.add_argument( + "--chains_to_design", + type=str, + default=None, + help="Specify which chains to redesign, all others will be kept fixed.", + ) + + argparser.add_argument( + "--parse_these_chains_only", + type=str, + default="", + help="Provide chains letters for parsing backbones, 'ABCF'", + ) + + argparser.add_argument( + "--transmembrane_buried", + type=str, + default="", + help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25", + ) + argparser.add_argument( + "--transmembrane_interface", + type=str, + default="", + help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25", + ) + + argparser.add_argument( + "--global_transmembrane_label", + type=int, + default=0, + help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble", + ) + + argparser.add_argument( + "--parse_atoms_with_zero_occupancy", + type=int, + default=0, + help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy", + ) + + argparser.add_argument( + "--use_sequence", + type=int, + default=1, + help="1 - get scores using amino acid sequence info; 0 - get scores using backbone info only", + ) + + argparser.add_argument( + "--autoregressive_score", + type=int, + default=0, + help="1 - run autoregressive scoring function; p(AA_1|backbone); p(AA_2|backbone, AA_1) etc, 0 - False", + ) + + argparser.add_argument( + "--single_aa_score", + type=int, + default=1, + help="1 - run single amino acid scoring function; p(AA_i|backbone, AA_{all except ith one}), 0 - False", + ) + + args = argparser.parse_args() + main(args)