Add LigandMPNN Nextflow pipeline for protein sequence design

This commit is contained in:
2026-03-18 22:31:13 +01:00
commit e7261ba7ce
15 changed files with 6825 additions and 0 deletions

96
Dockerfile Normal file
View File

@@ -0,0 +1,96 @@
FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04
# Set environment variables
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Set working directory
WORKDIR /app
# Install system dependencies including build tools for ProDy compilation
RUN apt-get update -y && \
apt-get install -y --no-install-recommends \
wget \
git \
curl \
ca-certificates \
python3.11 \
python3.11-venv \
python3.11-dev \
python3-pip \
procps \
build-essential \
&& rm -rf /var/lib/apt/lists/* \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 \
&& update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
# Upgrade pip
RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel
# Clone LigandMPNN repository
RUN git clone https://github.com/dauparas/LigandMPNN.git /app/LigandMPNN
# Set working directory to LigandMPNN
WORKDIR /app/LigandMPNN
# Install Python dependencies
RUN pip3 install --no-cache-dir \
biopython==1.79 \
filelock==3.13.1 \
fsspec==2024.3.1 \
Jinja2==3.1.3 \
MarkupSafe==2.1.5 \
mpmath==1.3.0 \
networkx==3.2.1 \
numpy==1.23.5 \
ProDy==2.4.1 \
pyparsing==3.1.1 \
scipy==1.12.0 \
sympy==1.12 \
typing_extensions==4.10.0 \
ml-collections==0.1.1 \
dm-tree==0.1.8
# Install PyTorch with CUDA support
RUN pip3 install --no-cache-dir \
torch==2.2.1 \
--index-url https://download.pytorch.org/whl/cu121
# Download model parameters
RUN mkdir -p /app/LigandMPNN/model_params && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_002.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_002.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_010.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_010.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_020.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_030.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_030.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_005_25.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_010_25.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_020_25.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_030_25.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_002.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_002.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_010.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_010.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_020.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_030.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_030.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/per_residue_label_membrane_mpnn_v_48_020.pt -O /app/LigandMPNN/model_params/per_residue_label_membrane_mpnn_v_48_020.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/global_label_membrane_mpnn_v_48_020.pt -O /app/LigandMPNN/model_params/global_label_membrane_mpnn_v_48_020.pt && \
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_sc_v_32_002_16.pt -O /app/LigandMPNN/model_params/ligandmpnn_sc_v_32_002_16.pt && \
ls -la /app/LigandMPNN/model_params/
# Create wrapper script for ligandmpnn (using absolute path)
RUN printf '#!/bin/bash\npython /app/LigandMPNN/run.py "$@"\n' > /usr/local/bin/ligandmpnn && \
chmod +x /usr/local/bin/ligandmpnn
# Create wrapper script for scoring (using absolute path)
RUN printf '#!/bin/bash\npython /app/LigandMPNN/score.py "$@"\n' > /usr/local/bin/ligandmpnn-score && \
chmod +x /usr/local/bin/ligandmpnn-score
# Create input/output directories
RUN mkdir -p /app/inputs /app/outputs
# Set environment variables for model paths
ENV LIGANDMPNN_MODEL_DIR=/app/LigandMPNN/model_params
ENV PYTHONPATH=/app/LigandMPNN:${PYTHONPATH:-}
WORKDIR /app/LigandMPNN
CMD ["python", "run.py", "--help"]

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Justas Dauparas
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

636
README.md Normal file
View File

@@ -0,0 +1,636 @@
## LigandMPNN
This package provides inference code for [LigandMPNN](https://www.biorxiv.org/content/10.1101/2023.12.22.573103v1) & [ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) models. The code and model parameters are available under the MIT license.
Third party code: side chain packing uses helper functions from [Openfold](https://github.com/aqlaboratory/openfold).
### Running the code
```
git clone https://github.com/dauparas/LigandMPNN.git
cd LigandMPNN
bash get_model_params.sh "./model_params"
#setup your conda/or other environment
#conda create -n ligandmpnn_env python=3.11
#pip3 install -r requirements.txt
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/default"
```
### Dependencies
To run the model you will need to have Python>=3.0, PyTorch, Numpy installed, and to read/write PDB files you will need [Prody](https://pypi.org/project/ProDy/).
For example to make a new conda environment for LigandMPNN run:
```
conda create -n ligandmpnn_env python=3.11
pip3 install -r requirements.txt
```
### Main differences compared with [ProteinMPNN](https://github.com/dauparas/ProteinMPNN) code
- Input PDBs are parsed using [Prody](https://pypi.org/project/ProDy/) preserving protein residue indices, chain letters, and insertion codes. If there are missing residues in the input structure the output fasta file won't have added `X` to fill the gaps. The script outputs .fasta and .pdb files. It's recommended to use .pdb files since they will hold information about chain letters and residue indices.
- Adding bias, fixing residues, and selecting residues to be redesigned now can be done using residue indices directly, e.g. A23 (means chain A residue with index 23), B42D (chain B, residue 42, insertion code D).
- Model writes to fasta files: `overall_confidence`, `ligand_confidence` which reflect the average confidence/probability (with T=1.0) over the redesigned residues `overall_confidence=exp[-mean_over_residues(log_probs)]`. Higher numbers mean the model is more confident about that sequence. min_value=0.0; max_value=1.0. Sequence recovery with respect to the input sequence is calculated only over the redesigned residues.
### Model parameters
To download model parameters run:
```
bash get_model_params.sh "./model_params"
```
### Available models
To run the model of your choice specify `--model_type` and optionally the model checkpoint path. Available models:
- ProteinMPNN
```
--model_type "protein_mpnn"
--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_002.pt" #noised with 0.02A Gaussian noise
--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_010.pt" #noised with 0.10A Gaussian noise
--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_020.pt" #noised with 0.20A Gaussian noise
--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_030.pt" #noised with 0.30A Gaussian noise
```
- LigandMPNN
```
--model_type "ligand_mpnn"
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" #noised with 0.05A Gaussian noise
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_010_25.pt" #noised with 0.10A Gaussian noise
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_020_25.pt" #noised with 0.20A Gaussian noise
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_030_25.pt" #noised with 0.30A Gaussian noise
```
- SolubleMPNN
```
--model_type "soluble_mpnn"
--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_002.pt" #noised with 0.02A Gaussian noise
--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_010.pt" #noised with 0.10A Gaussian noise
--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_020.pt" #noised with 0.20A Gaussian noise
--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_030.pt" #noised with 0.30A Gaussian noise
```
- ProteinMPNN with global membrane label
```
--model_type "global_label_membrane_mpnn"
--checkpoint_global_label_membrane_mpnn "./model_params/global_label_membrane_mpnn_v_48_020.pt" #noised with 0.20A Gaussian noise
```
- ProteinMPNN with per residue membrane label
```
--model_type "per_residue_label_membrane_mpnn"
--checkpoint_per_residue_label_membrane_mpnn "./model_params/per_residue_label_membrane_mpnn_v_48_020.pt" #noised with 0.20A Gaussian noise
```
- Side chain packing model
```
--checkpoint_path_sc "./model_params/ligandmpnn_sc_v_32_002_16.pt"
```
## Design examples
### 1 default
Default settings will run ProteinMPNN.
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/default"
```
### 2 --temperature
`--temperature 0.05` Change sampling temperature (higher temperature gives more sequence diversity).
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--temperature 0.05 \
--out_folder "./outputs/temperature"
```
### 3 --seed
`--seed` Not selecting a seed will run with a random seed. Running this multiple times will give different results.
```
python run.py \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/random_seed"
```
### 4 --verbose
`--verbose 0` Do not print any statements.
```
python run.py \
--seed 111 \
--verbose 0 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/verbose"
```
### 5 --save_stats
`--save_stats 1` Save sequence design statistics.
```
#['generated_sequences', 'sampling_probs', 'log_probs', 'decoding_order', 'native_sequence', 'mask', 'chain_mask', 'seed', 'temperature']
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/save_stats" \
--save_stats 1
```
### 6 --fixed_residues
`--fixed_residues` Fixing specific amino acids. This example fixes the first 10 residues in chain C and adds global bias towards A (alanine). The output should have all alanines except the first 10 residues should be the same as in the input sequence since those are fixed.
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/fix_residues" \
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
--bias_AA "A:10.0"
```
### 7 --redesigned_residues
`--redesigned_residues` Specifying which residues need to be designed. This example redesigns the first 10 residues while fixing everything else.
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/redesign_residues" \
--redesigned_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
--bias_AA "A:10.0"
```
### 8 --number_of_batches
Design 15 sequences; with batch size 3 (can be 1 when using CPUs) and the number of batches 5.
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/batch_size" \
--batch_size 3 \
--number_of_batches 5
```
### 9 --bias_AA
Global amino acid bias. In this example, output sequences are biased towards W, P, C and away from A.
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--bias_AA "W:3.0,P:3.0,C:3.0,A:-3.0" \
--out_folder "./outputs/global_bias"
```
### 10 --bias_AA_per_residue
Specify per residue amino acid bias, e.g. make residues C1, C3, C5, and C7 to be prolines.
```
# {
# "C1": {"G": -0.3, "C": -2.0, "P": 10.8},
# "C3": {"P": 10.0},
# "C5": {"G": -1.3, "P": 10.0},
# "C7": {"G": -1.3, "P": 10.0}
# }
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--bias_AA_per_residue "./inputs/bias_AA_per_residue.json" \
--out_folder "./outputs/per_residue_bias"
```
### 11 --omit_AA
Global amino acid restrictions. This is equivalent to using `--bias_AA` and setting bias to be a large negative number. The output should be just made of E, K, A.
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--omit_AA "CDFGHILMNPQRSTVWY" \
--out_folder "./outputs/global_omit"
```
### 12 --omit_AA_per_residue
Per residue amino acid restrictions.
```
# {
# "C1": "ACDEFGHIKLMNPQRSTVW",
# "C3": "ACDEFGHIKLMNPQRSTVW",
# "C5": "ACDEFGHIKLMNPQRSTVW",
# "C7": "ACDEFGHIKLMNPQRSTVW"
# }
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--omit_AA_per_residue "./inputs/omit_AA_per_residue.json" \
--out_folder "./outputs/per_residue_omit"
```
### 13 --symmetry_residues
### 13 --symmetry_weights
Designing sequences with symmetry, e.g. homooligomer/2-state proteins, etc. In this example make C1=C2=C3, also C4=C5, and C6=C7.
```
#total_logits += symmetry_weights[t]*logits
#probs = torch.nn.functional.softmax((total_logits+bias_t) / temperature, dim=-1)
#total_logits_123 = 0.33*logits_1+0.33*logits_2+0.33*logits_3
#output should be ***ooxx
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/symmetry" \
--symmetry_residues "C1,C2,C3|C4,C5|C6,C7" \
--symmetry_weights "0.33,0.33,0.33|0.5,0.5|0.5,0.5"
```
### 14 --homo_oligomer
Design homooligomer sequences. This automatically sets `--symmetry_residues` and `--symmetry_weights` assuming equal weighting from all chains.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/4GYT.pdb" \
--out_folder "./outputs/homooligomer" \
--homo_oligomer 1 \
--number_of_batches 2
```
### 15 --file_ending
Outputs will have a specified ending; e.g. `1BC8_xyz.fa` instead of `1BC8.fa`
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/file_ending" \
--file_ending "_xyz"
```
### 16 --zero_indexed
Zero indexed names in /backbones/1BC8_0.pdb, 1BC8_1.pdb, 1BC8_2.pdb etc
```
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/zero_indexed" \
--zero_indexed 1 \
--number_of_batches 2
```
### 17 --chains_to_design
Specify which chains (e.g. "A,B,C") need to be redesigned, other chains will be kept fixed. Outputs in seqs/backbones will still have atoms/sequences for the whole input PDB.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/4GYT.pdb" \
--out_folder "./outputs/chains_to_design" \
--chains_to_design "A,B"
```
### 18 --parse_these_chains_only
Parse and design only specified chains (e.g. "A,B,C"). Outputs will have only specified chains.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/4GYT.pdb" \
--out_folder "./outputs/parse_these_chains_only" \
--parse_these_chains_only "A,B"
```
### 19 --model_type "ligand_mpnn"
Run LigandMPNN with default settings.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/ligandmpnn_default"
```
### 20 --checkpoint_ligand_mpnn
Run LigandMPNN using 0.05A model by specifying `--checkpoint_ligand_mpnn` flag.
```
python run.py \
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/ligandmpnn_v_32_005_25"
```
### 21 --ligand_mpnn_use_atom_context
Setting `--ligand_mpnn_use_atom_context 0` will mask all ligand atoms. This can be used to assess how much ligand atoms affect AA probabilities.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/ligandmpnn_no_context" \
--ligand_mpnn_use_atom_context 0
```
### 22 --ligand_mpnn_use_side_chain_context
Use fixed residue side chain atoms as extra ligand atoms.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/ligandmpnn_use_side_chain_atoms" \
--ligand_mpnn_use_side_chain_context 1 \
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10"
```
### 23 --model_type "soluble_mpnn"
Run SolubleMPNN (ProteinMPNN-like model with only soluble proteins in the training dataset).
```
python run.py \
--model_type "soluble_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/soluble_mpnn_default"
```
### 24 --model_type "global_label_membrane_mpnn"
Run global label membrane MPNN (trained with extra input - binary label soluble vs not) `--global_transmembrane_label #1 - membrane, 0 - soluble`.
```
python run.py \
--model_type "global_label_membrane_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/global_label_membrane_mpnn_0" \
--global_transmembrane_label 0
```
### 25 --model_type "per_residue_label_membrane_mpnn"
Run per residue label membrane MPNN (trained with extra input per residue specifying buried (hydrophobic), interface (polar), or other type residues; 3 classes).
```
python run.py \
--model_type "per_residue_label_membrane_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/per_residue_label_membrane_mpnn_default" \
--transmembrane_buried "C1 C2 C3 C11" \
--transmembrane_interface "C4 C5 C6 C22"
```
### 26 --fasta_seq_separation
Choose a symbol to put between different chains in fasta output format. It's recommended to PDB output format to deal with residue jumps and multiple chain parsing.
```
python run.py \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/fasta_seq_separation" \
--fasta_seq_separation ":"
```
### 27 --pdb_path_multi
Specify multiple PDB input paths. This is more efficient since the model needs to be loaded from the checkpoint once.
```
#{
#"./inputs/1BC8.pdb": "",
#"./inputs/4GYT.pdb": ""
#}
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--out_folder "./outputs/pdb_path_multi" \
--seed 111
```
### 28 --fixed_residues_multi
Specify fixed residues when using `--pdb_path_multi` flag.
```
#{
#"./inputs/1BC8.pdb": "C1 C2 C3 C4 C5 C10 C22",
#"./inputs/4GYT.pdb": "A7 A8 A9 A10 A11 A12 A13 B38"
#}
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--fixed_residues_multi "./inputs/fix_residues_multi.json" \
--out_folder "./outputs/fixed_residues_multi" \
--seed 111
```
### 29 --redesigned_residues_multi
Specify which residues need to be redesigned when using `--pdb_path_multi` flag.
```
#{
#"./inputs/1BC8.pdb": "C1 C2 C3 C4 C5 C10",
#"./inputs/4GYT.pdb": "A7 A8 A9 A10 A12 A13 B38"
#}
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--redesigned_residues_multi "./inputs/redesigned_residues_multi.json" \
--out_folder "./outputs/redesigned_residues_multi" \
--seed 111
```
### 30 --omit_AA_per_residue_multi
Specify which residues need to be omitted when using `--pdb_path_multi` flag.
```
#{
#"./inputs/1BC8.pdb": {"C1":"ACDEFGHILMNPQRSTVWY", "C2":"ACDEFGHILMNPQRSTVWY", "C3":"ACDEFGHILMNPQRSTVWY"},
#"./inputs/4GYT.pdb": {"A7":"ACDEFGHILMNPQRSTVWY", "A8":"ACDEFGHILMNPQRSTVWY"}
#}
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--omit_AA_per_residue_multi "./inputs/omit_AA_per_residue_multi.json" \
--out_folder "./outputs/omit_AA_per_residue_multi" \
--seed 111
```
### 31 --bias_AA_per_residue_multi
Specify amino acid biases per residue when using `--pdb_path_multi` flag.
```
#{
#"./inputs/1BC8.pdb": {"C1":{"A":3.0, "P":-2.0}, "C2":{"W":10.0, "G":-0.43}},
#"./inputs/4GYT.pdb": {"A7":{"Y":5.0, "S":-2.0}, "A8":{"M":3.9, "G":-0.43}}
#}
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--bias_AA_per_residue_multi "./inputs/bias_AA_per_residue_multi.json" \
--out_folder "./outputs/bias_AA_per_residue_multi" \
--seed 111
```
### 32 --ligand_mpnn_cutoff_for_score
This sets the cutoff distance in angstroms to select residues that are considered to be close to ligand atoms. This flag only affects the `num_ligand_res` and `ligand_confidence` in the output fasta files.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--ligand_mpnn_cutoff_for_score "6.0" \
--out_folder "./outputs/ligand_mpnn_cutoff_for_score"
```
### 33 specifying residues with insertion codes
You can specify residue using chain_id + residue_number + insersion_code; e.g. redesign only residue B82, B82A, B82B, B82C.
```
python run.py \
--seed 111 \
--pdb_path "./inputs/2GFB.pdb" \
--out_folder "./outputs/insertion_code" \
--redesigned_residues "B82 B82A B82B B82C" \
--parse_these_chains_only "B"
```
### 34 parse atoms with zero occupancy
Parse atoms in the PDB files with zero occupancy too.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/parse_atoms_with_zero_occupancy" \
--parse_atoms_with_zero_occupancy 1
```
## Scoring examples
### Output dictionary
```
out_dict = {}
out_dict["logits"] - raw logits from the model
out_dict["probs"] - softmax(logits)
out_dict["log_probs"] - log_softmax(logits)
out_dict["decoding_order"] - decoding order used (logits will depend on the decoding order)
out_dict["native_sequence"] - parsed input sequence in integers
out_dict["mask"] - mask for missing residues (usually all ones)
out_dict["chain_mask"] - controls which residues are decoded first
out_dict["alphabet"] - amino acid alphabet used
out_dict["residue_names"] - dictionary to map integers to residue_names, e.g. {0: "C10", 1: "C11"}
out_dict["sequence"] - parsed input sequence in alphabet
out_dict["mean_of_probs"] - averaged over batch_size*number_of_batches probabilities, [protein_length, 21]
out_dict["std_of_probs"] - same as above, but std
```
### 1 autoregressive with sequence info
Get probabilities/scores for backbone-sequence pairs using autoregressive probabilities: p(AA_1|backbone), p(AA_2|backbone, AA_1) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
```
python score.py \
--model_type "ligand_mpnn" \
--seed 111 \
--autoregressive_score 1\
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
--out_folder "./outputs/autoregressive_score_w_seq" \
--use_sequence 1\
--batch_size 1 \
--number_of_batches 10
```
### 2 autoregressive with backbone info only
Get probabilities/scores for backbone using probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
```
python score.py \
--model_type "ligand_mpnn" \
--seed 111 \
--autoregressive_score 1\
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
--out_folder "./outputs/autoregressive_score_wo_seq" \
--use_sequence 0\
--batch_size 1 \
--number_of_batches 10
```
### 3 single amino acid score with sequence info
Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone, AA_{all except AA_1}), p(AA_2|backbone, AA_{all except AA_2}) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
```
python score.py \
--model_type "ligand_mpnn" \
--seed 111 \
--single_aa_score 1\
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
--out_folder "./outputs/single_aa_score_w_seq" \
--use_sequence 1\
--batch_size 1 \
--number_of_batches 10
```
### 4 single amino acid score with backbone info only
Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
```
python score.py \
--model_type "ligand_mpnn" \
--seed 111 \
--single_aa_score 1\
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
--out_folder "./outputs/single_aa_score_wo_seq" \
--use_sequence 0\
--batch_size 1 \
--number_of_batches 10
```
## Side chain packing examples
### 1 design a new sequence and pack side chains (return 1 side chain packing sample - fast)
Design a new sequence using any of the available models and also pack side chains of the new sequence. Return only a single solution for the side chain packing.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_default_fast" \
--pack_side_chains 1 \
--number_of_packs_per_design 0 \
--pack_with_ligand_context 1
```
### 2 design a new sequence and pack side chains (return 4 side chain packing samples)
Same as above, but returns 4 independent samples for side chains. b-factor shows log prob density per chi angle group.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_default" \
--pack_side_chains 1 \
--number_of_packs_per_design 4 \
--pack_with_ligand_context 1
```
### 3 fix specific residues fors sequence design and packing
This option will not repack side chains of the fixed residues, but use them as a context.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_fixed_residues" \
--pack_side_chains 1 \
--number_of_packs_per_design 4 \
--pack_with_ligand_context 1 \
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
--repack_everything 0
```
### 4 fix specific residues for sequence design but repack everything
This option will repacks all the residues.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_fixed_residues_full_repack" \
--pack_side_chains 1 \
--number_of_packs_per_design 4 \
--pack_with_ligand_context 1 \
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
--repack_everything 1
```
### 5 design a new sequence using LigandMPNN but pack side chains without considering ligand/DNA etc atoms
You can run side chain packing without taking into account context atoms like DNA atoms. This most likely will results in side chain clashing with context atoms, but it might be interesting to see how model's uncertainty changes when ligand atoms are present vs not for side chain conformations.
```
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_no_context" \
--pack_side_chains 1 \
--number_of_packs_per_design 4 \
--pack_with_ligand_context 0
```
### Things to add
- Support for ProteinMPNN CA-only model.
- Examples for scoring sequences only.
- Side-chain packing scripts.
- TER
### Citing this work
If you use the code, please cite:
```
@article{dauparas2023atomic,
title={Atomic context-conditioned protein sequence design using LigandMPNN},
author={Dauparas, Justas and Lee, Gyu Rie and Pecoraro, Robert and An, Linna and Anishchenko, Ivan and Glasscock, Cameron and Baker, David},
journal={Biorxiv},
pages={2023--12},
year={2023},
publisher={Cold Spring Harbor Laboratory}
}
@article{dauparas2022robust,
title={Robust deep learning--based protein sequence design using ProteinMPNN},
author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
journal={Science},
volume={378},
number={6615},
pages={49--56},
year={2022},
publisher={American Association for the Advancement of Science}
}
```

988
data_utils.py Normal file
View File

@@ -0,0 +1,988 @@
from __future__ import print_function
import numpy as np
import torch
import torch.utils
from prody import *
confProDy(verbosity="none")
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
"X": "UNK",
}
restype_str_to_int = {
"A": 0,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"V": 17,
"W": 18,
"Y": 19,
"X": 20,
}
restype_int_to_str = {
0: "A",
1: "C",
2: "D",
3: "E",
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X",
}
alphabet = list(restype_str_to_int)
element_list = [
"H",
"He",
"Li",
"Be",
"B",
"C",
"N",
"O",
"F",
"Ne",
"Na",
"Mg",
"Al",
"Si",
"P",
"S",
"Cl",
"Ar",
"K",
"Ca",
"Sc",
"Ti",
"V",
"Cr",
"Mn",
"Fe",
"Co",
"Ni",
"Cu",
"Zn",
"Ga",
"Ge",
"As",
"Se",
"Br",
"Kr",
"Rb",
"Sr",
"Y",
"Zr",
"Nb",
"Mb",
"Tc",
"Ru",
"Rh",
"Pd",
"Ag",
"Cd",
"In",
"Sn",
"Sb",
"Te",
"I",
"Xe",
"Cs",
"Ba",
"La",
"Ce",
"Pr",
"Nd",
"Pm",
"Sm",
"Eu",
"Gd",
"Tb",
"Dy",
"Ho",
"Er",
"Tm",
"Yb",
"Lu",
"Hf",
"Ta",
"W",
"Re",
"Os",
"Ir",
"Pt",
"Au",
"Hg",
"Tl",
"Pb",
"Bi",
"Po",
"At",
"Rn",
"Fr",
"Ra",
"Ac",
"Th",
"Pa",
"U",
"Np",
"Pu",
"Am",
"Cm",
"Bk",
"Cf",
"Es",
"Fm",
"Md",
"No",
"Lr",
"Rf",
"Db",
"Sg",
"Bh",
"Hs",
"Mt",
"Ds",
"Rg",
"Cn",
"Uut",
"Fl",
"Uup",
"Lv",
"Uus",
"Uuo",
]
element_list = [item.upper() for item in element_list]
# element_dict = dict(zip(element_list, range(1,len(element_list))))
element_dict_rev = dict(zip(range(1, len(element_list)), element_list))
def get_seq_rec(S: torch.Tensor, S_pred: torch.Tensor, mask: torch.Tensor):
"""
S : true sequence shape=[batch, length]
S_pred : predicted sequence shape=[batch, length]
mask : mask to compute average over the region shape=[batch, length]
average : averaged sequence recovery shape=[batch]
"""
match = S == S_pred
average = torch.sum(match * mask, dim=-1) / torch.sum(mask, dim=-1)
return average
def get_score(S: torch.Tensor, log_probs: torch.Tensor, mask: torch.Tensor):
"""
S : true sequence shape=[batch, length]
log_probs : predicted sequence shape=[batch, length]
mask : mask to compute average over the region shape=[batch, length]
average_loss : averaged categorical cross entropy (CCE) [batch]
loss_per_resdue : per position CCE [batch, length]
"""
S_one_hot = torch.nn.functional.one_hot(S, 21)
loss_per_residue = -(S_one_hot * log_probs).sum(-1) # [B, L]
average_loss = torch.sum(loss_per_residue * mask, dim=-1) / (
torch.sum(mask, dim=-1) + 1e-8
)
return average_loss, loss_per_residue
def write_full_PDB(
save_path: str,
X: np.ndarray,
X_m: np.ndarray,
b_factors: np.ndarray,
R_idx: np.ndarray,
chain_letters: np.ndarray,
S: np.ndarray,
other_atoms=None,
icodes=None,
force_hetatm=False,
):
"""
save_path : path where the PDB will be written to
X : protein atom xyz coordinates shape=[length, 14, 3]
X_m : protein atom mask shape=[length, 14]
b_factors: shape=[length, 14]
R_idx: protein residue indices shape=[length]
chain_letters: protein chain letters shape=[length]
S : protein amino acid sequence shape=[length]
other_atoms: other atoms parsed by prody
icodes: a list of insertion codes for the PDB; e.g. antibody loops
"""
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
"X": "UNK",
}
restype_INTtoSTR = {
0: "A",
1: "C",
2: "D",
3: "E",
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X",
}
restype_name_to_atom14_names = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"NE",
"CZ",
"NH1",
"NH2",
"",
"",
"",
],
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"NE2",
"",
"",
"",
"",
"",
],
"GLU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"OE2",
"",
"",
"",
"",
"",
],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"ND1",
"CD2",
"CE1",
"NE2",
"",
"",
"",
"",
],
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
"PHE": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"",
"",
"",
],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE2",
"CE3",
"NE1",
"CZ2",
"CZ3",
"CH2",
],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
"",
"",
],
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
}
S_str = [restype_1to3[AA] for AA in [restype_INTtoSTR[AA] for AA in S]]
X_list = []
b_factor_list = []
atom_name_list = []
element_name_list = []
residue_name_list = []
residue_number_list = []
chain_id_list = []
icodes_list = []
for i, AA in enumerate(S_str):
sel = X_m[i].astype(np.int32) == 1
total = np.sum(sel)
tmp = np.array(restype_name_to_atom14_names[AA])[sel]
X_list.append(X[i][sel])
b_factor_list.append(b_factors[i][sel])
atom_name_list.append(tmp)
element_name_list += [AA[:1] for AA in list(tmp)]
residue_name_list += total * [AA]
residue_number_list += total * [R_idx[i]]
chain_id_list += total * [chain_letters[i]]
icodes_list += total * [icodes[i]]
X_stack = np.concatenate(X_list, 0)
b_factor_stack = np.concatenate(b_factor_list, 0)
atom_name_stack = np.concatenate(atom_name_list, 0)
protein = prody.AtomGroup()
protein.setCoords(X_stack)
protein.setBetas(b_factor_stack)
protein.setNames(atom_name_stack)
protein.setResnames(residue_name_list)
protein.setElements(element_name_list)
protein.setOccupancies(np.ones([X_stack.shape[0]]))
protein.setResnums(residue_number_list)
protein.setChids(chain_id_list)
protein.setIcodes(icodes_list)
if other_atoms:
other_atoms_g = prody.AtomGroup()
other_atoms_g.setCoords(other_atoms.getCoords())
other_atoms_g.setNames(other_atoms.getNames())
other_atoms_g.setResnames(other_atoms.getResnames())
other_atoms_g.setElements(other_atoms.getElements())
other_atoms_g.setOccupancies(other_atoms.getOccupancies())
other_atoms_g.setResnums(other_atoms.getResnums())
other_atoms_g.setChids(other_atoms.getChids())
if force_hetatm:
other_atoms_g.setFlags("hetatm", other_atoms.getFlags("hetatm"))
writePDB(save_path, protein + other_atoms_g)
else:
writePDB(save_path, protein)
def get_aligned_coordinates(protein_atoms, CA_dict: dict, atom_name: str):
"""
protein_atoms: prody atom group
CA_dict: mapping between chain_residue_idx_icodes and integers
atom_name: atom to be parsed; e.g. CA
"""
atom_atoms = protein_atoms.select(f"name {atom_name}")
if atom_atoms != None:
atom_coords = atom_atoms.getCoords()
atom_resnums = atom_atoms.getResnums()
atom_chain_ids = atom_atoms.getChids()
atom_icodes = atom_atoms.getIcodes()
atom_coords_ = np.zeros([len(CA_dict), 3], np.float32)
atom_coords_m = np.zeros([len(CA_dict)], np.int32)
if atom_atoms != None:
for i in range(len(atom_resnums)):
code = atom_chain_ids[i] + "_" + str(atom_resnums[i]) + "_" + atom_icodes[i]
if code in list(CA_dict):
atom_coords_[CA_dict[code], :] = atom_coords[i]
atom_coords_m[CA_dict[code]] = 1
return atom_coords_, atom_coords_m
def parse_PDB(
input_path: str,
device: str = "cpu",
chains: list = [],
parse_all_atoms: bool = False,
parse_atoms_with_zero_occupancy: bool = False
):
"""
input_path : path for the input PDB
device: device for the torch.Tensor
chains: a list specifying which chains need to be parsed; e.g. ["A", "B"]
parse_all_atoms: if False parse only N,CA,C,O otherwise all 37 atoms
parse_atoms_with_zero_occupancy: if True atoms with zero occupancy will be parsed
"""
element_list = [
"H",
"He",
"Li",
"Be",
"B",
"C",
"N",
"O",
"F",
"Ne",
"Na",
"Mg",
"Al",
"Si",
"P",
"S",
"Cl",
"Ar",
"K",
"Ca",
"Sc",
"Ti",
"V",
"Cr",
"Mn",
"Fe",
"Co",
"Ni",
"Cu",
"Zn",
"Ga",
"Ge",
"As",
"Se",
"Br",
"Kr",
"Rb",
"Sr",
"Y",
"Zr",
"Nb",
"Mb",
"Tc",
"Ru",
"Rh",
"Pd",
"Ag",
"Cd",
"In",
"Sn",
"Sb",
"Te",
"I",
"Xe",
"Cs",
"Ba",
"La",
"Ce",
"Pr",
"Nd",
"Pm",
"Sm",
"Eu",
"Gd",
"Tb",
"Dy",
"Ho",
"Er",
"Tm",
"Yb",
"Lu",
"Hf",
"Ta",
"W",
"Re",
"Os",
"Ir",
"Pt",
"Au",
"Hg",
"Tl",
"Pb",
"Bi",
"Po",
"At",
"Rn",
"Fr",
"Ra",
"Ac",
"Th",
"Pa",
"U",
"Np",
"Pu",
"Am",
"Cm",
"Bk",
"Cf",
"Es",
"Fm",
"Md",
"No",
"Lr",
"Rf",
"Db",
"Sg",
"Bh",
"Hs",
"Mt",
"Ds",
"Rg",
"Cn",
"Uut",
"Fl",
"Uup",
"Lv",
"Uus",
"Uuo",
]
element_list = [item.upper() for item in element_list]
element_dict = dict(zip(element_list, range(1, len(element_list))))
restype_3to1 = {
"ALA": "A",
"ARG": "R",
"ASN": "N",
"ASP": "D",
"CYS": "C",
"GLN": "Q",
"GLU": "E",
"GLY": "G",
"HIS": "H",
"ILE": "I",
"LEU": "L",
"LYS": "K",
"MET": "M",
"PHE": "F",
"PRO": "P",
"SER": "S",
"THR": "T",
"TRP": "W",
"TYR": "Y",
"VAL": "V",
}
restype_STRtoINT = {
"A": 0,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"V": 17,
"W": 18,
"Y": 19,
"X": 20,
}
atom_order = {
"N": 0,
"CA": 1,
"C": 2,
"CB": 3,
"O": 4,
"CG": 5,
"CG1": 6,
"CG2": 7,
"OG": 8,
"OG1": 9,
"SG": 10,
"CD": 11,
"CD1": 12,
"CD2": 13,
"ND1": 14,
"ND2": 15,
"OD1": 16,
"OD2": 17,
"SD": 18,
"CE": 19,
"CE1": 20,
"CE2": 21,
"CE3": 22,
"NE": 23,
"NE1": 24,
"NE2": 25,
"OE1": 26,
"OE2": 27,
"CH2": 28,
"NH1": 29,
"NH2": 30,
"OH": 31,
"CZ": 32,
"CZ2": 33,
"CZ3": 34,
"NZ": 35,
"OXT": 36,
}
if not parse_all_atoms:
atom_types = ["N", "CA", "C", "O"]
else:
atom_types = [
"N",
"CA",
"C",
"CB",
"O",
"CG",
"CG1",
"CG2",
"OG",
"OG1",
"SG",
"CD",
"CD1",
"CD2",
"ND1",
"ND2",
"OD1",
"OD2",
"SD",
"CE",
"CE1",
"CE2",
"CE3",
"NE",
"NE1",
"NE2",
"OE1",
"OE2",
"CH2",
"NH1",
"NH2",
"OH",
"CZ",
"CZ2",
"CZ3",
"NZ",
]
atoms = parsePDB(input_path)
if not parse_atoms_with_zero_occupancy:
atoms = atoms.select("occupancy > 0")
if chains:
str_out = ""
for item in chains:
str_out += " chain " + item + " or"
atoms = atoms.select(str_out[1:-3])
protein_atoms = atoms.select("protein")
backbone = protein_atoms.select("backbone")
other_atoms = atoms.select("not protein and not water")
water_atoms = atoms.select("water")
CA_atoms = protein_atoms.select("name CA")
CA_resnums = CA_atoms.getResnums()
CA_chain_ids = CA_atoms.getChids()
CA_icodes = CA_atoms.getIcodes()
CA_dict = {}
for i in range(len(CA_resnums)):
code = CA_chain_ids[i] + "_" + str(CA_resnums[i]) + "_" + CA_icodes[i]
CA_dict[code] = i
xyz_37 = np.zeros([len(CA_dict), 37, 3], np.float32)
xyz_37_m = np.zeros([len(CA_dict), 37], np.int32)
for atom_name in atom_types:
xyz, xyz_m = get_aligned_coordinates(protein_atoms, CA_dict, atom_name)
xyz_37[:, atom_order[atom_name], :] = xyz
xyz_37_m[:, atom_order[atom_name]] = xyz_m
N = xyz_37[:, atom_order["N"], :]
CA = xyz_37[:, atom_order["CA"], :]
C = xyz_37[:, atom_order["C"], :]
O = xyz_37[:, atom_order["O"], :]
N_m = xyz_37_m[:, atom_order["N"]]
CA_m = xyz_37_m[:, atom_order["CA"]]
C_m = xyz_37_m[:, atom_order["C"]]
O_m = xyz_37_m[:, atom_order["O"]]
mask = N_m * CA_m * C_m * O_m # must all 4 atoms exist
b = CA - N
c = C - CA
a = np.cross(b, c, axis=-1)
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32)
R_idx = np.array(CA_resnums, dtype=np.int32)
S = CA_atoms.getResnames()
S = [restype_3to1[AA] if AA in list(restype_3to1) else "X" for AA in list(S)]
S = np.array([restype_STRtoINT[AA] for AA in list(S)], np.int32)
X = np.concatenate([N[:, None], CA[:, None], C[:, None], O[:, None]], 1)
try:
Y = np.array(other_atoms.getCoords(), dtype=np.float32)
Y_t = list(other_atoms.getElements())
Y_t = np.array(
[
element_dict[y_t.upper()] if y_t.upper() in element_list else 0
for y_t in Y_t
],
dtype=np.int32,
)
Y_m = (Y_t != 1) * (Y_t != 0)
Y = Y[Y_m, :]
Y_t = Y_t[Y_m]
Y_m = Y_m[Y_m]
except:
Y = np.zeros([1, 3], np.float32)
Y_t = np.zeros([1], np.int32)
Y_m = np.zeros([1], np.int32)
output_dict = {}
output_dict["X"] = torch.tensor(X, device=device, dtype=torch.float32)
output_dict["mask"] = torch.tensor(mask, device=device, dtype=torch.int32)
output_dict["Y"] = torch.tensor(Y, device=device, dtype=torch.float32)
output_dict["Y_t"] = torch.tensor(Y_t, device=device, dtype=torch.int32)
output_dict["Y_m"] = torch.tensor(Y_m, device=device, dtype=torch.int32)
output_dict["R_idx"] = torch.tensor(R_idx, device=device, dtype=torch.int32)
output_dict["chain_labels"] = torch.tensor(
chain_labels, device=device, dtype=torch.int32
)
output_dict["chain_letters"] = CA_chain_ids
mask_c = []
chain_list = list(set(output_dict["chain_letters"]))
chain_list.sort()
for chain in chain_list:
mask_c.append(
torch.tensor(
[chain == item for item in output_dict["chain_letters"]],
device=device,
dtype=bool,
)
)
output_dict["mask_c"] = mask_c
output_dict["chain_list"] = chain_list
output_dict["S"] = torch.tensor(S, device=device, dtype=torch.int32)
output_dict["xyz_37"] = torch.tensor(xyz_37, device=device, dtype=torch.float32)
output_dict["xyz_37_m"] = torch.tensor(xyz_37_m, device=device, dtype=torch.int32)
return output_dict, backbone, other_atoms, CA_icodes, water_atoms
def get_nearest_neighbours(CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms):
device = CB.device
mask_CBY = mask[:, None] * Y_m[None, :] # [A,B]
L2_AB = torch.sum((CB[:, None, :] - Y[None, :, :]) ** 2, -1)
L2_AB = L2_AB * mask_CBY + (1 - mask_CBY) * 1000.0
nn_idx = torch.argsort(L2_AB, -1)[:, :number_of_ligand_atoms]
L2_AB_nn = torch.gather(L2_AB, 1, nn_idx)
D_AB_closest = torch.sqrt(L2_AB_nn[:, 0])
Y_r = Y[None, :, :].repeat(CB.shape[0], 1, 1)
Y_t_r = Y_t[None, :].repeat(CB.shape[0], 1)
Y_m_r = Y_m[None, :].repeat(CB.shape[0], 1)
Y_tmp = torch.gather(Y_r, 1, nn_idx[:, :, None].repeat(1, 1, 3))
Y_t_tmp = torch.gather(Y_t_r, 1, nn_idx)
Y_m_tmp = torch.gather(Y_m_r, 1, nn_idx)
Y = torch.zeros(
[CB.shape[0], number_of_ligand_atoms, 3], dtype=torch.float32, device=device
)
Y_t = torch.zeros(
[CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
)
Y_m = torch.zeros(
[CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
)
num_nn_update = Y_tmp.shape[1]
Y[:, :num_nn_update] = Y_tmp
Y_t[:, :num_nn_update] = Y_t_tmp
Y_m[:, :num_nn_update] = Y_m_tmp
return Y, Y_t, Y_m, D_AB_closest
def featurize(
input_dict,
cutoff_for_score=8.0,
use_atom_context=True,
number_of_ligand_atoms=16,
model_type="protein_mpnn",
):
output_dict = {}
if model_type == "ligand_mpnn":
mask = input_dict["mask"]
Y = input_dict["Y"]
Y_t = input_dict["Y_t"]
Y_m = input_dict["Y_m"]
N = input_dict["X"][:, 0, :]
CA = input_dict["X"][:, 1, :]
C = input_dict["X"][:, 2, :]
b = CA - N
c = C - CA
a = torch.cross(b, c, axis=-1)
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
Y, Y_t, Y_m, D_XY = get_nearest_neighbours(
CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms
)
mask_XY = (D_XY < cutoff_for_score) * mask * Y_m[:, 0]
output_dict["mask_XY"] = mask_XY[None,]
if "side_chain_mask" in list(input_dict):
output_dict["side_chain_mask"] = input_dict["side_chain_mask"][None,]
output_dict["Y"] = Y[None,]
output_dict["Y_t"] = Y_t[None,]
output_dict["Y_m"] = Y_m[None,]
if not use_atom_context:
output_dict["Y_m"] = 0.0 * output_dict["Y_m"]
elif (
model_type == "per_residue_label_membrane_mpnn"
or model_type == "global_label_membrane_mpnn"
):
output_dict["membrane_per_residue_labels"] = input_dict[
"membrane_per_residue_labels"
][None,]
R_idx_list = []
count = 0
R_idx_prev = -100000
for R_idx in list(input_dict["R_idx"]):
if R_idx_prev == R_idx:
count += 1
R_idx_list.append(R_idx + count)
R_idx_prev = R_idx
R_idx_renumbered = torch.tensor(R_idx_list, device=R_idx.device)
output_dict["R_idx"] = R_idx_renumbered[None,]
output_dict["R_idx_original"] = input_dict["R_idx"][None,]
output_dict["chain_labels"] = input_dict["chain_labels"][None,]
output_dict["S"] = input_dict["S"][None,]
output_dict["chain_mask"] = input_dict["chain_mask"][None,]
output_dict["mask"] = input_dict["mask"][None,]
output_dict["X"] = input_dict["X"][None,]
if "xyz_37" in list(input_dict):
output_dict["xyz_37"] = input_dict["xyz_37"][None,]
output_dict["xyz_37_m"] = input_dict["xyz_37_m"][None,]
return output_dict

47
get_model_params.sh Normal file
View File

@@ -0,0 +1,47 @@
#!/bin/bash
#make new directory for model parameters
#e.g. bash get_model_params.sh "./model_params"
mkdir -p $1
#Original ProteinMPNN weights
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_002.pt -O $1"/proteinmpnn_v_48_002.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_010.pt -O $1"/proteinmpnn_v_48_010.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt -O $1"/proteinmpnn_v_48_020.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_030.pt -O $1"/proteinmpnn_v_48_030.pt"
#ProteinMPNN with num_edges=32
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_002.pt -O $1"/proteinmpnn_v_32_002.pt"
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_010.pt -O $1"/proteinmpnn_v_32_010.pt"
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_020.pt -O $1"/proteinmpnn_v_32_020.pt"
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_030.pt -O $1"/proteinmpnn_v_32_030.pt"
#LigandMPNN with num_edges=32; atom_context_num=25
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_25.pt -O $1"/ligandmpnn_v_32_005_25.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt -O $1"/ligandmpnn_v_32_010_25.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_25.pt -O $1"/ligandmpnn_v_32_020_25.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_25.pt -O $1"/ligandmpnn_v_32_030_25.pt"
#LigandMPNN with num_edges=32; atom_context_num=16
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_16.pt -O $1"/ligandmpnn_v_32_005_16.pt"
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_16.pt -O $1"/ligandmpnn_v_32_010_16.pt"
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_16.pt -O $1"/ligandmpnn_v_32_020_16.pt"
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_16.pt -O $1"/ligandmpnn_v_32_030_16.pt"
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/publication_version_ligandmpnn_v_32_010_25.pt -O $1"/publication_version_ligandmpnn_v_32_010_25.pt"
#Per residue label membrane ProteinMPNN
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/per_residue_label_membrane_mpnn_v_48_020.pt -O $1"/per_residue_label_membrane_mpnn_v_48_020.pt"
#Global label membrane ProteinMPNN
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/global_label_membrane_mpnn_v_48_020.pt -O $1"/global_label_membrane_mpnn_v_48_020.pt"
#SolubleMPNN
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_002.pt -O $1"/solublempnn_v_48_002.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_010.pt -O $1"/solublempnn_v_48_010.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt -O $1"/solublempnn_v_48_020.pt"
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_030.pt -O $1"/solublempnn_v_48_030.pt"
#LigandMPNN for side-chain packing (multi-step denoising model)
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_sc_v_32_002_16.pt -O $1"/ligandmpnn_sc_v_32_002_16.pt"

67
main.nf Normal file
View File

@@ -0,0 +1,67 @@
#!/usr/bin/env nextflow
nextflow.enable.dsl=2
params.pdb = '/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb'
params.outdir = '/mnt/OmicNAS/private/old/olamide/ligandmpnn/output'
params.model_type = 'ligand_mpnn'
params.temperature = 0.1
params.seed = 111
params.batch_size = 1
params.number_of_batches = 1
params.chains_to_design = ''
params.fixed_residues = ''
params.pack_side_chains = 0
process LIGANDMPNN {
container 'ligandmpnn:latest'
containerOptions '--rm --gpus all -v /mnt:/mnt'
publishDir params.outdir, mode: 'copy'
stageInMode 'copy'
input:
path pdb
output:
path "${pdb.simpleName}/seqs/*.fa"
path "${pdb.simpleName}/backbones/*.pdb"
path "${pdb.simpleName}/packed/*.pdb", optional: true
path "run.log"
script:
// Set checkpoint path based on model type
def checkpoint_arg = ''
if (params.model_type == 'ligand_mpnn') {
checkpoint_arg = '--checkpoint_ligand_mpnn /app/LigandMPNN/model_params/ligandmpnn_v_32_010_25.pt'
} else if (params.model_type == 'protein_mpnn') {
checkpoint_arg = '--checkpoint_protein_mpnn /app/LigandMPNN/model_params/proteinmpnn_v_48_020.pt'
} else if (params.model_type == 'soluble_mpnn') {
checkpoint_arg = '--checkpoint_soluble_mpnn /app/LigandMPNN/model_params/solublempnn_v_48_020.pt'
} else if (params.model_type == 'global_label_membrane_mpnn') {
checkpoint_arg = '--checkpoint_global_label_membrane_mpnn /app/LigandMPNN/model_params/global_label_membrane_mpnn_v_48_020.pt'
} else if (params.model_type == 'per_residue_label_membrane_mpnn') {
checkpoint_arg = '--checkpoint_per_residue_label_membrane_mpnn /app/LigandMPNN/model_params/per_residue_label_membrane_mpnn_v_48_020.pt'
}
"""
mkdir -p ${pdb.simpleName}/seqs ${pdb.simpleName}/backbones ${pdb.simpleName}/packed
ligandmpnn \\
--pdb_path \$PWD/${pdb} \\
--out_folder \$PWD/${pdb.simpleName} \\
--model_type ${params.model_type} \\
${checkpoint_arg} \\
--temperature ${params.temperature} \\
--seed ${params.seed} \\
--batch_size ${params.batch_size} \\
--number_of_batches ${params.number_of_batches} \\
--pack_side_chains ${params.pack_side_chains} \\
${params.chains_to_design ? "--chains_to_design ${params.chains_to_design}" : ''} \\
${params.fixed_residues ? "--fixed_residues \"${params.fixed_residues}\"" : ''} \\
2>&1 | tee run.log
"""
}
workflow {
LIGANDMPNN(Channel.fromPath(params.pdb))
}

1772
model_utils.py Normal file

File diff suppressed because it is too large Load Diff

42
nextflow.config Normal file
View File

@@ -0,0 +1,42 @@
// Manifest for Nextflow metadata
manifest {
name = 'LigandMPNN-Nextflow'
author = 'Generated from LigandMPNN repository'
homePage = 'https://github.com/dauparas/LigandMPNN'
description = 'Nextflow pipeline for LigandMPNN - Protein sequence design with ligand context'
mainScript = 'main.nf'
version = '1.0.0'
}
// Global default parameters
params {
pdb = "/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb"
outdir = "/mnt/OmicNAS/private/old/olamide/ligandmpnn/output"
model_type = "ligand_mpnn"
temperature = 0.1
seed = 111
batch_size = 1
number_of_batches = 1
chains_to_design = ""
fixed_residues = ""
}
// Container configurations
docker {
enabled = true
runOptions = '--gpus all'
}
// Process configurations
process {
cpus = 4
memory = '16 GB'
}
// Execution configurations
executor {
$local {
cpus = 8
memory = '32 GB'
}
}

131
params.json Normal file
View File

@@ -0,0 +1,131 @@
{
"params": {
"pdb": {
"type": "file",
"description": "Path to input PDB file for protein sequence design",
"default": "/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb",
"required": true,
"pipeline_io": "input",
"var_name": "params.pdb",
"examples": [
"/mnt/workflow/input/protein.pdb",
"/mnt/workflow/input/*.pdb"
],
"pattern": ".*\\.pdb$",
"enum": [],
"validation": {},
"notes": "Input PDB file containing the protein structure for sequence design."
},
"outdir": {
"type": "folder",
"description": "Directory for LigandMPNN output results",
"default": "/mnt/OmicNAS/private/old/olamide/ligandmpnn/output",
"required": true,
"pipeline_io": "output",
"var_name": "params.outdir",
"examples": [
"/mnt/workflow/output",
"/path/to/results"
],
"pattern": ".*",
"enum": [],
"validation": {},
"notes": "Directory where designed sequences and backbone PDBs will be saved."
},
"model_type": {
"type": "string",
"description": "Type of MPNN model to use for sequence design",
"default": "ligand_mpnn",
"required": false,
"pipeline_io": "parameter",
"var_name": "params.model_type",
"examples": [
"protein_mpnn",
"ligand_mpnn",
"soluble_mpnn"
],
"pattern": "^(protein_mpnn|ligand_mpnn|soluble_mpnn|global_label_membrane_mpnn|per_residue_label_membrane_mpnn)$",
"enum": ["protein_mpnn", "ligand_mpnn", "soluble_mpnn", "global_label_membrane_mpnn", "per_residue_label_membrane_mpnn"],
"validation": {},
"notes": "protein_mpnn: Original ProteinMPNN. ligand_mpnn: Context-aware with ligands. soluble_mpnn: Trained on soluble proteins."
},
"temperature": {
"type": "number",
"description": "Sampling temperature for sequence generation",
"default": 0.1,
"required": false,
"pipeline_io": "parameter",
"var_name": "params.temperature",
"examples": [0.05, 0.1, 0.2],
"pattern": null,
"enum": [],
"validation": {},
"notes": "Higher temperature gives more sequence diversity. Recommended range: 0.05-0.5"
},
"seed": {
"type": "integer",
"description": "Random seed for reproducibility",
"default": 111,
"required": false,
"pipeline_io": "parameter",
"var_name": "params.seed",
"examples": [111, 42, 12345],
"pattern": null,
"enum": [],
"validation": {},
"notes": "Set for reproducible results."
},
"batch_size": {
"type": "integer",
"description": "Number of sequences to generate per batch",
"default": 1,
"required": false,
"pipeline_io": "parameter",
"var_name": "params.batch_size",
"examples": [1, 3, 5],
"pattern": null,
"enum": [],
"validation": {},
"notes": "Higher batch sizes require more GPU memory."
},
"number_of_batches": {
"type": "integer",
"description": "Number of batches to run",
"default": 1,
"required": false,
"pipeline_io": "parameter",
"var_name": "params.number_of_batches",
"examples": [1, 5, 10],
"pattern": null,
"enum": [],
"validation": {},
"notes": "Total sequences = batch_size × number_of_batches"
},
"chains_to_design": {
"type": "string",
"description": "Comma-separated chain IDs to redesign",
"default": "",
"required": false,
"pipeline_io": "parameter",
"var_name": "params.chains_to_design",
"examples": ["A", "A,B", "A,B,C"],
"pattern": "^([A-Z],?)*$",
"enum": [],
"validation": {},
"notes": "Leave empty to design all chains."
},
"fixed_residues": {
"type": "string",
"description": "Space-separated list of residues to keep fixed",
"default": "",
"required": false,
"pipeline_io": "parameter",
"var_name": "params.fixed_residues",
"examples": ["A1 A2 A3", "A12 B25 B26"],
"pattern": null,
"enum": [],
"validation": {},
"notes": "Format: ChainResidue (e.g., A12). Leave empty to design all residues."
}
}
}

29
requirements.txt Normal file
View File

@@ -0,0 +1,29 @@
biopython==1.79
filelock==3.13.1
fsspec==2024.3.1
Jinja2==3.1.3
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.2.1
numpy==1.23.5
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
ProDy==2.4.1
pyparsing==3.1.1
scipy==1.12.0
sympy==1.12
torch==2.2.1
triton==2.2.0
typing_extensions==4.10.0
ml-collections==0.1.1
dm-tree==0.1.8

990
run.py Normal file
View File

@@ -0,0 +1,990 @@
import argparse
import copy
import json
import os.path
import random
import sys
import numpy as np
import torch
from data_utils import (
alphabet,
element_dict_rev,
featurize,
get_score,
get_seq_rec,
parse_PDB,
restype_1to3,
restype_int_to_str,
restype_str_to_int,
write_full_PDB,
)
from model_utils import ProteinMPNN
from prody import writePDB
from sc_utils import Packer, pack_side_chains
def main(args) -> None:
"""
Inference function
"""
if args.seed:
seed = args.seed
else:
seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
folder_for_outputs = args.out_folder
base_folder = folder_for_outputs
if base_folder[-1] != "/":
base_folder = base_folder + "/"
if not os.path.exists(base_folder):
os.makedirs(base_folder, exist_ok=True)
if not os.path.exists(base_folder + "seqs"):
os.makedirs(base_folder + "seqs", exist_ok=True)
if not os.path.exists(base_folder + "backbones"):
os.makedirs(base_folder + "backbones", exist_ok=True)
if not os.path.exists(base_folder + "packed"):
os.makedirs(base_folder + "packed", exist_ok=True)
if args.save_stats:
if not os.path.exists(base_folder + "stats"):
os.makedirs(base_folder + "stats", exist_ok=True)
if args.model_type == "protein_mpnn":
checkpoint_path = args.checkpoint_protein_mpnn
elif args.model_type == "ligand_mpnn":
checkpoint_path = args.checkpoint_ligand_mpnn
elif args.model_type == "per_residue_label_membrane_mpnn":
checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
elif args.model_type == "global_label_membrane_mpnn":
checkpoint_path = args.checkpoint_global_label_membrane_mpnn
elif args.model_type == "soluble_mpnn":
checkpoint_path = args.checkpoint_soluble_mpnn
else:
print("Choose one of the available models")
sys.exit()
checkpoint = torch.load(checkpoint_path, map_location=device)
if args.model_type == "ligand_mpnn":
atom_context_num = checkpoint["atom_context_num"]
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
k_neighbors = checkpoint["num_edges"]
else:
atom_context_num = 1
ligand_mpnn_use_side_chain_context = 0
k_neighbors = checkpoint["num_edges"]
model = ProteinMPNN(
node_features=128,
edge_features=128,
hidden_dim=128,
num_encoder_layers=3,
num_decoder_layers=3,
k_neighbors=k_neighbors,
device=device,
atom_context_num=atom_context_num,
model_type=args.model_type,
ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
if args.pack_side_chains:
model_sc = Packer(
node_features=128,
edge_features=128,
num_positional_embeddings=16,
num_chain_embeddings=16,
num_rbf=16,
hidden_dim=128,
num_encoder_layers=3,
num_decoder_layers=3,
atom_context_num=16,
lower_bound=0.0,
upper_bound=20.0,
top_k=32,
dropout=0.0,
augment_eps=0.0,
atom37_order=False,
device=device,
num_mix=3,
)
checkpoint_sc = torch.load(args.checkpoint_path_sc, map_location=device)
model_sc.load_state_dict(checkpoint_sc["model_state_dict"])
model_sc.to(device)
model_sc.eval()
if args.pdb_path_multi:
with open(args.pdb_path_multi, "r") as fh:
pdb_paths = list(json.load(fh))
else:
pdb_paths = [args.pdb_path]
if args.fixed_residues_multi:
with open(args.fixed_residues_multi, "r") as fh:
fixed_residues_multi = json.load(fh)
fixed_residues_multi = {key:value.split() for key,value in fixed_residues_multi.items()}
else:
fixed_residues = [item for item in args.fixed_residues.split()]
fixed_residues_multi = {}
for pdb in pdb_paths:
fixed_residues_multi[pdb] = fixed_residues
if args.redesigned_residues_multi:
with open(args.redesigned_residues_multi, "r") as fh:
redesigned_residues_multi = json.load(fh)
redesigned_residues_multi = {key:value.split() for key,value in redesigned_residues_multi.items()}
else:
redesigned_residues = [item for item in args.redesigned_residues.split()]
redesigned_residues_multi = {}
for pdb in pdb_paths:
redesigned_residues_multi[pdb] = redesigned_residues
bias_AA = torch.zeros([21], device=device, dtype=torch.float32)
if args.bias_AA:
tmp = [item.split(":") for item in args.bias_AA.split(",")]
a1 = [b[0] for b in tmp]
a2 = [float(b[1]) for b in tmp]
for i, AA in enumerate(a1):
bias_AA[restype_str_to_int[AA]] = a2[i]
if args.bias_AA_per_residue_multi:
with open(args.bias_AA_per_residue_multi, "r") as fh:
bias_AA_per_residue_multi = json.load(
fh
) # {"pdb_path" : {"A12": {"G": 1.1}}}
else:
if args.bias_AA_per_residue:
with open(args.bias_AA_per_residue, "r") as fh:
bias_AA_per_residue = json.load(fh) # {"A12": {"G": 1.1}}
bias_AA_per_residue_multi = {}
for pdb in pdb_paths:
bias_AA_per_residue_multi[pdb] = bias_AA_per_residue
if args.omit_AA_per_residue_multi:
with open(args.omit_AA_per_residue_multi, "r") as fh:
omit_AA_per_residue_multi = json.load(
fh
) # {"pdb_path" : {"A12": "PQR", "A13": "QS"}}
else:
if args.omit_AA_per_residue:
with open(args.omit_AA_per_residue, "r") as fh:
omit_AA_per_residue = json.load(fh) # {"A12": "PG"}
omit_AA_per_residue_multi = {}
for pdb in pdb_paths:
omit_AA_per_residue_multi[pdb] = omit_AA_per_residue
omit_AA_list = args.omit_AA
omit_AA = torch.tensor(
np.array([AA in omit_AA_list for AA in alphabet]).astype(np.float32),
device=device,
)
if len(args.parse_these_chains_only) != 0:
parse_these_chains_only_list = args.parse_these_chains_only.split(",")
else:
parse_these_chains_only_list = []
# loop over PDB paths
for pdb in pdb_paths:
if args.verbose:
print("Designing protein from this path:", pdb)
fixed_residues = fixed_residues_multi[pdb]
redesigned_residues = redesigned_residues_multi[pdb]
parse_all_atoms_flag = args.ligand_mpnn_use_side_chain_context or (
args.pack_side_chains and not args.repack_everything
)
protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
pdb,
device=device,
chains=parse_these_chains_only_list,
parse_all_atoms=parse_all_atoms_flag,
parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy,
)
# make chain_letter + residue_idx + insertion_code mapping to integers
R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
encoded_residues = []
for i, R_idx_item in enumerate(R_idx_list):
tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
encoded_residues.append(tmp)
encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
encoded_residue_dict_rev = dict(
zip(list(range(len(encoded_residues))), encoded_residues)
)
bias_AA_per_residue = torch.zeros(
[len(encoded_residues), 21], device=device, dtype=torch.float32
)
if args.bias_AA_per_residue_multi or args.bias_AA_per_residue:
bias_dict = bias_AA_per_residue_multi[pdb]
for residue_name, v1 in bias_dict.items():
if residue_name in encoded_residues:
i1 = encoded_residue_dict[residue_name]
for amino_acid, v2 in v1.items():
if amino_acid in alphabet:
j1 = restype_str_to_int[amino_acid]
bias_AA_per_residue[i1, j1] = v2
omit_AA_per_residue = torch.zeros(
[len(encoded_residues), 21], device=device, dtype=torch.float32
)
if args.omit_AA_per_residue_multi or args.omit_AA_per_residue:
omit_dict = omit_AA_per_residue_multi[pdb]
for residue_name, v1 in omit_dict.items():
if residue_name in encoded_residues:
i1 = encoded_residue_dict[residue_name]
for amino_acid in v1:
if amino_acid in alphabet:
j1 = restype_str_to_int[amino_acid]
omit_AA_per_residue[i1, j1] = 1.0
fixed_positions = torch.tensor(
[int(item not in fixed_residues) for item in encoded_residues],
device=device,
)
redesigned_positions = torch.tensor(
[int(item not in redesigned_residues) for item in encoded_residues],
device=device,
)
# specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
if args.transmembrane_buried:
buried_residues = [item for item in args.transmembrane_buried.split()]
buried_positions = torch.tensor(
[int(item in buried_residues) for item in encoded_residues],
device=device,
)
else:
buried_positions = torch.zeros_like(fixed_positions)
if args.transmembrane_interface:
interface_residues = [item for item in args.transmembrane_interface.split()]
interface_positions = torch.tensor(
[int(item in interface_residues) for item in encoded_residues],
device=device,
)
else:
interface_positions = torch.zeros_like(fixed_positions)
protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
1 - interface_positions
) + 1 * interface_positions * (1 - buried_positions)
if args.model_type == "global_label_membrane_mpnn":
protein_dict["membrane_per_residue_labels"] = (
args.global_transmembrane_label + 0 * fixed_positions
)
if len(args.chains_to_design) != 0:
chains_to_design_list = args.chains_to_design.split(",")
else:
chains_to_design_list = protein_dict["chain_letters"]
chain_mask = torch.tensor(
np.array(
[
item in chains_to_design_list
for item in protein_dict["chain_letters"]
],
dtype=np.int32,
),
device=device,
)
# create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
if redesigned_residues:
protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
elif fixed_residues:
protein_dict["chain_mask"] = chain_mask * fixed_positions
else:
protein_dict["chain_mask"] = chain_mask
if args.verbose:
PDB_residues_to_be_redesigned = [
encoded_residue_dict_rev[item]
for item in range(protein_dict["chain_mask"].shape[0])
if protein_dict["chain_mask"][item] == 1
]
PDB_residues_to_be_fixed = [
encoded_residue_dict_rev[item]
for item in range(protein_dict["chain_mask"].shape[0])
if protein_dict["chain_mask"][item] == 0
]
print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
print("These residues will be fixed: ", PDB_residues_to_be_fixed)
# specify which residues are linked
if args.symmetry_residues:
symmetry_residues_list_of_lists = [
x.split(",") for x in args.symmetry_residues.split("|")
]
remapped_symmetry_residues = []
for t_list in symmetry_residues_list_of_lists:
tmp_list = []
for t in t_list:
tmp_list.append(encoded_residue_dict[t])
remapped_symmetry_residues.append(tmp_list)
else:
remapped_symmetry_residues = [[]]
# specify linking weights
if args.symmetry_weights:
symmetry_weights = [
[float(item) for item in x.split(",")]
for x in args.symmetry_weights.split("|")
]
else:
symmetry_weights = [[]]
if args.homo_oligomer:
if args.verbose:
print("Designing HOMO-OLIGOMER")
chain_letters_set = list(set(chain_letters_list))
reference_chain = chain_letters_set[0]
lc = len(reference_chain)
residue_indices = [
item[lc:] for item in encoded_residues if item[:lc] == reference_chain
]
remapped_symmetry_residues = []
symmetry_weights = []
for res in residue_indices:
tmp_list = []
tmp_w_list = []
for chain in chain_letters_set:
name = chain + res
tmp_list.append(encoded_residue_dict[name])
tmp_w_list.append(1 / len(chain_letters_set))
remapped_symmetry_residues.append(tmp_list)
symmetry_weights.append(tmp_w_list)
# set other atom bfactors to 0.0
if other_atoms:
other_bfactors = other_atoms.getBetas()
other_atoms.setBetas(other_bfactors * 0.0)
# adjust input PDB name by dropping .pdb if it does exist
name = pdb[pdb.rfind("/") + 1 :]
if name[-4:] == ".pdb":
name = name[:-4]
with torch.no_grad():
# run featurize to remap R_idx and add batch dimension
if args.verbose:
if "Y" in list(protein_dict):
atom_coords = protein_dict["Y"].cpu().numpy()
atom_types = list(protein_dict["Y_t"].cpu().numpy())
atom_mask = list(protein_dict["Y_m"].cpu().numpy())
number_of_atoms_parsed = np.sum(atom_mask)
else:
print("No ligand atoms parsed")
number_of_atoms_parsed = 0
atom_types = ""
atom_coords = []
if number_of_atoms_parsed == 0:
print("No ligand atoms parsed")
elif args.model_type == "ligand_mpnn":
print(
f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
)
for i, atom_type in enumerate(atom_types):
print(
f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
)
feature_dict = featurize(
protein_dict,
cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
use_atom_context=args.ligand_mpnn_use_atom_context,
number_of_ligand_atoms=atom_context_num,
model_type=args.model_type,
)
feature_dict["batch_size"] = args.batch_size
B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
# add additional keys to the feature dictionary
feature_dict["temperature"] = args.temperature
feature_dict["bias"] = (
(-1e8 * omit_AA[None, None, :] + bias_AA).repeat([1, L, 1])
+ bias_AA_per_residue[None]
- 1e8 * omit_AA_per_residue[None]
)
feature_dict["symmetry_residues"] = remapped_symmetry_residues
feature_dict["symmetry_weights"] = symmetry_weights
sampling_probs_list = []
log_probs_list = []
decoding_order_list = []
S_list = []
loss_list = []
loss_per_residue_list = []
loss_XY_list = []
for _ in range(args.number_of_batches):
feature_dict["randn"] = torch.randn(
[feature_dict["batch_size"], feature_dict["mask"].shape[1]],
device=device,
)
output_dict = model.sample(feature_dict)
# compute confidence scores
loss, loss_per_residue = get_score(
output_dict["S"],
output_dict["log_probs"],
feature_dict["mask"] * feature_dict["chain_mask"],
)
if args.model_type == "ligand_mpnn":
combined_mask = (
feature_dict["mask"]
* feature_dict["mask_XY"]
* feature_dict["chain_mask"]
)
else:
combined_mask = feature_dict["mask"] * feature_dict["chain_mask"]
loss_XY, _ = get_score(
output_dict["S"], output_dict["log_probs"], combined_mask
)
# -----
S_list.append(output_dict["S"])
log_probs_list.append(output_dict["log_probs"])
sampling_probs_list.append(output_dict["sampling_probs"])
decoding_order_list.append(output_dict["decoding_order"])
loss_list.append(loss)
loss_per_residue_list.append(loss_per_residue)
loss_XY_list.append(loss_XY)
S_stack = torch.cat(S_list, 0)
log_probs_stack = torch.cat(log_probs_list, 0)
sampling_probs_stack = torch.cat(sampling_probs_list, 0)
decoding_order_stack = torch.cat(decoding_order_list, 0)
loss_stack = torch.cat(loss_list, 0)
loss_per_residue_stack = torch.cat(loss_per_residue_list, 0)
loss_XY_stack = torch.cat(loss_XY_list, 0)
rec_mask = feature_dict["mask"][:1] * feature_dict["chain_mask"][:1]
rec_stack = get_seq_rec(feature_dict["S"][:1], S_stack, rec_mask)
native_seq = "".join(
[restype_int_to_str[AA] for AA in feature_dict["S"][0].cpu().numpy()]
)
seq_np = np.array(list(native_seq))
seq_out_str = []
for mask in protein_dict["mask_c"]:
seq_out_str += list(seq_np[mask.cpu().numpy()])
seq_out_str += [args.fasta_seq_separation]
seq_out_str = "".join(seq_out_str)[:-1]
output_fasta = base_folder + "/seqs/" + name + args.file_ending + ".fa"
output_backbones = base_folder + "/backbones/"
output_packed = base_folder + "/packed/"
output_stats_path = base_folder + "stats/" + name + args.file_ending + ".pt"
out_dict = {}
out_dict["generated_sequences"] = S_stack.cpu()
out_dict["sampling_probs"] = sampling_probs_stack.cpu()
out_dict["log_probs"] = log_probs_stack.cpu()
out_dict["decoding_order"] = decoding_order_stack.cpu()
out_dict["native_sequence"] = feature_dict["S"][0].cpu()
out_dict["mask"] = feature_dict["mask"][0].cpu()
out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu()
out_dict["seed"] = seed
out_dict["temperature"] = args.temperature
if args.save_stats:
torch.save(out_dict, output_stats_path)
if args.pack_side_chains:
if args.verbose:
print("Packing side chains...")
feature_dict_ = featurize(
protein_dict,
cutoff_for_score=8.0,
use_atom_context=args.pack_with_ligand_context,
number_of_ligand_atoms=16,
model_type="ligand_mpnn",
)
sc_feature_dict = copy.deepcopy(feature_dict_)
B = args.batch_size
for k, v in sc_feature_dict.items():
if k != "S":
try:
num_dim = len(v.shape)
if num_dim == 2:
sc_feature_dict[k] = v.repeat(B, 1)
elif num_dim == 3:
sc_feature_dict[k] = v.repeat(B, 1, 1)
elif num_dim == 4:
sc_feature_dict[k] = v.repeat(B, 1, 1, 1)
elif num_dim == 5:
sc_feature_dict[k] = v.repeat(B, 1, 1, 1, 1)
except:
pass
X_stack_list = []
X_m_stack_list = []
b_factor_stack_list = []
for _ in range(args.number_of_packs_per_design):
X_list = []
X_m_list = []
b_factor_list = []
for c in range(args.number_of_batches):
sc_feature_dict["S"] = S_list[c]
sc_dict = pack_side_chains(
sc_feature_dict,
model_sc,
args.sc_num_denoising_steps,
args.sc_num_samples,
args.repack_everything,
)
X_list.append(sc_dict["X"])
X_m_list.append(sc_dict["X_m"])
b_factor_list.append(sc_dict["b_factors"])
X_stack = torch.cat(X_list, 0)
X_m_stack = torch.cat(X_m_list, 0)
b_factor_stack = torch.cat(b_factor_list, 0)
X_stack_list.append(X_stack)
X_m_stack_list.append(X_m_stack)
b_factor_stack_list.append(b_factor_stack)
with open(output_fasta, "w") as f:
f.write(
">{}, T={}, seed={}, num_res={}, num_ligand_res={}, use_ligand_context={}, ligand_cutoff_distance={}, batch_size={}, number_of_batches={}, model_path={}\n{}\n".format(
name,
args.temperature,
seed,
torch.sum(rec_mask).cpu().numpy(),
torch.sum(combined_mask[:1]).cpu().numpy(),
bool(args.ligand_mpnn_use_atom_context),
float(args.ligand_mpnn_cutoff_for_score),
args.batch_size,
args.number_of_batches,
checkpoint_path,
seq_out_str,
)
)
for ix in range(S_stack.shape[0]):
ix_suffix = ix
if not args.zero_indexed:
ix_suffix += 1
seq_rec_print = np.format_float_positional(
rec_stack[ix].cpu().numpy(), unique=False, precision=4
)
loss_np = np.format_float_positional(
np.exp(-loss_stack[ix].cpu().numpy()), unique=False, precision=4
)
loss_XY_np = np.format_float_positional(
np.exp(-loss_XY_stack[ix].cpu().numpy()),
unique=False,
precision=4,
)
seq = "".join(
[restype_int_to_str[AA] for AA in S_stack[ix].cpu().numpy()]
)
# write new sequences into PDB with backbone coordinates
seq_prody = np.array([restype_1to3[AA] for AA in list(seq)])[
None,
].repeat(4, 1)
bfactor_prody = (
loss_per_residue_stack[ix].cpu().numpy()[None, :].repeat(4, 1)
)
backbone.setResnames(seq_prody)
backbone.setBetas(
np.exp(-bfactor_prody)
* (bfactor_prody > 0.01).astype(np.float32)
)
if other_atoms:
writePDB(
output_backbones
+ name
+ "_"
+ str(ix_suffix)
+ args.file_ending
+ ".pdb",
backbone + other_atoms,
)
else:
writePDB(
output_backbones
+ name
+ "_"
+ str(ix_suffix)
+ args.file_ending
+ ".pdb",
backbone,
)
# write full PDB files
if args.pack_side_chains:
for c_pack in range(args.number_of_packs_per_design):
X_stack = X_stack_list[c_pack]
X_m_stack = X_m_stack_list[c_pack]
b_factor_stack = b_factor_stack_list[c_pack]
write_full_PDB(
output_packed
+ name
+ args.packed_suffix
+ "_"
+ str(ix_suffix)
+ "_"
+ str(c_pack + 1)
+ args.file_ending
+ ".pdb",
X_stack[ix].cpu().numpy(),
X_m_stack[ix].cpu().numpy(),
b_factor_stack[ix].cpu().numpy(),
feature_dict["R_idx_original"][0].cpu().numpy(),
protein_dict["chain_letters"],
S_stack[ix].cpu().numpy(),
other_atoms=other_atoms,
icodes=icodes,
force_hetatm=args.force_hetatm,
)
# -----
# write fasta lines
seq_np = np.array(list(seq))
seq_out_str = []
for mask in protein_dict["mask_c"]:
seq_out_str += list(seq_np[mask.cpu().numpy()])
seq_out_str += [args.fasta_seq_separation]
seq_out_str = "".join(seq_out_str)[:-1]
if ix == S_stack.shape[0] - 1:
# final 2 lines
f.write(
">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}".format(
name,
ix_suffix,
args.temperature,
seed,
loss_np,
loss_XY_np,
seq_rec_print,
seq_out_str,
)
)
else:
f.write(
">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}\n".format(
name,
ix_suffix,
args.temperature,
seed,
loss_np,
loss_XY_np,
seq_rec_print,
seq_out_str,
)
)
if __name__ == "__main__":
argparser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
argparser.add_argument(
"--model_type",
type=str,
default="protein_mpnn",
help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
)
# protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
# ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
# per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
# global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
# soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
argparser.add_argument(
"--checkpoint_protein_mpnn",
type=str,
default="./model_params/proteinmpnn_v_48_020.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--checkpoint_ligand_mpnn",
type=str,
default="./model_params/ligandmpnn_v_32_010_25.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--checkpoint_per_residue_label_membrane_mpnn",
type=str,
default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--checkpoint_global_label_membrane_mpnn",
type=str,
default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--checkpoint_soluble_mpnn",
type=str,
default="./model_params/solublempnn_v_48_020.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--fasta_seq_separation",
type=str,
default=":",
help="Symbol to use between sequences from different chains",
)
argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
argparser.add_argument(
"--pdb_path", type=str, default="", help="Path to the input PDB."
)
argparser.add_argument(
"--pdb_path_multi",
type=str,
default="",
help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
)
argparser.add_argument(
"--fixed_residues",
type=str,
default="",
help="Provide fixed residues, A12 A13 A14 B2 B25",
)
argparser.add_argument(
"--fixed_residues_multi",
type=str,
default="",
help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
)
argparser.add_argument(
"--redesigned_residues",
type=str,
default="",
help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
)
argparser.add_argument(
"--redesigned_residues_multi",
type=str,
default="",
help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
)
argparser.add_argument(
"--bias_AA",
type=str,
default="",
help="Bias generation of amino acids, e.g. 'A:-1.024,P:2.34,C:-12.34'",
)
argparser.add_argument(
"--bias_AA_per_residue",
type=str,
default="",
help="Path to json mapping of bias {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}",
)
argparser.add_argument(
"--bias_AA_per_residue_multi",
type=str,
default="",
help="Path to json mapping of bias {'pdb_path': {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}}",
)
argparser.add_argument(
"--omit_AA",
type=str,
default="",
help="Bias generation of amino acids, e.g. 'ACG'",
)
argparser.add_argument(
"--omit_AA_per_residue",
type=str,
default="",
help="Path to json mapping of bias {'A12': 'APQ', 'A13': 'QST'}",
)
argparser.add_argument(
"--omit_AA_per_residue_multi",
type=str,
default="",
help="Path to json mapping of bias {'pdb_path': {'A12': 'QSPC', 'A13': 'AGE'}}",
)
argparser.add_argument(
"--symmetry_residues",
type=str,
default="",
help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
)
argparser.add_argument(
"--symmetry_weights",
type=str,
default="",
help="Add weights that match symmetry_residues, e.g. '1.01,1.0,1.0|-1.0,2.0|2.0,2.3'",
)
argparser.add_argument(
"--homo_oligomer",
type=int,
default=0,
help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
)
argparser.add_argument(
"--out_folder",
type=str,
help="Path to a folder to output sequences, e.g. /home/out/",
)
argparser.add_argument(
"--file_ending", type=str, default="", help="adding_string_to_the_end"
)
argparser.add_argument(
"--zero_indexed",
type=str,
default=0,
help="1 - to start output PDB numbering with 0",
)
argparser.add_argument(
"--seed",
type=int,
default=0,
help="Set seed for torch, numpy, and python random.",
)
argparser.add_argument(
"--batch_size",
type=int,
default=1,
help="Number of sequence to generate per one pass.",
)
argparser.add_argument(
"--number_of_batches",
type=int,
default=1,
help="Number of times to design sequence using a chosen batch size.",
)
argparser.add_argument(
"--temperature",
type=float,
default=0.1,
help="Temperature to sample sequences.",
)
argparser.add_argument(
"--save_stats", type=int, default=0, help="Save output statistics"
)
argparser.add_argument(
"--ligand_mpnn_use_atom_context",
type=int,
default=1,
help="1 - use atom context, 0 - do not use atom context.",
)
argparser.add_argument(
"--ligand_mpnn_cutoff_for_score",
type=float,
default=8.0,
help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
)
argparser.add_argument(
"--ligand_mpnn_use_side_chain_context",
type=int,
default=0,
help="Flag to use side chain atoms as ligand context for the fixed residues",
)
argparser.add_argument(
"--chains_to_design",
type=str,
default="",
help="Specify which chains to redesign, all others will be kept fixed, 'A,B,C,F'",
)
argparser.add_argument(
"--parse_these_chains_only",
type=str,
default="",
help="Provide chains letters for parsing backbones, 'A,B,C,F'",
)
argparser.add_argument(
"--transmembrane_buried",
type=str,
default="",
help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
)
argparser.add_argument(
"--transmembrane_interface",
type=str,
default="",
help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
)
argparser.add_argument(
"--global_transmembrane_label",
type=int,
default=0,
help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
)
argparser.add_argument(
"--parse_atoms_with_zero_occupancy",
type=int,
default=0,
help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
)
argparser.add_argument(
"--pack_side_chains",
type=int,
default=0,
help="1 - to run side chain packer, 0 - do not run it",
)
argparser.add_argument(
"--checkpoint_path_sc",
type=str,
default="./model_params/ligandmpnn_sc_v_32_002_16.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--number_of_packs_per_design",
type=int,
default=4,
help="Number of independent side chain packing samples to return per design",
)
argparser.add_argument(
"--sc_num_denoising_steps",
type=int,
default=3,
help="Number of denoising/recycling steps to make for side chain packing",
)
argparser.add_argument(
"--sc_num_samples",
type=int,
default=16,
help="Number of samples to draw from a mixture distribution and then take a sample with the highest likelihood.",
)
argparser.add_argument(
"--repack_everything",
type=int,
default=0,
help="1 - repacks side chains of all residues including the fixed ones; 0 - keeps the side chains fixed for fixed residues",
)
argparser.add_argument(
"--force_hetatm",
type=int,
default=0,
help="To force ligand atoms to be written as HETATM to PDB file after packing.",
)
argparser.add_argument(
"--packed_suffix",
type=str,
default="_packed",
help="Suffix for packed PDB paths",
)
argparser.add_argument(
"--pack_with_ligand_context",
type=int,
default=1,
help="1-pack side chains using ligand context, 0 - do not use it.",
)
args = argparser.parse_args()
main(args)

244
run_examples.sh Normal file
View File

@@ -0,0 +1,244 @@
#!/bin/bash
#1
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/default"
#2
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--temperature 0.05 \
--out_folder "./outputs/temperature"
#3
python run.py \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/random_seed"
#4
python run.py \
--seed 111 \
--verbose 0 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/verbose"
#5
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/save_stats" \
--save_stats 1
#6
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/fix_residues" \
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
--bias_AA "A:10.0"
#7
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/redesign_residues" \
--redesigned_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
--bias_AA "A:10.0"
#8
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/batch_size" \
--batch_size 3 \
--number_of_batches 5
#9
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--bias_AA "W:3.0,P:3.0,C:3.0,A:-3.0" \
--out_folder "./outputs/global_bias"
#10
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--bias_AA_per_residue "./inputs/bias_AA_per_residue.json" \
--out_folder "./outputs/per_residue_bias"
#11
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--omit_AA "CDFGHILMNPQRSTVWY" \
--out_folder "./outputs/global_omit"
#12
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--omit_AA_per_residue "./inputs/omit_AA_per_residue.json" \
--out_folder "./outputs/per_residue_omit"
#13
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/symmetry" \
--symmetry_residues "C1,C2,C3|C4,C5|C6,C7" \
--symmetry_weights "0.33,0.33,0.33|0.5,0.5|0.5,0.5"
#14
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/4GYT.pdb" \
--out_folder "./outputs/homooligomer" \
--homo_oligomer 1 \
--number_of_batches 2
#15
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/file_ending" \
--file_ending "_xyz"
#16
python run.py \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/zero_indexed" \
--zero_indexed 1 \
--number_of_batches 2
#17
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/4GYT.pdb" \
--out_folder "./outputs/chains_to_design" \
--chains_to_design "A,B"
#18
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/4GYT.pdb" \
--out_folder "./outputs/parse_these_chains_only" \
--parse_these_chains_only "A,B"
#19
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/ligandmpnn_default"
#20
python run.py \
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/ligandmpnn_v_32_005_25"
#21
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/ligandmpnn_no_context" \
--ligand_mpnn_use_atom_context 0
#22
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/ligandmpnn_use_side_chain_atoms" \
--ligand_mpnn_use_side_chain_context 1 \
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10"
#23
python run.py \
--model_type "soluble_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/soluble_mpnn_default"
#24
python run.py \
--model_type "global_label_membrane_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/global_label_membrane_mpnn_0" \
--global_transmembrane_label 0
#25
python run.py \
--model_type "per_residue_label_membrane_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/per_residue_label_membrane_mpnn_default" \
--transmembrane_buried "C1 C2 C3 C11" \
--transmembrane_interface "C4 C5 C6 C22"
#26
python run.py \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/fasta_seq_separation" \
--fasta_seq_separation ":"
#27
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--out_folder "./outputs/pdb_path_multi" \
--seed 111
#28
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--fixed_residues_multi "./inputs/fix_residues_multi.json" \
--out_folder "./outputs/fixed_residues_multi" \
--seed 111
#29
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--redesigned_residues_multi "./inputs/redesigned_residues_multi.json" \
--out_folder "./outputs/redesigned_residues_multi" \
--seed 111
#30
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--omit_AA_per_residue_multi "./inputs/omit_AA_per_residue_multi.json" \
--out_folder "./outputs/omit_AA_per_residue_multi" \
--seed 111
#31
python run.py \
--pdb_path_multi "./inputs/pdb_ids.json" \
--bias_AA_per_residue_multi "./inputs/bias_AA_per_residue_multi.json" \
--out_folder "./outputs/bias_AA_per_residue_multi" \
--seed 111
#32
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--ligand_mpnn_cutoff_for_score "6.0" \
--out_folder "./outputs/ligand_mpnn_cutoff_for_score"
#33
python run.py \
--seed 111 \
--pdb_path "./inputs/2GFB.pdb" \
--out_folder "./outputs/insertion_code" \
--redesigned_residues "B82 B82A B82B B82C" \
--parse_these_chains_only "B"

55
sc_examples.sh Normal file
View File

@@ -0,0 +1,55 @@
#1 design a new sequence and pack side chains (return 1 side chain packing sample - fast)
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_default_fast" \
--pack_side_chains 1 \
--number_of_packs_per_design 0 \
--pack_with_ligand_context 1
#2 design a new sequence and pack side chains (return 4 side chain packing samples)
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_default" \
--pack_side_chains 1 \
--number_of_packs_per_design 4 \
--pack_with_ligand_context 1
#3 fix specific residues for design and packing
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_fixed_residues" \
--pack_side_chains 1 \
--number_of_packs_per_design 4 \
--pack_with_ligand_context 1 \
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
--repack_everything 0
#4 fix specific residues for sequence design but repack everything
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_fixed_residues_full_repack" \
--pack_side_chains 1 \
--number_of_packs_per_design 4 \
--pack_with_ligand_context 1 \
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
--repack_everything 1
#5 design a new sequence using LigandMPNN but pack side chains without considering ligand/DNA etc atoms
python run.py \
--model_type "ligand_mpnn" \
--seed 111 \
--pdb_path "./inputs/1BC8.pdb" \
--out_folder "./outputs/sc_no_context" \
--pack_side_chains 1 \
--number_of_packs_per_design 4 \
--pack_with_ligand_context 0

1158
sc_utils.py Normal file

File diff suppressed because it is too large Load Diff

549
score.py Normal file
View File

@@ -0,0 +1,549 @@
import argparse
import json
import os.path
import random
import sys
import numpy as np
import torch
from data_utils import (
element_dict_rev,
alphabet,
restype_int_to_str,
featurize,
parse_PDB,
)
from model_utils import ProteinMPNN
def main(args) -> None:
"""
Inference function
"""
if args.seed:
seed = args.seed
else:
seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
folder_for_outputs = args.out_folder
base_folder = folder_for_outputs
if base_folder[-1] != "/":
base_folder = base_folder + "/"
if not os.path.exists(base_folder):
os.makedirs(base_folder, exist_ok=True)
if args.model_type == "protein_mpnn":
checkpoint_path = args.checkpoint_protein_mpnn
elif args.model_type == "ligand_mpnn":
checkpoint_path = args.checkpoint_ligand_mpnn
elif args.model_type == "per_residue_label_membrane_mpnn":
checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
elif args.model_type == "global_label_membrane_mpnn":
checkpoint_path = args.checkpoint_global_label_membrane_mpnn
elif args.model_type == "soluble_mpnn":
checkpoint_path = args.checkpoint_soluble_mpnn
else:
print("Choose one of the available models")
sys.exit()
checkpoint = torch.load(checkpoint_path, map_location=device)
if args.model_type == "ligand_mpnn":
atom_context_num = checkpoint["atom_context_num"]
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
k_neighbors = checkpoint["num_edges"]
else:
atom_context_num = 1
ligand_mpnn_use_side_chain_context = 0
k_neighbors = checkpoint["num_edges"]
model = ProteinMPNN(
node_features=128,
edge_features=128,
hidden_dim=128,
num_encoder_layers=3,
num_decoder_layers=3,
k_neighbors=k_neighbors,
device=device,
atom_context_num=atom_context_num,
model_type=args.model_type,
ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
if args.pdb_path_multi:
with open(args.pdb_path_multi, "r") as fh:
pdb_paths = list(json.load(fh))
else:
pdb_paths = [args.pdb_path]
if args.fixed_residues_multi:
with open(args.fixed_residues_multi, "r") as fh:
fixed_residues_multi = json.load(fh)
else:
fixed_residues = [item for item in args.fixed_residues.split()]
fixed_residues_multi = {}
for pdb in pdb_paths:
fixed_residues_multi[pdb] = fixed_residues
if args.redesigned_residues_multi:
with open(args.redesigned_residues_multi, "r") as fh:
redesigned_residues_multi = json.load(fh)
else:
redesigned_residues = [item for item in args.redesigned_residues.split()]
redesigned_residues_multi = {}
for pdb in pdb_paths:
redesigned_residues_multi[pdb] = redesigned_residues
# loop over PDB paths
for pdb in pdb_paths:
if args.verbose:
print("Designing protein from this path:", pdb)
fixed_residues = fixed_residues_multi[pdb]
redesigned_residues = redesigned_residues_multi[pdb]
protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
pdb,
device=device,
chains=args.parse_these_chains_only,
parse_all_atoms=args.ligand_mpnn_use_side_chain_context,
parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy
)
# make chain_letter + residue_idx + insertion_code mapping to integers
R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
encoded_residues = []
for i, R_idx_item in enumerate(R_idx_list):
tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
encoded_residues.append(tmp)
encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
encoded_residue_dict_rev = dict(
zip(list(range(len(encoded_residues))), encoded_residues)
)
fixed_positions = torch.tensor(
[int(item not in fixed_residues) for item in encoded_residues],
device=device,
)
redesigned_positions = torch.tensor(
[int(item not in redesigned_residues) for item in encoded_residues],
device=device,
)
# specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
if args.transmembrane_buried:
buried_residues = [item for item in args.transmembrane_buried.split()]
buried_positions = torch.tensor(
[int(item in buried_residues) for item in encoded_residues],
device=device,
)
else:
buried_positions = torch.zeros_like(fixed_positions)
if args.transmembrane_interface:
interface_residues = [item for item in args.transmembrane_interface.split()]
interface_positions = torch.tensor(
[int(item in interface_residues) for item in encoded_residues],
device=device,
)
else:
interface_positions = torch.zeros_like(fixed_positions)
protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
1 - interface_positions
) + 1 * interface_positions * (1 - buried_positions)
if args.model_type == "global_label_membrane_mpnn":
protein_dict["membrane_per_residue_labels"] = (
args.global_transmembrane_label + 0 * fixed_positions
)
if type(args.chains_to_design) == str:
chains_to_design_list = args.chains_to_design.split(",")
else:
chains_to_design_list = protein_dict["chain_letters"]
chain_mask = torch.tensor(
np.array(
[
item in chains_to_design_list
for item in protein_dict["chain_letters"]
],
dtype=np.int32,
),
device=device,
)
# create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
if redesigned_residues:
protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
elif fixed_residues:
protein_dict["chain_mask"] = chain_mask * fixed_positions
else:
protein_dict["chain_mask"] = chain_mask
if args.verbose:
PDB_residues_to_be_redesigned = [
encoded_residue_dict_rev[item]
for item in range(protein_dict["chain_mask"].shape[0])
if protein_dict["chain_mask"][item] == 1
]
PDB_residues_to_be_fixed = [
encoded_residue_dict_rev[item]
for item in range(protein_dict["chain_mask"].shape[0])
if protein_dict["chain_mask"][item] == 0
]
print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
print("These residues will be fixed: ", PDB_residues_to_be_fixed)
# specify which residues are linked
if args.symmetry_residues:
symmetry_residues_list_of_lists = [
x.split(",") for x in args.symmetry_residues.split("|")
]
remapped_symmetry_residues = []
for t_list in symmetry_residues_list_of_lists:
tmp_list = []
for t in t_list:
tmp_list.append(encoded_residue_dict[t])
remapped_symmetry_residues.append(tmp_list)
else:
remapped_symmetry_residues = [[]]
if args.homo_oligomer:
if args.verbose:
print("Designing HOMO-OLIGOMER")
chain_letters_set = list(set(chain_letters_list))
reference_chain = chain_letters_set[0]
lc = len(reference_chain)
residue_indices = [
item[lc:] for item in encoded_residues if item[:lc] == reference_chain
]
remapped_symmetry_residues = []
for res in residue_indices:
tmp_list = []
tmp_w_list = []
for chain in chain_letters_set:
name = chain + res
tmp_list.append(encoded_residue_dict[name])
tmp_w_list.append(1 / len(chain_letters_set))
remapped_symmetry_residues.append(tmp_list)
# set other atom bfactors to 0.0
if other_atoms:
other_bfactors = other_atoms.getBetas()
other_atoms.setBetas(other_bfactors * 0.0)
# adjust input PDB name by dropping .pdb if it does exist
name = pdb[pdb.rfind("/") + 1 :]
if name[-4:] == ".pdb":
name = name[:-4]
with torch.no_grad():
# run featurize to remap R_idx and add batch dimension
if args.verbose:
if "Y" in list(protein_dict):
atom_coords = protein_dict["Y"].cpu().numpy()
atom_types = list(protein_dict["Y_t"].cpu().numpy())
atom_mask = list(protein_dict["Y_m"].cpu().numpy())
number_of_atoms_parsed = np.sum(atom_mask)
else:
print("No ligand atoms parsed")
number_of_atoms_parsed = 0
atom_types = ""
atom_coords = []
if number_of_atoms_parsed == 0:
print("No ligand atoms parsed")
elif args.model_type == "ligand_mpnn":
print(
f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
)
for i, atom_type in enumerate(atom_types):
print(
f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
)
feature_dict = featurize(
protein_dict,
cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
use_atom_context=args.ligand_mpnn_use_atom_context,
number_of_ligand_atoms=atom_context_num,
model_type=args.model_type,
)
feature_dict["batch_size"] = args.batch_size
B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
# add additional keys to the feature dictionary
feature_dict["symmetry_residues"] = remapped_symmetry_residues
logits_list = []
probs_list = []
log_probs_list = []
decoding_order_list = []
for _ in range(args.number_of_batches):
feature_dict["randn"] = torch.randn(
[feature_dict["batch_size"], feature_dict["mask"].shape[1]],
device=device,
)
if args.autoregressive_score:
score_dict = model.score(feature_dict, use_sequence=args.use_sequence)
elif args.single_aa_score:
score_dict = model.single_aa_score(feature_dict, use_sequence=args.use_sequence)
else:
print("Set either autoregressive_score or single_aa_score to True")
sys.exit()
logits_list.append(score_dict["logits"])
log_probs_list.append(score_dict["log_probs"])
probs_list.append(torch.exp(score_dict["log_probs"]))
decoding_order_list.append(score_dict["decoding_order"])
log_probs_stack = torch.cat(log_probs_list, 0)
logits_stack = torch.cat(logits_list, 0)
probs_stack = torch.cat(probs_list, 0)
decoding_order_stack = torch.cat(decoding_order_list, 0)
output_stats_path = base_folder + name + args.file_ending + ".pt"
out_dict = {}
out_dict["logits"] = logits_stack.cpu().numpy()
out_dict["probs"] = probs_stack.cpu().numpy()
out_dict["log_probs"] = log_probs_stack.cpu().numpy()
out_dict["decoding_order"] = decoding_order_stack.cpu().numpy()
out_dict["native_sequence"] = feature_dict["S"][0].cpu().numpy()
out_dict["mask"] = feature_dict["mask"][0].cpu().numpy()
out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu().numpy() #this affects decoding order
out_dict["seed"] = seed
out_dict["alphabet"] = alphabet
out_dict["residue_names"] = encoded_residue_dict_rev
mean_probs = np.mean(out_dict["probs"], 0)
std_probs = np.std(out_dict["probs"], 0)
sequence = [restype_int_to_str[AA] for AA in out_dict["native_sequence"]]
mean_dict = {}
std_dict = {}
for residue in range(L):
mean_dict_ = dict(zip(alphabet, mean_probs[residue]))
mean_dict[encoded_residue_dict_rev[residue]] = mean_dict_
std_dict_ = dict(zip(alphabet, std_probs[residue]))
std_dict[encoded_residue_dict_rev[residue]] = std_dict_
out_dict["sequence"] = sequence
out_dict["mean_of_probs"] = mean_dict
out_dict["std_of_probs"] = std_dict
torch.save(out_dict, output_stats_path)
if __name__ == "__main__":
argparser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
argparser.add_argument(
"--model_type",
type=str,
default="protein_mpnn",
help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
)
# protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
# ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
# per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
# global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
# soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
argparser.add_argument(
"--checkpoint_protein_mpnn",
type=str,
default="./model_params/proteinmpnn_v_48_020.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--checkpoint_ligand_mpnn",
type=str,
default="./model_params/ligandmpnn_v_32_010_25.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--checkpoint_per_residue_label_membrane_mpnn",
type=str,
default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--checkpoint_global_label_membrane_mpnn",
type=str,
default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
help="Path to model weights.",
)
argparser.add_argument(
"--checkpoint_soluble_mpnn",
type=str,
default="./model_params/solublempnn_v_48_020.pt",
help="Path to model weights.",
)
argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
argparser.add_argument(
"--pdb_path", type=str, default="", help="Path to the input PDB."
)
argparser.add_argument(
"--pdb_path_multi",
type=str,
default="",
help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
)
argparser.add_argument(
"--fixed_residues",
type=str,
default="",
help="Provide fixed residues, A12 A13 A14 B2 B25",
)
argparser.add_argument(
"--fixed_residues_multi",
type=str,
default="",
help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
)
argparser.add_argument(
"--redesigned_residues",
type=str,
default="",
help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
)
argparser.add_argument(
"--redesigned_residues_multi",
type=str,
default="",
help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
)
argparser.add_argument(
"--symmetry_residues",
type=str,
default="",
help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
)
argparser.add_argument(
"--homo_oligomer",
type=int,
default=0,
help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
)
argparser.add_argument(
"--out_folder",
type=str,
help="Path to a folder to output scores, e.g. /home/out/",
)
argparser.add_argument(
"--file_ending", type=str, default="", help="adding_string_to_the_end"
)
argparser.add_argument(
"--zero_indexed",
type=str,
default=0,
help="1 - to start output PDB numbering with 0",
)
argparser.add_argument(
"--seed",
type=int,
default=0,
help="Set seed for torch, numpy, and python random.",
)
argparser.add_argument(
"--batch_size",
type=int,
default=1,
help="Number of sequence to generate per one pass.",
)
argparser.add_argument(
"--number_of_batches",
type=int,
default=1,
help="Number of times to design sequence using a chosen batch size.",
)
argparser.add_argument(
"--ligand_mpnn_use_atom_context",
type=int,
default=1,
help="1 - use atom context, 0 - do not use atom context.",
)
argparser.add_argument(
"--ligand_mpnn_use_side_chain_context",
type=int,
default=0,
help="Flag to use side chain atoms as ligand context for the fixed residues",
)
argparser.add_argument(
"--ligand_mpnn_cutoff_for_score",
type=float,
default=8.0,
help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
)
argparser.add_argument(
"--chains_to_design",
type=str,
default=None,
help="Specify which chains to redesign, all others will be kept fixed.",
)
argparser.add_argument(
"--parse_these_chains_only",
type=str,
default="",
help="Provide chains letters for parsing backbones, 'ABCF'",
)
argparser.add_argument(
"--transmembrane_buried",
type=str,
default="",
help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
)
argparser.add_argument(
"--transmembrane_interface",
type=str,
default="",
help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
)
argparser.add_argument(
"--global_transmembrane_label",
type=int,
default=0,
help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
)
argparser.add_argument(
"--parse_atoms_with_zero_occupancy",
type=int,
default=0,
help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
)
argparser.add_argument(
"--use_sequence",
type=int,
default=1,
help="1 - get scores using amino acid sequence info; 0 - get scores using backbone info only",
)
argparser.add_argument(
"--autoregressive_score",
type=int,
default=0,
help="1 - run autoregressive scoring function; p(AA_1|backbone); p(AA_2|backbone, AA_1) etc, 0 - False",
)
argparser.add_argument(
"--single_aa_score",
type=int,
default=1,
help="1 - run single amino acid scoring function; p(AA_i|backbone, AA_{all except ith one}), 0 - False",
)
args = argparser.parse_args()
main(args)