Add LigandMPNN Nextflow pipeline for protein sequence design
This commit is contained in:
96
Dockerfile
Normal file
96
Dockerfile
Normal file
@@ -0,0 +1,96 @@
|
||||
FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04
|
||||
|
||||
# Set environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies including build tools for ProDy compilation
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
git \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.11 \
|
||||
python3.11-venv \
|
||||
python3.11-dev \
|
||||
python3-pip \
|
||||
procps \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 \
|
||||
&& update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
|
||||
|
||||
# Upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel
|
||||
|
||||
# Clone LigandMPNN repository
|
||||
RUN git clone https://github.com/dauparas/LigandMPNN.git /app/LigandMPNN
|
||||
|
||||
# Set working directory to LigandMPNN
|
||||
WORKDIR /app/LigandMPNN
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip3 install --no-cache-dir \
|
||||
biopython==1.79 \
|
||||
filelock==3.13.1 \
|
||||
fsspec==2024.3.1 \
|
||||
Jinja2==3.1.3 \
|
||||
MarkupSafe==2.1.5 \
|
||||
mpmath==1.3.0 \
|
||||
networkx==3.2.1 \
|
||||
numpy==1.23.5 \
|
||||
ProDy==2.4.1 \
|
||||
pyparsing==3.1.1 \
|
||||
scipy==1.12.0 \
|
||||
sympy==1.12 \
|
||||
typing_extensions==4.10.0 \
|
||||
ml-collections==0.1.1 \
|
||||
dm-tree==0.1.8
|
||||
|
||||
# Install PyTorch with CUDA support
|
||||
RUN pip3 install --no-cache-dir \
|
||||
torch==2.2.1 \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# Download model parameters
|
||||
RUN mkdir -p /app/LigandMPNN/model_params && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_002.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_002.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_010.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_010.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_020.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_030.pt -O /app/LigandMPNN/model_params/proteinmpnn_v_48_030.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_005_25.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_010_25.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_020_25.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_25.pt -O /app/LigandMPNN/model_params/ligandmpnn_v_32_030_25.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_002.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_002.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_010.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_010.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_020.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_030.pt -O /app/LigandMPNN/model_params/solublempnn_v_48_030.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/per_residue_label_membrane_mpnn_v_48_020.pt -O /app/LigandMPNN/model_params/per_residue_label_membrane_mpnn_v_48_020.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/global_label_membrane_mpnn_v_48_020.pt -O /app/LigandMPNN/model_params/global_label_membrane_mpnn_v_48_020.pt && \
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_sc_v_32_002_16.pt -O /app/LigandMPNN/model_params/ligandmpnn_sc_v_32_002_16.pt && \
|
||||
ls -la /app/LigandMPNN/model_params/
|
||||
|
||||
# Create wrapper script for ligandmpnn (using absolute path)
|
||||
RUN printf '#!/bin/bash\npython /app/LigandMPNN/run.py "$@"\n' > /usr/local/bin/ligandmpnn && \
|
||||
chmod +x /usr/local/bin/ligandmpnn
|
||||
|
||||
# Create wrapper script for scoring (using absolute path)
|
||||
RUN printf '#!/bin/bash\npython /app/LigandMPNN/score.py "$@"\n' > /usr/local/bin/ligandmpnn-score && \
|
||||
chmod +x /usr/local/bin/ligandmpnn-score
|
||||
|
||||
# Create input/output directories
|
||||
RUN mkdir -p /app/inputs /app/outputs
|
||||
|
||||
# Set environment variables for model paths
|
||||
ENV LIGANDMPNN_MODEL_DIR=/app/LigandMPNN/model_params
|
||||
ENV PYTHONPATH=/app/LigandMPNN:${PYTHONPATH:-}
|
||||
|
||||
WORKDIR /app/LigandMPNN
|
||||
|
||||
CMD ["python", "run.py", "--help"]
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Justas Dauparas
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
636
README.md
Normal file
636
README.md
Normal file
@@ -0,0 +1,636 @@
|
||||
## LigandMPNN
|
||||
|
||||
This package provides inference code for [LigandMPNN](https://www.biorxiv.org/content/10.1101/2023.12.22.573103v1) & [ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) models. The code and model parameters are available under the MIT license.
|
||||
|
||||
Third party code: side chain packing uses helper functions from [Openfold](https://github.com/aqlaboratory/openfold).
|
||||
|
||||
### Running the code
|
||||
```
|
||||
git clone https://github.com/dauparas/LigandMPNN.git
|
||||
cd LigandMPNN
|
||||
bash get_model_params.sh "./model_params"
|
||||
|
||||
#setup your conda/or other environment
|
||||
#conda create -n ligandmpnn_env python=3.11
|
||||
#pip3 install -r requirements.txt
|
||||
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/default"
|
||||
```
|
||||
|
||||
### Dependencies
|
||||
To run the model you will need to have Python>=3.0, PyTorch, Numpy installed, and to read/write PDB files you will need [Prody](https://pypi.org/project/ProDy/).
|
||||
|
||||
For example to make a new conda environment for LigandMPNN run:
|
||||
```
|
||||
conda create -n ligandmpnn_env python=3.11
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
### Main differences compared with [ProteinMPNN](https://github.com/dauparas/ProteinMPNN) code
|
||||
- Input PDBs are parsed using [Prody](https://pypi.org/project/ProDy/) preserving protein residue indices, chain letters, and insertion codes. If there are missing residues in the input structure the output fasta file won't have added `X` to fill the gaps. The script outputs .fasta and .pdb files. It's recommended to use .pdb files since they will hold information about chain letters and residue indices.
|
||||
- Adding bias, fixing residues, and selecting residues to be redesigned now can be done using residue indices directly, e.g. A23 (means chain A residue with index 23), B42D (chain B, residue 42, insertion code D).
|
||||
- Model writes to fasta files: `overall_confidence`, `ligand_confidence` which reflect the average confidence/probability (with T=1.0) over the redesigned residues `overall_confidence=exp[-mean_over_residues(log_probs)]`. Higher numbers mean the model is more confident about that sequence. min_value=0.0; max_value=1.0. Sequence recovery with respect to the input sequence is calculated only over the redesigned residues.
|
||||
|
||||
### Model parameters
|
||||
To download model parameters run:
|
||||
```
|
||||
bash get_model_params.sh "./model_params"
|
||||
```
|
||||
|
||||
### Available models
|
||||
|
||||
To run the model of your choice specify `--model_type` and optionally the model checkpoint path. Available models:
|
||||
- ProteinMPNN
|
||||
```
|
||||
--model_type "protein_mpnn"
|
||||
--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_002.pt" #noised with 0.02A Gaussian noise
|
||||
--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_010.pt" #noised with 0.10A Gaussian noise
|
||||
--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_020.pt" #noised with 0.20A Gaussian noise
|
||||
--checkpoint_protein_mpnn "./model_params/proteinmpnn_v_48_030.pt" #noised with 0.30A Gaussian noise
|
||||
```
|
||||
- LigandMPNN
|
||||
```
|
||||
--model_type "ligand_mpnn"
|
||||
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" #noised with 0.05A Gaussian noise
|
||||
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_010_25.pt" #noised with 0.10A Gaussian noise
|
||||
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_020_25.pt" #noised with 0.20A Gaussian noise
|
||||
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_030_25.pt" #noised with 0.30A Gaussian noise
|
||||
```
|
||||
- SolubleMPNN
|
||||
```
|
||||
--model_type "soluble_mpnn"
|
||||
--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_002.pt" #noised with 0.02A Gaussian noise
|
||||
--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_010.pt" #noised with 0.10A Gaussian noise
|
||||
--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_020.pt" #noised with 0.20A Gaussian noise
|
||||
--checkpoint_soluble_mpnn "./model_params/solublempnn_v_48_030.pt" #noised with 0.30A Gaussian noise
|
||||
```
|
||||
- ProteinMPNN with global membrane label
|
||||
```
|
||||
--model_type "global_label_membrane_mpnn"
|
||||
--checkpoint_global_label_membrane_mpnn "./model_params/global_label_membrane_mpnn_v_48_020.pt" #noised with 0.20A Gaussian noise
|
||||
```
|
||||
- ProteinMPNN with per residue membrane label
|
||||
```
|
||||
--model_type "per_residue_label_membrane_mpnn"
|
||||
--checkpoint_per_residue_label_membrane_mpnn "./model_params/per_residue_label_membrane_mpnn_v_48_020.pt" #noised with 0.20A Gaussian noise
|
||||
```
|
||||
- Side chain packing model
|
||||
```
|
||||
--checkpoint_path_sc "./model_params/ligandmpnn_sc_v_32_002_16.pt"
|
||||
```
|
||||
## Design examples
|
||||
### 1 default
|
||||
Default settings will run ProteinMPNN.
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/default"
|
||||
```
|
||||
### 2 --temperature
|
||||
`--temperature 0.05` Change sampling temperature (higher temperature gives more sequence diversity).
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--temperature 0.05 \
|
||||
--out_folder "./outputs/temperature"
|
||||
```
|
||||
### 3 --seed
|
||||
`--seed` Not selecting a seed will run with a random seed. Running this multiple times will give different results.
|
||||
```
|
||||
python run.py \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/random_seed"
|
||||
```
|
||||
### 4 --verbose
|
||||
`--verbose 0` Do not print any statements.
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--verbose 0 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/verbose"
|
||||
```
|
||||
### 5 --save_stats
|
||||
`--save_stats 1` Save sequence design statistics.
|
||||
```
|
||||
#['generated_sequences', 'sampling_probs', 'log_probs', 'decoding_order', 'native_sequence', 'mask', 'chain_mask', 'seed', 'temperature']
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/save_stats" \
|
||||
--save_stats 1
|
||||
```
|
||||
### 6 --fixed_residues
|
||||
`--fixed_residues` Fixing specific amino acids. This example fixes the first 10 residues in chain C and adds global bias towards A (alanine). The output should have all alanines except the first 10 residues should be the same as in the input sequence since those are fixed.
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/fix_residues" \
|
||||
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
|
||||
--bias_AA "A:10.0"
|
||||
```
|
||||
|
||||
### 7 --redesigned_residues
|
||||
`--redesigned_residues` Specifying which residues need to be designed. This example redesigns the first 10 residues while fixing everything else.
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/redesign_residues" \
|
||||
--redesigned_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
|
||||
--bias_AA "A:10.0"
|
||||
```
|
||||
|
||||
### 8 --number_of_batches
|
||||
Design 15 sequences; with batch size 3 (can be 1 when using CPUs) and the number of batches 5.
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/batch_size" \
|
||||
--batch_size 3 \
|
||||
--number_of_batches 5
|
||||
```
|
||||
### 9 --bias_AA
|
||||
Global amino acid bias. In this example, output sequences are biased towards W, P, C and away from A.
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--bias_AA "W:3.0,P:3.0,C:3.0,A:-3.0" \
|
||||
--out_folder "./outputs/global_bias"
|
||||
```
|
||||
### 10 --bias_AA_per_residue
|
||||
Specify per residue amino acid bias, e.g. make residues C1, C3, C5, and C7 to be prolines.
|
||||
```
|
||||
# {
|
||||
# "C1": {"G": -0.3, "C": -2.0, "P": 10.8},
|
||||
# "C3": {"P": 10.0},
|
||||
# "C5": {"G": -1.3, "P": 10.0},
|
||||
# "C7": {"G": -1.3, "P": 10.0}
|
||||
# }
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--bias_AA_per_residue "./inputs/bias_AA_per_residue.json" \
|
||||
--out_folder "./outputs/per_residue_bias"
|
||||
```
|
||||
### 11 --omit_AA
|
||||
Global amino acid restrictions. This is equivalent to using `--bias_AA` and setting bias to be a large negative number. The output should be just made of E, K, A.
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--omit_AA "CDFGHILMNPQRSTVWY" \
|
||||
--out_folder "./outputs/global_omit"
|
||||
```
|
||||
|
||||
### 12 --omit_AA_per_residue
|
||||
Per residue amino acid restrictions.
|
||||
```
|
||||
# {
|
||||
# "C1": "ACDEFGHIKLMNPQRSTVW",
|
||||
# "C3": "ACDEFGHIKLMNPQRSTVW",
|
||||
# "C5": "ACDEFGHIKLMNPQRSTVW",
|
||||
# "C7": "ACDEFGHIKLMNPQRSTVW"
|
||||
# }
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--omit_AA_per_residue "./inputs/omit_AA_per_residue.json" \
|
||||
--out_folder "./outputs/per_residue_omit"
|
||||
```
|
||||
### 13 --symmetry_residues
|
||||
### 13 --symmetry_weights
|
||||
Designing sequences with symmetry, e.g. homooligomer/2-state proteins, etc. In this example make C1=C2=C3, also C4=C5, and C6=C7.
|
||||
```
|
||||
#total_logits += symmetry_weights[t]*logits
|
||||
#probs = torch.nn.functional.softmax((total_logits+bias_t) / temperature, dim=-1)
|
||||
#total_logits_123 = 0.33*logits_1+0.33*logits_2+0.33*logits_3
|
||||
#output should be ***ooxx
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/symmetry" \
|
||||
--symmetry_residues "C1,C2,C3|C4,C5|C6,C7" \
|
||||
--symmetry_weights "0.33,0.33,0.33|0.5,0.5|0.5,0.5"
|
||||
```
|
||||
|
||||
|
||||
### 14 --homo_oligomer
|
||||
Design homooligomer sequences. This automatically sets `--symmetry_residues` and `--symmetry_weights` assuming equal weighting from all chains.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/4GYT.pdb" \
|
||||
--out_folder "./outputs/homooligomer" \
|
||||
--homo_oligomer 1 \
|
||||
--number_of_batches 2
|
||||
```
|
||||
|
||||
### 15 --file_ending
|
||||
Outputs will have a specified ending; e.g. `1BC8_xyz.fa` instead of `1BC8.fa`
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/file_ending" \
|
||||
--file_ending "_xyz"
|
||||
```
|
||||
|
||||
### 16 --zero_indexed
|
||||
Zero indexed names in /backbones/1BC8_0.pdb, 1BC8_1.pdb, 1BC8_2.pdb etc
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/zero_indexed" \
|
||||
--zero_indexed 1 \
|
||||
--number_of_batches 2
|
||||
```
|
||||
|
||||
### 17 --chains_to_design
|
||||
Specify which chains (e.g. "A,B,C") need to be redesigned, other chains will be kept fixed. Outputs in seqs/backbones will still have atoms/sequences for the whole input PDB.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/4GYT.pdb" \
|
||||
--out_folder "./outputs/chains_to_design" \
|
||||
--chains_to_design "A,B"
|
||||
```
|
||||
### 18 --parse_these_chains_only
|
||||
Parse and design only specified chains (e.g. "A,B,C"). Outputs will have only specified chains.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/4GYT.pdb" \
|
||||
--out_folder "./outputs/parse_these_chains_only" \
|
||||
--parse_these_chains_only "A,B"
|
||||
```
|
||||
|
||||
### 19 --model_type "ligand_mpnn"
|
||||
Run LigandMPNN with default settings.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/ligandmpnn_default"
|
||||
```
|
||||
|
||||
### 20 --checkpoint_ligand_mpnn
|
||||
Run LigandMPNN using 0.05A model by specifying `--checkpoint_ligand_mpnn` flag.
|
||||
```
|
||||
python run.py \
|
||||
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/ligandmpnn_v_32_005_25"
|
||||
```
|
||||
### 21 --ligand_mpnn_use_atom_context
|
||||
Setting `--ligand_mpnn_use_atom_context 0` will mask all ligand atoms. This can be used to assess how much ligand atoms affect AA probabilities.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/ligandmpnn_no_context" \
|
||||
--ligand_mpnn_use_atom_context 0
|
||||
```
|
||||
|
||||
### 22 --ligand_mpnn_use_side_chain_context
|
||||
Use fixed residue side chain atoms as extra ligand atoms.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/ligandmpnn_use_side_chain_atoms" \
|
||||
--ligand_mpnn_use_side_chain_context 1 \
|
||||
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10"
|
||||
```
|
||||
|
||||
### 23 --model_type "soluble_mpnn"
|
||||
Run SolubleMPNN (ProteinMPNN-like model with only soluble proteins in the training dataset).
|
||||
```
|
||||
python run.py \
|
||||
--model_type "soluble_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/soluble_mpnn_default"
|
||||
```
|
||||
|
||||
### 24 --model_type "global_label_membrane_mpnn"
|
||||
Run global label membrane MPNN (trained with extra input - binary label soluble vs not) `--global_transmembrane_label #1 - membrane, 0 - soluble`.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "global_label_membrane_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/global_label_membrane_mpnn_0" \
|
||||
--global_transmembrane_label 0
|
||||
```
|
||||
|
||||
### 25 --model_type "per_residue_label_membrane_mpnn"
|
||||
Run per residue label membrane MPNN (trained with extra input per residue specifying buried (hydrophobic), interface (polar), or other type residues; 3 classes).
|
||||
```
|
||||
python run.py \
|
||||
--model_type "per_residue_label_membrane_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/per_residue_label_membrane_mpnn_default" \
|
||||
--transmembrane_buried "C1 C2 C3 C11" \
|
||||
--transmembrane_interface "C4 C5 C6 C22"
|
||||
```
|
||||
|
||||
### 26 --fasta_seq_separation
|
||||
Choose a symbol to put between different chains in fasta output format. It's recommended to PDB output format to deal with residue jumps and multiple chain parsing.
|
||||
```
|
||||
python run.py \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/fasta_seq_separation" \
|
||||
--fasta_seq_separation ":"
|
||||
```
|
||||
|
||||
### 27 --pdb_path_multi
|
||||
Specify multiple PDB input paths. This is more efficient since the model needs to be loaded from the checkpoint once.
|
||||
```
|
||||
#{
|
||||
#"./inputs/1BC8.pdb": "",
|
||||
#"./inputs/4GYT.pdb": ""
|
||||
#}
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--out_folder "./outputs/pdb_path_multi" \
|
||||
--seed 111
|
||||
```
|
||||
|
||||
### 28 --fixed_residues_multi
|
||||
Specify fixed residues when using `--pdb_path_multi` flag.
|
||||
```
|
||||
#{
|
||||
#"./inputs/1BC8.pdb": "C1 C2 C3 C4 C5 C10 C22",
|
||||
#"./inputs/4GYT.pdb": "A7 A8 A9 A10 A11 A12 A13 B38"
|
||||
#}
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--fixed_residues_multi "./inputs/fix_residues_multi.json" \
|
||||
--out_folder "./outputs/fixed_residues_multi" \
|
||||
--seed 111
|
||||
```
|
||||
|
||||
### 29 --redesigned_residues_multi
|
||||
Specify which residues need to be redesigned when using `--pdb_path_multi` flag.
|
||||
```
|
||||
#{
|
||||
#"./inputs/1BC8.pdb": "C1 C2 C3 C4 C5 C10",
|
||||
#"./inputs/4GYT.pdb": "A7 A8 A9 A10 A12 A13 B38"
|
||||
#}
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--redesigned_residues_multi "./inputs/redesigned_residues_multi.json" \
|
||||
--out_folder "./outputs/redesigned_residues_multi" \
|
||||
--seed 111
|
||||
```
|
||||
|
||||
### 30 --omit_AA_per_residue_multi
|
||||
Specify which residues need to be omitted when using `--pdb_path_multi` flag.
|
||||
```
|
||||
#{
|
||||
#"./inputs/1BC8.pdb": {"C1":"ACDEFGHILMNPQRSTVWY", "C2":"ACDEFGHILMNPQRSTVWY", "C3":"ACDEFGHILMNPQRSTVWY"},
|
||||
#"./inputs/4GYT.pdb": {"A7":"ACDEFGHILMNPQRSTVWY", "A8":"ACDEFGHILMNPQRSTVWY"}
|
||||
#}
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--omit_AA_per_residue_multi "./inputs/omit_AA_per_residue_multi.json" \
|
||||
--out_folder "./outputs/omit_AA_per_residue_multi" \
|
||||
--seed 111
|
||||
```
|
||||
|
||||
### 31 --bias_AA_per_residue_multi
|
||||
Specify amino acid biases per residue when using `--pdb_path_multi` flag.
|
||||
```
|
||||
#{
|
||||
#"./inputs/1BC8.pdb": {"C1":{"A":3.0, "P":-2.0}, "C2":{"W":10.0, "G":-0.43}},
|
||||
#"./inputs/4GYT.pdb": {"A7":{"Y":5.0, "S":-2.0}, "A8":{"M":3.9, "G":-0.43}}
|
||||
#}
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--bias_AA_per_residue_multi "./inputs/bias_AA_per_residue_multi.json" \
|
||||
--out_folder "./outputs/bias_AA_per_residue_multi" \
|
||||
--seed 111
|
||||
```
|
||||
|
||||
### 32 --ligand_mpnn_cutoff_for_score
|
||||
This sets the cutoff distance in angstroms to select residues that are considered to be close to ligand atoms. This flag only affects the `num_ligand_res` and `ligand_confidence` in the output fasta files.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--ligand_mpnn_cutoff_for_score "6.0" \
|
||||
--out_folder "./outputs/ligand_mpnn_cutoff_for_score"
|
||||
```
|
||||
|
||||
### 33 specifying residues with insertion codes
|
||||
You can specify residue using chain_id + residue_number + insersion_code; e.g. redesign only residue B82, B82A, B82B, B82C.
|
||||
```
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/2GFB.pdb" \
|
||||
--out_folder "./outputs/insertion_code" \
|
||||
--redesigned_residues "B82 B82A B82B B82C" \
|
||||
--parse_these_chains_only "B"
|
||||
```
|
||||
|
||||
### 34 parse atoms with zero occupancy
|
||||
Parse atoms in the PDB files with zero occupancy too.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/parse_atoms_with_zero_occupancy" \
|
||||
--parse_atoms_with_zero_occupancy 1
|
||||
```
|
||||
|
||||
## Scoring examples
|
||||
### Output dictionary
|
||||
```
|
||||
out_dict = {}
|
||||
out_dict["logits"] - raw logits from the model
|
||||
out_dict["probs"] - softmax(logits)
|
||||
out_dict["log_probs"] - log_softmax(logits)
|
||||
out_dict["decoding_order"] - decoding order used (logits will depend on the decoding order)
|
||||
out_dict["native_sequence"] - parsed input sequence in integers
|
||||
out_dict["mask"] - mask for missing residues (usually all ones)
|
||||
out_dict["chain_mask"] - controls which residues are decoded first
|
||||
out_dict["alphabet"] - amino acid alphabet used
|
||||
out_dict["residue_names"] - dictionary to map integers to residue_names, e.g. {0: "C10", 1: "C11"}
|
||||
out_dict["sequence"] - parsed input sequence in alphabet
|
||||
out_dict["mean_of_probs"] - averaged over batch_size*number_of_batches probabilities, [protein_length, 21]
|
||||
out_dict["std_of_probs"] - same as above, but std
|
||||
```
|
||||
|
||||
### 1 autoregressive with sequence info
|
||||
Get probabilities/scores for backbone-sequence pairs using autoregressive probabilities: p(AA_1|backbone), p(AA_2|backbone, AA_1) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
|
||||
```
|
||||
python score.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--autoregressive_score 1\
|
||||
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
|
||||
--out_folder "./outputs/autoregressive_score_w_seq" \
|
||||
--use_sequence 1\
|
||||
--batch_size 1 \
|
||||
--number_of_batches 10
|
||||
```
|
||||
### 2 autoregressive with backbone info only
|
||||
Get probabilities/scores for backbone using probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
|
||||
```
|
||||
python score.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--autoregressive_score 1\
|
||||
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
|
||||
--out_folder "./outputs/autoregressive_score_wo_seq" \
|
||||
--use_sequence 0\
|
||||
--batch_size 1 \
|
||||
--number_of_batches 10
|
||||
```
|
||||
### 3 single amino acid score with sequence info
|
||||
Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone, AA_{all except AA_1}), p(AA_2|backbone, AA_{all except AA_2}) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
|
||||
```
|
||||
python score.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--single_aa_score 1\
|
||||
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
|
||||
--out_folder "./outputs/single_aa_score_w_seq" \
|
||||
--use_sequence 1\
|
||||
--batch_size 1 \
|
||||
--number_of_batches 10
|
||||
```
|
||||
### 4 single amino acid score with backbone info only
|
||||
Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
|
||||
```
|
||||
python score.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--single_aa_score 1\
|
||||
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
|
||||
--out_folder "./outputs/single_aa_score_wo_seq" \
|
||||
--use_sequence 0\
|
||||
--batch_size 1 \
|
||||
--number_of_batches 10
|
||||
```
|
||||
|
||||
## Side chain packing examples
|
||||
|
||||
### 1 design a new sequence and pack side chains (return 1 side chain packing sample - fast)
|
||||
Design a new sequence using any of the available models and also pack side chains of the new sequence. Return only a single solution for the side chain packing.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_default_fast" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 0 \
|
||||
--pack_with_ligand_context 1
|
||||
```
|
||||
### 2 design a new sequence and pack side chains (return 4 side chain packing samples)
|
||||
Same as above, but returns 4 independent samples for side chains. b-factor shows log prob density per chi angle group.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_default" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 4 \
|
||||
--pack_with_ligand_context 1
|
||||
```
|
||||
|
||||
### 3 fix specific residues fors sequence design and packing
|
||||
This option will not repack side chains of the fixed residues, but use them as a context.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_fixed_residues" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 4 \
|
||||
--pack_with_ligand_context 1 \
|
||||
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
|
||||
--repack_everything 0
|
||||
```
|
||||
### 4 fix specific residues for sequence design but repack everything
|
||||
This option will repacks all the residues.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_fixed_residues_full_repack" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 4 \
|
||||
--pack_with_ligand_context 1 \
|
||||
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
|
||||
--repack_everything 1
|
||||
```
|
||||
|
||||
### 5 design a new sequence using LigandMPNN but pack side chains without considering ligand/DNA etc atoms
|
||||
You can run side chain packing without taking into account context atoms like DNA atoms. This most likely will results in side chain clashing with context atoms, but it might be interesting to see how model's uncertainty changes when ligand atoms are present vs not for side chain conformations.
|
||||
```
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_no_context" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 4 \
|
||||
--pack_with_ligand_context 0
|
||||
```
|
||||
|
||||
### Things to add
|
||||
- Support for ProteinMPNN CA-only model.
|
||||
- Examples for scoring sequences only.
|
||||
- Side-chain packing scripts.
|
||||
- TER
|
||||
|
||||
|
||||
### Citing this work
|
||||
If you use the code, please cite:
|
||||
```
|
||||
@article{dauparas2023atomic,
|
||||
title={Atomic context-conditioned protein sequence design using LigandMPNN},
|
||||
author={Dauparas, Justas and Lee, Gyu Rie and Pecoraro, Robert and An, Linna and Anishchenko, Ivan and Glasscock, Cameron and Baker, David},
|
||||
journal={Biorxiv},
|
||||
pages={2023--12},
|
||||
year={2023},
|
||||
publisher={Cold Spring Harbor Laboratory}
|
||||
}
|
||||
|
||||
@article{dauparas2022robust,
|
||||
title={Robust deep learning--based protein sequence design using ProteinMPNN},
|
||||
author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
|
||||
journal={Science},
|
||||
volume={378},
|
||||
number={6615},
|
||||
pages={49--56},
|
||||
year={2022},
|
||||
publisher={American Association for the Advancement of Science}
|
||||
}
|
||||
```
|
||||
988
data_utils.py
Normal file
988
data_utils.py
Normal file
@@ -0,0 +1,988 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils
|
||||
from prody import *
|
||||
|
||||
confProDy(verbosity="none")
|
||||
|
||||
restype_1to3 = {
|
||||
"A": "ALA",
|
||||
"R": "ARG",
|
||||
"N": "ASN",
|
||||
"D": "ASP",
|
||||
"C": "CYS",
|
||||
"Q": "GLN",
|
||||
"E": "GLU",
|
||||
"G": "GLY",
|
||||
"H": "HIS",
|
||||
"I": "ILE",
|
||||
"L": "LEU",
|
||||
"K": "LYS",
|
||||
"M": "MET",
|
||||
"F": "PHE",
|
||||
"P": "PRO",
|
||||
"S": "SER",
|
||||
"T": "THR",
|
||||
"W": "TRP",
|
||||
"Y": "TYR",
|
||||
"V": "VAL",
|
||||
"X": "UNK",
|
||||
}
|
||||
restype_str_to_int = {
|
||||
"A": 0,
|
||||
"C": 1,
|
||||
"D": 2,
|
||||
"E": 3,
|
||||
"F": 4,
|
||||
"G": 5,
|
||||
"H": 6,
|
||||
"I": 7,
|
||||
"K": 8,
|
||||
"L": 9,
|
||||
"M": 10,
|
||||
"N": 11,
|
||||
"P": 12,
|
||||
"Q": 13,
|
||||
"R": 14,
|
||||
"S": 15,
|
||||
"T": 16,
|
||||
"V": 17,
|
||||
"W": 18,
|
||||
"Y": 19,
|
||||
"X": 20,
|
||||
}
|
||||
restype_int_to_str = {
|
||||
0: "A",
|
||||
1: "C",
|
||||
2: "D",
|
||||
3: "E",
|
||||
4: "F",
|
||||
5: "G",
|
||||
6: "H",
|
||||
7: "I",
|
||||
8: "K",
|
||||
9: "L",
|
||||
10: "M",
|
||||
11: "N",
|
||||
12: "P",
|
||||
13: "Q",
|
||||
14: "R",
|
||||
15: "S",
|
||||
16: "T",
|
||||
17: "V",
|
||||
18: "W",
|
||||
19: "Y",
|
||||
20: "X",
|
||||
}
|
||||
alphabet = list(restype_str_to_int)
|
||||
|
||||
element_list = [
|
||||
"H",
|
||||
"He",
|
||||
"Li",
|
||||
"Be",
|
||||
"B",
|
||||
"C",
|
||||
"N",
|
||||
"O",
|
||||
"F",
|
||||
"Ne",
|
||||
"Na",
|
||||
"Mg",
|
||||
"Al",
|
||||
"Si",
|
||||
"P",
|
||||
"S",
|
||||
"Cl",
|
||||
"Ar",
|
||||
"K",
|
||||
"Ca",
|
||||
"Sc",
|
||||
"Ti",
|
||||
"V",
|
||||
"Cr",
|
||||
"Mn",
|
||||
"Fe",
|
||||
"Co",
|
||||
"Ni",
|
||||
"Cu",
|
||||
"Zn",
|
||||
"Ga",
|
||||
"Ge",
|
||||
"As",
|
||||
"Se",
|
||||
"Br",
|
||||
"Kr",
|
||||
"Rb",
|
||||
"Sr",
|
||||
"Y",
|
||||
"Zr",
|
||||
"Nb",
|
||||
"Mb",
|
||||
"Tc",
|
||||
"Ru",
|
||||
"Rh",
|
||||
"Pd",
|
||||
"Ag",
|
||||
"Cd",
|
||||
"In",
|
||||
"Sn",
|
||||
"Sb",
|
||||
"Te",
|
||||
"I",
|
||||
"Xe",
|
||||
"Cs",
|
||||
"Ba",
|
||||
"La",
|
||||
"Ce",
|
||||
"Pr",
|
||||
"Nd",
|
||||
"Pm",
|
||||
"Sm",
|
||||
"Eu",
|
||||
"Gd",
|
||||
"Tb",
|
||||
"Dy",
|
||||
"Ho",
|
||||
"Er",
|
||||
"Tm",
|
||||
"Yb",
|
||||
"Lu",
|
||||
"Hf",
|
||||
"Ta",
|
||||
"W",
|
||||
"Re",
|
||||
"Os",
|
||||
"Ir",
|
||||
"Pt",
|
||||
"Au",
|
||||
"Hg",
|
||||
"Tl",
|
||||
"Pb",
|
||||
"Bi",
|
||||
"Po",
|
||||
"At",
|
||||
"Rn",
|
||||
"Fr",
|
||||
"Ra",
|
||||
"Ac",
|
||||
"Th",
|
||||
"Pa",
|
||||
"U",
|
||||
"Np",
|
||||
"Pu",
|
||||
"Am",
|
||||
"Cm",
|
||||
"Bk",
|
||||
"Cf",
|
||||
"Es",
|
||||
"Fm",
|
||||
"Md",
|
||||
"No",
|
||||
"Lr",
|
||||
"Rf",
|
||||
"Db",
|
||||
"Sg",
|
||||
"Bh",
|
||||
"Hs",
|
||||
"Mt",
|
||||
"Ds",
|
||||
"Rg",
|
||||
"Cn",
|
||||
"Uut",
|
||||
"Fl",
|
||||
"Uup",
|
||||
"Lv",
|
||||
"Uus",
|
||||
"Uuo",
|
||||
]
|
||||
element_list = [item.upper() for item in element_list]
|
||||
# element_dict = dict(zip(element_list, range(1,len(element_list))))
|
||||
element_dict_rev = dict(zip(range(1, len(element_list)), element_list))
|
||||
|
||||
|
||||
def get_seq_rec(S: torch.Tensor, S_pred: torch.Tensor, mask: torch.Tensor):
|
||||
"""
|
||||
S : true sequence shape=[batch, length]
|
||||
S_pred : predicted sequence shape=[batch, length]
|
||||
mask : mask to compute average over the region shape=[batch, length]
|
||||
|
||||
average : averaged sequence recovery shape=[batch]
|
||||
"""
|
||||
match = S == S_pred
|
||||
average = torch.sum(match * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
return average
|
||||
|
||||
|
||||
def get_score(S: torch.Tensor, log_probs: torch.Tensor, mask: torch.Tensor):
|
||||
"""
|
||||
S : true sequence shape=[batch, length]
|
||||
log_probs : predicted sequence shape=[batch, length]
|
||||
mask : mask to compute average over the region shape=[batch, length]
|
||||
|
||||
average_loss : averaged categorical cross entropy (CCE) [batch]
|
||||
loss_per_resdue : per position CCE [batch, length]
|
||||
"""
|
||||
S_one_hot = torch.nn.functional.one_hot(S, 21)
|
||||
loss_per_residue = -(S_one_hot * log_probs).sum(-1) # [B, L]
|
||||
average_loss = torch.sum(loss_per_residue * mask, dim=-1) / (
|
||||
torch.sum(mask, dim=-1) + 1e-8
|
||||
)
|
||||
return average_loss, loss_per_residue
|
||||
|
||||
|
||||
def write_full_PDB(
|
||||
save_path: str,
|
||||
X: np.ndarray,
|
||||
X_m: np.ndarray,
|
||||
b_factors: np.ndarray,
|
||||
R_idx: np.ndarray,
|
||||
chain_letters: np.ndarray,
|
||||
S: np.ndarray,
|
||||
other_atoms=None,
|
||||
icodes=None,
|
||||
force_hetatm=False,
|
||||
):
|
||||
"""
|
||||
save_path : path where the PDB will be written to
|
||||
X : protein atom xyz coordinates shape=[length, 14, 3]
|
||||
X_m : protein atom mask shape=[length, 14]
|
||||
b_factors: shape=[length, 14]
|
||||
R_idx: protein residue indices shape=[length]
|
||||
chain_letters: protein chain letters shape=[length]
|
||||
S : protein amino acid sequence shape=[length]
|
||||
other_atoms: other atoms parsed by prody
|
||||
icodes: a list of insertion codes for the PDB; e.g. antibody loops
|
||||
"""
|
||||
|
||||
restype_1to3 = {
|
||||
"A": "ALA",
|
||||
"R": "ARG",
|
||||
"N": "ASN",
|
||||
"D": "ASP",
|
||||
"C": "CYS",
|
||||
"Q": "GLN",
|
||||
"E": "GLU",
|
||||
"G": "GLY",
|
||||
"H": "HIS",
|
||||
"I": "ILE",
|
||||
"L": "LEU",
|
||||
"K": "LYS",
|
||||
"M": "MET",
|
||||
"F": "PHE",
|
||||
"P": "PRO",
|
||||
"S": "SER",
|
||||
"T": "THR",
|
||||
"W": "TRP",
|
||||
"Y": "TYR",
|
||||
"V": "VAL",
|
||||
"X": "UNK",
|
||||
}
|
||||
restype_INTtoSTR = {
|
||||
0: "A",
|
||||
1: "C",
|
||||
2: "D",
|
||||
3: "E",
|
||||
4: "F",
|
||||
5: "G",
|
||||
6: "H",
|
||||
7: "I",
|
||||
8: "K",
|
||||
9: "L",
|
||||
10: "M",
|
||||
11: "N",
|
||||
12: "P",
|
||||
13: "Q",
|
||||
14: "R",
|
||||
15: "S",
|
||||
16: "T",
|
||||
17: "V",
|
||||
18: "W",
|
||||
19: "Y",
|
||||
20: "X",
|
||||
}
|
||||
restype_name_to_atom14_names = {
|
||||
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
|
||||
"ARG": [
|
||||
"N",
|
||||
"CA",
|
||||
"C",
|
||||
"O",
|
||||
"CB",
|
||||
"CG",
|
||||
"CD",
|
||||
"NE",
|
||||
"CZ",
|
||||
"NH1",
|
||||
"NH2",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
|
||||
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
|
||||
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
|
||||
"GLN": [
|
||||
"N",
|
||||
"CA",
|
||||
"C",
|
||||
"O",
|
||||
"CB",
|
||||
"CG",
|
||||
"CD",
|
||||
"OE1",
|
||||
"NE2",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
"GLU": [
|
||||
"N",
|
||||
"CA",
|
||||
"C",
|
||||
"O",
|
||||
"CB",
|
||||
"CG",
|
||||
"CD",
|
||||
"OE1",
|
||||
"OE2",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
|
||||
"HIS": [
|
||||
"N",
|
||||
"CA",
|
||||
"C",
|
||||
"O",
|
||||
"CB",
|
||||
"CG",
|
||||
"ND1",
|
||||
"CD2",
|
||||
"CE1",
|
||||
"NE2",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
|
||||
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
|
||||
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
|
||||
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
|
||||
"PHE": [
|
||||
"N",
|
||||
"CA",
|
||||
"C",
|
||||
"O",
|
||||
"CB",
|
||||
"CG",
|
||||
"CD1",
|
||||
"CD2",
|
||||
"CE1",
|
||||
"CE2",
|
||||
"CZ",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
|
||||
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
|
||||
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
|
||||
"TRP": [
|
||||
"N",
|
||||
"CA",
|
||||
"C",
|
||||
"O",
|
||||
"CB",
|
||||
"CG",
|
||||
"CD1",
|
||||
"CD2",
|
||||
"CE2",
|
||||
"CE3",
|
||||
"NE1",
|
||||
"CZ2",
|
||||
"CZ3",
|
||||
"CH2",
|
||||
],
|
||||
"TYR": [
|
||||
"N",
|
||||
"CA",
|
||||
"C",
|
||||
"O",
|
||||
"CB",
|
||||
"CG",
|
||||
"CD1",
|
||||
"CD2",
|
||||
"CE1",
|
||||
"CE2",
|
||||
"CZ",
|
||||
"OH",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
|
||||
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
|
||||
}
|
||||
|
||||
S_str = [restype_1to3[AA] for AA in [restype_INTtoSTR[AA] for AA in S]]
|
||||
|
||||
X_list = []
|
||||
b_factor_list = []
|
||||
atom_name_list = []
|
||||
element_name_list = []
|
||||
residue_name_list = []
|
||||
residue_number_list = []
|
||||
chain_id_list = []
|
||||
icodes_list = []
|
||||
for i, AA in enumerate(S_str):
|
||||
sel = X_m[i].astype(np.int32) == 1
|
||||
total = np.sum(sel)
|
||||
tmp = np.array(restype_name_to_atom14_names[AA])[sel]
|
||||
X_list.append(X[i][sel])
|
||||
b_factor_list.append(b_factors[i][sel])
|
||||
atom_name_list.append(tmp)
|
||||
element_name_list += [AA[:1] for AA in list(tmp)]
|
||||
residue_name_list += total * [AA]
|
||||
residue_number_list += total * [R_idx[i]]
|
||||
chain_id_list += total * [chain_letters[i]]
|
||||
icodes_list += total * [icodes[i]]
|
||||
|
||||
X_stack = np.concatenate(X_list, 0)
|
||||
b_factor_stack = np.concatenate(b_factor_list, 0)
|
||||
atom_name_stack = np.concatenate(atom_name_list, 0)
|
||||
|
||||
protein = prody.AtomGroup()
|
||||
protein.setCoords(X_stack)
|
||||
protein.setBetas(b_factor_stack)
|
||||
protein.setNames(atom_name_stack)
|
||||
protein.setResnames(residue_name_list)
|
||||
protein.setElements(element_name_list)
|
||||
protein.setOccupancies(np.ones([X_stack.shape[0]]))
|
||||
protein.setResnums(residue_number_list)
|
||||
protein.setChids(chain_id_list)
|
||||
protein.setIcodes(icodes_list)
|
||||
|
||||
if other_atoms:
|
||||
other_atoms_g = prody.AtomGroup()
|
||||
other_atoms_g.setCoords(other_atoms.getCoords())
|
||||
other_atoms_g.setNames(other_atoms.getNames())
|
||||
other_atoms_g.setResnames(other_atoms.getResnames())
|
||||
other_atoms_g.setElements(other_atoms.getElements())
|
||||
other_atoms_g.setOccupancies(other_atoms.getOccupancies())
|
||||
other_atoms_g.setResnums(other_atoms.getResnums())
|
||||
other_atoms_g.setChids(other_atoms.getChids())
|
||||
if force_hetatm:
|
||||
other_atoms_g.setFlags("hetatm", other_atoms.getFlags("hetatm"))
|
||||
writePDB(save_path, protein + other_atoms_g)
|
||||
else:
|
||||
writePDB(save_path, protein)
|
||||
|
||||
|
||||
def get_aligned_coordinates(protein_atoms, CA_dict: dict, atom_name: str):
|
||||
"""
|
||||
protein_atoms: prody atom group
|
||||
CA_dict: mapping between chain_residue_idx_icodes and integers
|
||||
atom_name: atom to be parsed; e.g. CA
|
||||
"""
|
||||
atom_atoms = protein_atoms.select(f"name {atom_name}")
|
||||
|
||||
if atom_atoms != None:
|
||||
atom_coords = atom_atoms.getCoords()
|
||||
atom_resnums = atom_atoms.getResnums()
|
||||
atom_chain_ids = atom_atoms.getChids()
|
||||
atom_icodes = atom_atoms.getIcodes()
|
||||
|
||||
atom_coords_ = np.zeros([len(CA_dict), 3], np.float32)
|
||||
atom_coords_m = np.zeros([len(CA_dict)], np.int32)
|
||||
if atom_atoms != None:
|
||||
for i in range(len(atom_resnums)):
|
||||
code = atom_chain_ids[i] + "_" + str(atom_resnums[i]) + "_" + atom_icodes[i]
|
||||
if code in list(CA_dict):
|
||||
atom_coords_[CA_dict[code], :] = atom_coords[i]
|
||||
atom_coords_m[CA_dict[code]] = 1
|
||||
return atom_coords_, atom_coords_m
|
||||
|
||||
|
||||
def parse_PDB(
|
||||
input_path: str,
|
||||
device: str = "cpu",
|
||||
chains: list = [],
|
||||
parse_all_atoms: bool = False,
|
||||
parse_atoms_with_zero_occupancy: bool = False
|
||||
):
|
||||
"""
|
||||
input_path : path for the input PDB
|
||||
device: device for the torch.Tensor
|
||||
chains: a list specifying which chains need to be parsed; e.g. ["A", "B"]
|
||||
parse_all_atoms: if False parse only N,CA,C,O otherwise all 37 atoms
|
||||
parse_atoms_with_zero_occupancy: if True atoms with zero occupancy will be parsed
|
||||
"""
|
||||
element_list = [
|
||||
"H",
|
||||
"He",
|
||||
"Li",
|
||||
"Be",
|
||||
"B",
|
||||
"C",
|
||||
"N",
|
||||
"O",
|
||||
"F",
|
||||
"Ne",
|
||||
"Na",
|
||||
"Mg",
|
||||
"Al",
|
||||
"Si",
|
||||
"P",
|
||||
"S",
|
||||
"Cl",
|
||||
"Ar",
|
||||
"K",
|
||||
"Ca",
|
||||
"Sc",
|
||||
"Ti",
|
||||
"V",
|
||||
"Cr",
|
||||
"Mn",
|
||||
"Fe",
|
||||
"Co",
|
||||
"Ni",
|
||||
"Cu",
|
||||
"Zn",
|
||||
"Ga",
|
||||
"Ge",
|
||||
"As",
|
||||
"Se",
|
||||
"Br",
|
||||
"Kr",
|
||||
"Rb",
|
||||
"Sr",
|
||||
"Y",
|
||||
"Zr",
|
||||
"Nb",
|
||||
"Mb",
|
||||
"Tc",
|
||||
"Ru",
|
||||
"Rh",
|
||||
"Pd",
|
||||
"Ag",
|
||||
"Cd",
|
||||
"In",
|
||||
"Sn",
|
||||
"Sb",
|
||||
"Te",
|
||||
"I",
|
||||
"Xe",
|
||||
"Cs",
|
||||
"Ba",
|
||||
"La",
|
||||
"Ce",
|
||||
"Pr",
|
||||
"Nd",
|
||||
"Pm",
|
||||
"Sm",
|
||||
"Eu",
|
||||
"Gd",
|
||||
"Tb",
|
||||
"Dy",
|
||||
"Ho",
|
||||
"Er",
|
||||
"Tm",
|
||||
"Yb",
|
||||
"Lu",
|
||||
"Hf",
|
||||
"Ta",
|
||||
"W",
|
||||
"Re",
|
||||
"Os",
|
||||
"Ir",
|
||||
"Pt",
|
||||
"Au",
|
||||
"Hg",
|
||||
"Tl",
|
||||
"Pb",
|
||||
"Bi",
|
||||
"Po",
|
||||
"At",
|
||||
"Rn",
|
||||
"Fr",
|
||||
"Ra",
|
||||
"Ac",
|
||||
"Th",
|
||||
"Pa",
|
||||
"U",
|
||||
"Np",
|
||||
"Pu",
|
||||
"Am",
|
||||
"Cm",
|
||||
"Bk",
|
||||
"Cf",
|
||||
"Es",
|
||||
"Fm",
|
||||
"Md",
|
||||
"No",
|
||||
"Lr",
|
||||
"Rf",
|
||||
"Db",
|
||||
"Sg",
|
||||
"Bh",
|
||||
"Hs",
|
||||
"Mt",
|
||||
"Ds",
|
||||
"Rg",
|
||||
"Cn",
|
||||
"Uut",
|
||||
"Fl",
|
||||
"Uup",
|
||||
"Lv",
|
||||
"Uus",
|
||||
"Uuo",
|
||||
]
|
||||
element_list = [item.upper() for item in element_list]
|
||||
element_dict = dict(zip(element_list, range(1, len(element_list))))
|
||||
restype_3to1 = {
|
||||
"ALA": "A",
|
||||
"ARG": "R",
|
||||
"ASN": "N",
|
||||
"ASP": "D",
|
||||
"CYS": "C",
|
||||
"GLN": "Q",
|
||||
"GLU": "E",
|
||||
"GLY": "G",
|
||||
"HIS": "H",
|
||||
"ILE": "I",
|
||||
"LEU": "L",
|
||||
"LYS": "K",
|
||||
"MET": "M",
|
||||
"PHE": "F",
|
||||
"PRO": "P",
|
||||
"SER": "S",
|
||||
"THR": "T",
|
||||
"TRP": "W",
|
||||
"TYR": "Y",
|
||||
"VAL": "V",
|
||||
}
|
||||
restype_STRtoINT = {
|
||||
"A": 0,
|
||||
"C": 1,
|
||||
"D": 2,
|
||||
"E": 3,
|
||||
"F": 4,
|
||||
"G": 5,
|
||||
"H": 6,
|
||||
"I": 7,
|
||||
"K": 8,
|
||||
"L": 9,
|
||||
"M": 10,
|
||||
"N": 11,
|
||||
"P": 12,
|
||||
"Q": 13,
|
||||
"R": 14,
|
||||
"S": 15,
|
||||
"T": 16,
|
||||
"V": 17,
|
||||
"W": 18,
|
||||
"Y": 19,
|
||||
"X": 20,
|
||||
}
|
||||
|
||||
atom_order = {
|
||||
"N": 0,
|
||||
"CA": 1,
|
||||
"C": 2,
|
||||
"CB": 3,
|
||||
"O": 4,
|
||||
"CG": 5,
|
||||
"CG1": 6,
|
||||
"CG2": 7,
|
||||
"OG": 8,
|
||||
"OG1": 9,
|
||||
"SG": 10,
|
||||
"CD": 11,
|
||||
"CD1": 12,
|
||||
"CD2": 13,
|
||||
"ND1": 14,
|
||||
"ND2": 15,
|
||||
"OD1": 16,
|
||||
"OD2": 17,
|
||||
"SD": 18,
|
||||
"CE": 19,
|
||||
"CE1": 20,
|
||||
"CE2": 21,
|
||||
"CE3": 22,
|
||||
"NE": 23,
|
||||
"NE1": 24,
|
||||
"NE2": 25,
|
||||
"OE1": 26,
|
||||
"OE2": 27,
|
||||
"CH2": 28,
|
||||
"NH1": 29,
|
||||
"NH2": 30,
|
||||
"OH": 31,
|
||||
"CZ": 32,
|
||||
"CZ2": 33,
|
||||
"CZ3": 34,
|
||||
"NZ": 35,
|
||||
"OXT": 36,
|
||||
}
|
||||
|
||||
if not parse_all_atoms:
|
||||
atom_types = ["N", "CA", "C", "O"]
|
||||
else:
|
||||
atom_types = [
|
||||
"N",
|
||||
"CA",
|
||||
"C",
|
||||
"CB",
|
||||
"O",
|
||||
"CG",
|
||||
"CG1",
|
||||
"CG2",
|
||||
"OG",
|
||||
"OG1",
|
||||
"SG",
|
||||
"CD",
|
||||
"CD1",
|
||||
"CD2",
|
||||
"ND1",
|
||||
"ND2",
|
||||
"OD1",
|
||||
"OD2",
|
||||
"SD",
|
||||
"CE",
|
||||
"CE1",
|
||||
"CE2",
|
||||
"CE3",
|
||||
"NE",
|
||||
"NE1",
|
||||
"NE2",
|
||||
"OE1",
|
||||
"OE2",
|
||||
"CH2",
|
||||
"NH1",
|
||||
"NH2",
|
||||
"OH",
|
||||
"CZ",
|
||||
"CZ2",
|
||||
"CZ3",
|
||||
"NZ",
|
||||
]
|
||||
|
||||
atoms = parsePDB(input_path)
|
||||
if not parse_atoms_with_zero_occupancy:
|
||||
atoms = atoms.select("occupancy > 0")
|
||||
if chains:
|
||||
str_out = ""
|
||||
for item in chains:
|
||||
str_out += " chain " + item + " or"
|
||||
atoms = atoms.select(str_out[1:-3])
|
||||
|
||||
protein_atoms = atoms.select("protein")
|
||||
backbone = protein_atoms.select("backbone")
|
||||
other_atoms = atoms.select("not protein and not water")
|
||||
water_atoms = atoms.select("water")
|
||||
|
||||
CA_atoms = protein_atoms.select("name CA")
|
||||
CA_resnums = CA_atoms.getResnums()
|
||||
CA_chain_ids = CA_atoms.getChids()
|
||||
CA_icodes = CA_atoms.getIcodes()
|
||||
|
||||
CA_dict = {}
|
||||
for i in range(len(CA_resnums)):
|
||||
code = CA_chain_ids[i] + "_" + str(CA_resnums[i]) + "_" + CA_icodes[i]
|
||||
CA_dict[code] = i
|
||||
|
||||
xyz_37 = np.zeros([len(CA_dict), 37, 3], np.float32)
|
||||
xyz_37_m = np.zeros([len(CA_dict), 37], np.int32)
|
||||
for atom_name in atom_types:
|
||||
xyz, xyz_m = get_aligned_coordinates(protein_atoms, CA_dict, atom_name)
|
||||
xyz_37[:, atom_order[atom_name], :] = xyz
|
||||
xyz_37_m[:, atom_order[atom_name]] = xyz_m
|
||||
|
||||
N = xyz_37[:, atom_order["N"], :]
|
||||
CA = xyz_37[:, atom_order["CA"], :]
|
||||
C = xyz_37[:, atom_order["C"], :]
|
||||
O = xyz_37[:, atom_order["O"], :]
|
||||
|
||||
N_m = xyz_37_m[:, atom_order["N"]]
|
||||
CA_m = xyz_37_m[:, atom_order["CA"]]
|
||||
C_m = xyz_37_m[:, atom_order["C"]]
|
||||
O_m = xyz_37_m[:, atom_order["O"]]
|
||||
|
||||
mask = N_m * CA_m * C_m * O_m # must all 4 atoms exist
|
||||
|
||||
b = CA - N
|
||||
c = C - CA
|
||||
a = np.cross(b, c, axis=-1)
|
||||
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
|
||||
|
||||
chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32)
|
||||
R_idx = np.array(CA_resnums, dtype=np.int32)
|
||||
S = CA_atoms.getResnames()
|
||||
S = [restype_3to1[AA] if AA in list(restype_3to1) else "X" for AA in list(S)]
|
||||
S = np.array([restype_STRtoINT[AA] for AA in list(S)], np.int32)
|
||||
X = np.concatenate([N[:, None], CA[:, None], C[:, None], O[:, None]], 1)
|
||||
|
||||
try:
|
||||
Y = np.array(other_atoms.getCoords(), dtype=np.float32)
|
||||
Y_t = list(other_atoms.getElements())
|
||||
Y_t = np.array(
|
||||
[
|
||||
element_dict[y_t.upper()] if y_t.upper() in element_list else 0
|
||||
for y_t in Y_t
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
Y_m = (Y_t != 1) * (Y_t != 0)
|
||||
|
||||
Y = Y[Y_m, :]
|
||||
Y_t = Y_t[Y_m]
|
||||
Y_m = Y_m[Y_m]
|
||||
except:
|
||||
Y = np.zeros([1, 3], np.float32)
|
||||
Y_t = np.zeros([1], np.int32)
|
||||
Y_m = np.zeros([1], np.int32)
|
||||
|
||||
output_dict = {}
|
||||
output_dict["X"] = torch.tensor(X, device=device, dtype=torch.float32)
|
||||
output_dict["mask"] = torch.tensor(mask, device=device, dtype=torch.int32)
|
||||
output_dict["Y"] = torch.tensor(Y, device=device, dtype=torch.float32)
|
||||
output_dict["Y_t"] = torch.tensor(Y_t, device=device, dtype=torch.int32)
|
||||
output_dict["Y_m"] = torch.tensor(Y_m, device=device, dtype=torch.int32)
|
||||
|
||||
output_dict["R_idx"] = torch.tensor(R_idx, device=device, dtype=torch.int32)
|
||||
output_dict["chain_labels"] = torch.tensor(
|
||||
chain_labels, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
output_dict["chain_letters"] = CA_chain_ids
|
||||
|
||||
mask_c = []
|
||||
chain_list = list(set(output_dict["chain_letters"]))
|
||||
chain_list.sort()
|
||||
for chain in chain_list:
|
||||
mask_c.append(
|
||||
torch.tensor(
|
||||
[chain == item for item in output_dict["chain_letters"]],
|
||||
device=device,
|
||||
dtype=bool,
|
||||
)
|
||||
)
|
||||
|
||||
output_dict["mask_c"] = mask_c
|
||||
output_dict["chain_list"] = chain_list
|
||||
|
||||
output_dict["S"] = torch.tensor(S, device=device, dtype=torch.int32)
|
||||
|
||||
output_dict["xyz_37"] = torch.tensor(xyz_37, device=device, dtype=torch.float32)
|
||||
output_dict["xyz_37_m"] = torch.tensor(xyz_37_m, device=device, dtype=torch.int32)
|
||||
|
||||
return output_dict, backbone, other_atoms, CA_icodes, water_atoms
|
||||
|
||||
|
||||
def get_nearest_neighbours(CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms):
|
||||
device = CB.device
|
||||
mask_CBY = mask[:, None] * Y_m[None, :] # [A,B]
|
||||
L2_AB = torch.sum((CB[:, None, :] - Y[None, :, :]) ** 2, -1)
|
||||
L2_AB = L2_AB * mask_CBY + (1 - mask_CBY) * 1000.0
|
||||
|
||||
nn_idx = torch.argsort(L2_AB, -1)[:, :number_of_ligand_atoms]
|
||||
L2_AB_nn = torch.gather(L2_AB, 1, nn_idx)
|
||||
D_AB_closest = torch.sqrt(L2_AB_nn[:, 0])
|
||||
|
||||
Y_r = Y[None, :, :].repeat(CB.shape[0], 1, 1)
|
||||
Y_t_r = Y_t[None, :].repeat(CB.shape[0], 1)
|
||||
Y_m_r = Y_m[None, :].repeat(CB.shape[0], 1)
|
||||
|
||||
Y_tmp = torch.gather(Y_r, 1, nn_idx[:, :, None].repeat(1, 1, 3))
|
||||
Y_t_tmp = torch.gather(Y_t_r, 1, nn_idx)
|
||||
Y_m_tmp = torch.gather(Y_m_r, 1, nn_idx)
|
||||
|
||||
Y = torch.zeros(
|
||||
[CB.shape[0], number_of_ligand_atoms, 3], dtype=torch.float32, device=device
|
||||
)
|
||||
Y_t = torch.zeros(
|
||||
[CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
|
||||
)
|
||||
Y_m = torch.zeros(
|
||||
[CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
num_nn_update = Y_tmp.shape[1]
|
||||
Y[:, :num_nn_update] = Y_tmp
|
||||
Y_t[:, :num_nn_update] = Y_t_tmp
|
||||
Y_m[:, :num_nn_update] = Y_m_tmp
|
||||
|
||||
return Y, Y_t, Y_m, D_AB_closest
|
||||
|
||||
|
||||
def featurize(
|
||||
input_dict,
|
||||
cutoff_for_score=8.0,
|
||||
use_atom_context=True,
|
||||
number_of_ligand_atoms=16,
|
||||
model_type="protein_mpnn",
|
||||
):
|
||||
output_dict = {}
|
||||
if model_type == "ligand_mpnn":
|
||||
mask = input_dict["mask"]
|
||||
Y = input_dict["Y"]
|
||||
Y_t = input_dict["Y_t"]
|
||||
Y_m = input_dict["Y_m"]
|
||||
N = input_dict["X"][:, 0, :]
|
||||
CA = input_dict["X"][:, 1, :]
|
||||
C = input_dict["X"][:, 2, :]
|
||||
b = CA - N
|
||||
c = C - CA
|
||||
a = torch.cross(b, c, axis=-1)
|
||||
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
|
||||
Y, Y_t, Y_m, D_XY = get_nearest_neighbours(
|
||||
CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms
|
||||
)
|
||||
mask_XY = (D_XY < cutoff_for_score) * mask * Y_m[:, 0]
|
||||
output_dict["mask_XY"] = mask_XY[None,]
|
||||
if "side_chain_mask" in list(input_dict):
|
||||
output_dict["side_chain_mask"] = input_dict["side_chain_mask"][None,]
|
||||
output_dict["Y"] = Y[None,]
|
||||
output_dict["Y_t"] = Y_t[None,]
|
||||
output_dict["Y_m"] = Y_m[None,]
|
||||
if not use_atom_context:
|
||||
output_dict["Y_m"] = 0.0 * output_dict["Y_m"]
|
||||
elif (
|
||||
model_type == "per_residue_label_membrane_mpnn"
|
||||
or model_type == "global_label_membrane_mpnn"
|
||||
):
|
||||
output_dict["membrane_per_residue_labels"] = input_dict[
|
||||
"membrane_per_residue_labels"
|
||||
][None,]
|
||||
|
||||
R_idx_list = []
|
||||
count = 0
|
||||
R_idx_prev = -100000
|
||||
for R_idx in list(input_dict["R_idx"]):
|
||||
if R_idx_prev == R_idx:
|
||||
count += 1
|
||||
R_idx_list.append(R_idx + count)
|
||||
R_idx_prev = R_idx
|
||||
R_idx_renumbered = torch.tensor(R_idx_list, device=R_idx.device)
|
||||
output_dict["R_idx"] = R_idx_renumbered[None,]
|
||||
output_dict["R_idx_original"] = input_dict["R_idx"][None,]
|
||||
output_dict["chain_labels"] = input_dict["chain_labels"][None,]
|
||||
output_dict["S"] = input_dict["S"][None,]
|
||||
output_dict["chain_mask"] = input_dict["chain_mask"][None,]
|
||||
output_dict["mask"] = input_dict["mask"][None,]
|
||||
|
||||
output_dict["X"] = input_dict["X"][None,]
|
||||
|
||||
if "xyz_37" in list(input_dict):
|
||||
output_dict["xyz_37"] = input_dict["xyz_37"][None,]
|
||||
output_dict["xyz_37_m"] = input_dict["xyz_37_m"][None,]
|
||||
|
||||
return output_dict
|
||||
47
get_model_params.sh
Normal file
47
get_model_params.sh
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
|
||||
#make new directory for model parameters
|
||||
#e.g. bash get_model_params.sh "./model_params"
|
||||
|
||||
mkdir -p $1
|
||||
|
||||
#Original ProteinMPNN weights
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_002.pt -O $1"/proteinmpnn_v_48_002.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_010.pt -O $1"/proteinmpnn_v_48_010.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt -O $1"/proteinmpnn_v_48_020.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_030.pt -O $1"/proteinmpnn_v_48_030.pt"
|
||||
|
||||
#ProteinMPNN with num_edges=32
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_002.pt -O $1"/proteinmpnn_v_32_002.pt"
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_010.pt -O $1"/proteinmpnn_v_32_010.pt"
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_020.pt -O $1"/proteinmpnn_v_32_020.pt"
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_32_030.pt -O $1"/proteinmpnn_v_32_030.pt"
|
||||
|
||||
#LigandMPNN with num_edges=32; atom_context_num=25
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_25.pt -O $1"/ligandmpnn_v_32_005_25.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt -O $1"/ligandmpnn_v_32_010_25.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_25.pt -O $1"/ligandmpnn_v_32_020_25.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_25.pt -O $1"/ligandmpnn_v_32_030_25.pt"
|
||||
|
||||
#LigandMPNN with num_edges=32; atom_context_num=16
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_005_16.pt -O $1"/ligandmpnn_v_32_005_16.pt"
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_16.pt -O $1"/ligandmpnn_v_32_010_16.pt"
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_020_16.pt -O $1"/ligandmpnn_v_32_020_16.pt"
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_16.pt -O $1"/ligandmpnn_v_32_030_16.pt"
|
||||
|
||||
# wget -q https://files.ipd.uw.edu/pub/ligandmpnn/publication_version_ligandmpnn_v_32_010_25.pt -O $1"/publication_version_ligandmpnn_v_32_010_25.pt"
|
||||
|
||||
#Per residue label membrane ProteinMPNN
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/per_residue_label_membrane_mpnn_v_48_020.pt -O $1"/per_residue_label_membrane_mpnn_v_48_020.pt"
|
||||
|
||||
#Global label membrane ProteinMPNN
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/global_label_membrane_mpnn_v_48_020.pt -O $1"/global_label_membrane_mpnn_v_48_020.pt"
|
||||
|
||||
#SolubleMPNN
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_002.pt -O $1"/solublempnn_v_48_002.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_010.pt -O $1"/solublempnn_v_48_010.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt -O $1"/solublempnn_v_48_020.pt"
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_030.pt -O $1"/solublempnn_v_48_030.pt"
|
||||
|
||||
#LigandMPNN for side-chain packing (multi-step denoising model)
|
||||
wget -q https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_sc_v_32_002_16.pt -O $1"/ligandmpnn_sc_v_32_002_16.pt"
|
||||
67
main.nf
Normal file
67
main.nf
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env nextflow
|
||||
|
||||
nextflow.enable.dsl=2
|
||||
|
||||
params.pdb = '/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb'
|
||||
params.outdir = '/mnt/OmicNAS/private/old/olamide/ligandmpnn/output'
|
||||
params.model_type = 'ligand_mpnn'
|
||||
params.temperature = 0.1
|
||||
params.seed = 111
|
||||
params.batch_size = 1
|
||||
params.number_of_batches = 1
|
||||
params.chains_to_design = ''
|
||||
params.fixed_residues = ''
|
||||
params.pack_side_chains = 0
|
||||
|
||||
process LIGANDMPNN {
|
||||
container 'ligandmpnn:latest'
|
||||
containerOptions '--rm --gpus all -v /mnt:/mnt'
|
||||
publishDir params.outdir, mode: 'copy'
|
||||
stageInMode 'copy'
|
||||
|
||||
input:
|
||||
path pdb
|
||||
|
||||
output:
|
||||
path "${pdb.simpleName}/seqs/*.fa"
|
||||
path "${pdb.simpleName}/backbones/*.pdb"
|
||||
path "${pdb.simpleName}/packed/*.pdb", optional: true
|
||||
path "run.log"
|
||||
|
||||
script:
|
||||
// Set checkpoint path based on model type
|
||||
def checkpoint_arg = ''
|
||||
if (params.model_type == 'ligand_mpnn') {
|
||||
checkpoint_arg = '--checkpoint_ligand_mpnn /app/LigandMPNN/model_params/ligandmpnn_v_32_010_25.pt'
|
||||
} else if (params.model_type == 'protein_mpnn') {
|
||||
checkpoint_arg = '--checkpoint_protein_mpnn /app/LigandMPNN/model_params/proteinmpnn_v_48_020.pt'
|
||||
} else if (params.model_type == 'soluble_mpnn') {
|
||||
checkpoint_arg = '--checkpoint_soluble_mpnn /app/LigandMPNN/model_params/solublempnn_v_48_020.pt'
|
||||
} else if (params.model_type == 'global_label_membrane_mpnn') {
|
||||
checkpoint_arg = '--checkpoint_global_label_membrane_mpnn /app/LigandMPNN/model_params/global_label_membrane_mpnn_v_48_020.pt'
|
||||
} else if (params.model_type == 'per_residue_label_membrane_mpnn') {
|
||||
checkpoint_arg = '--checkpoint_per_residue_label_membrane_mpnn /app/LigandMPNN/model_params/per_residue_label_membrane_mpnn_v_48_020.pt'
|
||||
}
|
||||
|
||||
"""
|
||||
mkdir -p ${pdb.simpleName}/seqs ${pdb.simpleName}/backbones ${pdb.simpleName}/packed
|
||||
|
||||
ligandmpnn \\
|
||||
--pdb_path \$PWD/${pdb} \\
|
||||
--out_folder \$PWD/${pdb.simpleName} \\
|
||||
--model_type ${params.model_type} \\
|
||||
${checkpoint_arg} \\
|
||||
--temperature ${params.temperature} \\
|
||||
--seed ${params.seed} \\
|
||||
--batch_size ${params.batch_size} \\
|
||||
--number_of_batches ${params.number_of_batches} \\
|
||||
--pack_side_chains ${params.pack_side_chains} \\
|
||||
${params.chains_to_design ? "--chains_to_design ${params.chains_to_design}" : ''} \\
|
||||
${params.fixed_residues ? "--fixed_residues \"${params.fixed_residues}\"" : ''} \\
|
||||
2>&1 | tee run.log
|
||||
"""
|
||||
}
|
||||
|
||||
workflow {
|
||||
LIGANDMPNN(Channel.fromPath(params.pdb))
|
||||
}
|
||||
1772
model_utils.py
Normal file
1772
model_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
42
nextflow.config
Normal file
42
nextflow.config
Normal file
@@ -0,0 +1,42 @@
|
||||
// Manifest for Nextflow metadata
|
||||
manifest {
|
||||
name = 'LigandMPNN-Nextflow'
|
||||
author = 'Generated from LigandMPNN repository'
|
||||
homePage = 'https://github.com/dauparas/LigandMPNN'
|
||||
description = 'Nextflow pipeline for LigandMPNN - Protein sequence design with ligand context'
|
||||
mainScript = 'main.nf'
|
||||
version = '1.0.0'
|
||||
}
|
||||
|
||||
// Global default parameters
|
||||
params {
|
||||
pdb = "/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb"
|
||||
outdir = "/mnt/OmicNAS/private/old/olamide/ligandmpnn/output"
|
||||
model_type = "ligand_mpnn"
|
||||
temperature = 0.1
|
||||
seed = 111
|
||||
batch_size = 1
|
||||
number_of_batches = 1
|
||||
chains_to_design = ""
|
||||
fixed_residues = ""
|
||||
}
|
||||
|
||||
// Container configurations
|
||||
docker {
|
||||
enabled = true
|
||||
runOptions = '--gpus all'
|
||||
}
|
||||
|
||||
// Process configurations
|
||||
process {
|
||||
cpus = 4
|
||||
memory = '16 GB'
|
||||
}
|
||||
|
||||
// Execution configurations
|
||||
executor {
|
||||
$local {
|
||||
cpus = 8
|
||||
memory = '32 GB'
|
||||
}
|
||||
}
|
||||
131
params.json
Normal file
131
params.json
Normal file
@@ -0,0 +1,131 @@
|
||||
{
|
||||
"params": {
|
||||
"pdb": {
|
||||
"type": "file",
|
||||
"description": "Path to input PDB file for protein sequence design",
|
||||
"default": "/mnt/OmicNAS/private/old/olamide/ligandmpnn/input/1BC8.pdb",
|
||||
"required": true,
|
||||
"pipeline_io": "input",
|
||||
"var_name": "params.pdb",
|
||||
"examples": [
|
||||
"/mnt/workflow/input/protein.pdb",
|
||||
"/mnt/workflow/input/*.pdb"
|
||||
],
|
||||
"pattern": ".*\\.pdb$",
|
||||
"enum": [],
|
||||
"validation": {},
|
||||
"notes": "Input PDB file containing the protein structure for sequence design."
|
||||
},
|
||||
"outdir": {
|
||||
"type": "folder",
|
||||
"description": "Directory for LigandMPNN output results",
|
||||
"default": "/mnt/OmicNAS/private/old/olamide/ligandmpnn/output",
|
||||
"required": true,
|
||||
"pipeline_io": "output",
|
||||
"var_name": "params.outdir",
|
||||
"examples": [
|
||||
"/mnt/workflow/output",
|
||||
"/path/to/results"
|
||||
],
|
||||
"pattern": ".*",
|
||||
"enum": [],
|
||||
"validation": {},
|
||||
"notes": "Directory where designed sequences and backbone PDBs will be saved."
|
||||
},
|
||||
"model_type": {
|
||||
"type": "string",
|
||||
"description": "Type of MPNN model to use for sequence design",
|
||||
"default": "ligand_mpnn",
|
||||
"required": false,
|
||||
"pipeline_io": "parameter",
|
||||
"var_name": "params.model_type",
|
||||
"examples": [
|
||||
"protein_mpnn",
|
||||
"ligand_mpnn",
|
||||
"soluble_mpnn"
|
||||
],
|
||||
"pattern": "^(protein_mpnn|ligand_mpnn|soluble_mpnn|global_label_membrane_mpnn|per_residue_label_membrane_mpnn)$",
|
||||
"enum": ["protein_mpnn", "ligand_mpnn", "soluble_mpnn", "global_label_membrane_mpnn", "per_residue_label_membrane_mpnn"],
|
||||
"validation": {},
|
||||
"notes": "protein_mpnn: Original ProteinMPNN. ligand_mpnn: Context-aware with ligands. soluble_mpnn: Trained on soluble proteins."
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "Sampling temperature for sequence generation",
|
||||
"default": 0.1,
|
||||
"required": false,
|
||||
"pipeline_io": "parameter",
|
||||
"var_name": "params.temperature",
|
||||
"examples": [0.05, 0.1, 0.2],
|
||||
"pattern": null,
|
||||
"enum": [],
|
||||
"validation": {},
|
||||
"notes": "Higher temperature gives more sequence diversity. Recommended range: 0.05-0.5"
|
||||
},
|
||||
"seed": {
|
||||
"type": "integer",
|
||||
"description": "Random seed for reproducibility",
|
||||
"default": 111,
|
||||
"required": false,
|
||||
"pipeline_io": "parameter",
|
||||
"var_name": "params.seed",
|
||||
"examples": [111, 42, 12345],
|
||||
"pattern": null,
|
||||
"enum": [],
|
||||
"validation": {},
|
||||
"notes": "Set for reproducible results."
|
||||
},
|
||||
"batch_size": {
|
||||
"type": "integer",
|
||||
"description": "Number of sequences to generate per batch",
|
||||
"default": 1,
|
||||
"required": false,
|
||||
"pipeline_io": "parameter",
|
||||
"var_name": "params.batch_size",
|
||||
"examples": [1, 3, 5],
|
||||
"pattern": null,
|
||||
"enum": [],
|
||||
"validation": {},
|
||||
"notes": "Higher batch sizes require more GPU memory."
|
||||
},
|
||||
"number_of_batches": {
|
||||
"type": "integer",
|
||||
"description": "Number of batches to run",
|
||||
"default": 1,
|
||||
"required": false,
|
||||
"pipeline_io": "parameter",
|
||||
"var_name": "params.number_of_batches",
|
||||
"examples": [1, 5, 10],
|
||||
"pattern": null,
|
||||
"enum": [],
|
||||
"validation": {},
|
||||
"notes": "Total sequences = batch_size × number_of_batches"
|
||||
},
|
||||
"chains_to_design": {
|
||||
"type": "string",
|
||||
"description": "Comma-separated chain IDs to redesign",
|
||||
"default": "",
|
||||
"required": false,
|
||||
"pipeline_io": "parameter",
|
||||
"var_name": "params.chains_to_design",
|
||||
"examples": ["A", "A,B", "A,B,C"],
|
||||
"pattern": "^([A-Z],?)*$",
|
||||
"enum": [],
|
||||
"validation": {},
|
||||
"notes": "Leave empty to design all chains."
|
||||
},
|
||||
"fixed_residues": {
|
||||
"type": "string",
|
||||
"description": "Space-separated list of residues to keep fixed",
|
||||
"default": "",
|
||||
"required": false,
|
||||
"pipeline_io": "parameter",
|
||||
"var_name": "params.fixed_residues",
|
||||
"examples": ["A1 A2 A3", "A12 B25 B26"],
|
||||
"pattern": null,
|
||||
"enum": [],
|
||||
"validation": {},
|
||||
"notes": "Format: ChainResidue (e.g., A12). Leave empty to design all residues."
|
||||
}
|
||||
}
|
||||
}
|
||||
29
requirements.txt
Normal file
29
requirements.txt
Normal file
@@ -0,0 +1,29 @@
|
||||
biopython==1.79
|
||||
filelock==3.13.1
|
||||
fsspec==2024.3.1
|
||||
Jinja2==3.1.3
|
||||
MarkupSafe==2.1.5
|
||||
mpmath==1.3.0
|
||||
networkx==3.2.1
|
||||
numpy==1.23.5
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
nvidia-nccl-cu12==2.19.3
|
||||
nvidia-nvjitlink-cu12==12.4.99
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
ProDy==2.4.1
|
||||
pyparsing==3.1.1
|
||||
scipy==1.12.0
|
||||
sympy==1.12
|
||||
torch==2.2.1
|
||||
triton==2.2.0
|
||||
typing_extensions==4.10.0
|
||||
ml-collections==0.1.1
|
||||
dm-tree==0.1.8
|
||||
990
run.py
Normal file
990
run.py
Normal file
@@ -0,0 +1,990 @@
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os.path
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from data_utils import (
|
||||
alphabet,
|
||||
element_dict_rev,
|
||||
featurize,
|
||||
get_score,
|
||||
get_seq_rec,
|
||||
parse_PDB,
|
||||
restype_1to3,
|
||||
restype_int_to_str,
|
||||
restype_str_to_int,
|
||||
write_full_PDB,
|
||||
)
|
||||
from model_utils import ProteinMPNN
|
||||
from prody import writePDB
|
||||
from sc_utils import Packer, pack_side_chains
|
||||
|
||||
|
||||
def main(args) -> None:
|
||||
"""
|
||||
Inference function
|
||||
"""
|
||||
if args.seed:
|
||||
seed = args.seed
|
||||
else:
|
||||
seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
|
||||
folder_for_outputs = args.out_folder
|
||||
base_folder = folder_for_outputs
|
||||
if base_folder[-1] != "/":
|
||||
base_folder = base_folder + "/"
|
||||
if not os.path.exists(base_folder):
|
||||
os.makedirs(base_folder, exist_ok=True)
|
||||
if not os.path.exists(base_folder + "seqs"):
|
||||
os.makedirs(base_folder + "seqs", exist_ok=True)
|
||||
if not os.path.exists(base_folder + "backbones"):
|
||||
os.makedirs(base_folder + "backbones", exist_ok=True)
|
||||
if not os.path.exists(base_folder + "packed"):
|
||||
os.makedirs(base_folder + "packed", exist_ok=True)
|
||||
if args.save_stats:
|
||||
if not os.path.exists(base_folder + "stats"):
|
||||
os.makedirs(base_folder + "stats", exist_ok=True)
|
||||
if args.model_type == "protein_mpnn":
|
||||
checkpoint_path = args.checkpoint_protein_mpnn
|
||||
elif args.model_type == "ligand_mpnn":
|
||||
checkpoint_path = args.checkpoint_ligand_mpnn
|
||||
elif args.model_type == "per_residue_label_membrane_mpnn":
|
||||
checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
|
||||
elif args.model_type == "global_label_membrane_mpnn":
|
||||
checkpoint_path = args.checkpoint_global_label_membrane_mpnn
|
||||
elif args.model_type == "soluble_mpnn":
|
||||
checkpoint_path = args.checkpoint_soluble_mpnn
|
||||
else:
|
||||
print("Choose one of the available models")
|
||||
sys.exit()
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
if args.model_type == "ligand_mpnn":
|
||||
atom_context_num = checkpoint["atom_context_num"]
|
||||
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
|
||||
k_neighbors = checkpoint["num_edges"]
|
||||
else:
|
||||
atom_context_num = 1
|
||||
ligand_mpnn_use_side_chain_context = 0
|
||||
k_neighbors = checkpoint["num_edges"]
|
||||
|
||||
model = ProteinMPNN(
|
||||
node_features=128,
|
||||
edge_features=128,
|
||||
hidden_dim=128,
|
||||
num_encoder_layers=3,
|
||||
num_decoder_layers=3,
|
||||
k_neighbors=k_neighbors,
|
||||
device=device,
|
||||
atom_context_num=atom_context_num,
|
||||
model_type=args.model_type,
|
||||
ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
|
||||
)
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.pack_side_chains:
|
||||
model_sc = Packer(
|
||||
node_features=128,
|
||||
edge_features=128,
|
||||
num_positional_embeddings=16,
|
||||
num_chain_embeddings=16,
|
||||
num_rbf=16,
|
||||
hidden_dim=128,
|
||||
num_encoder_layers=3,
|
||||
num_decoder_layers=3,
|
||||
atom_context_num=16,
|
||||
lower_bound=0.0,
|
||||
upper_bound=20.0,
|
||||
top_k=32,
|
||||
dropout=0.0,
|
||||
augment_eps=0.0,
|
||||
atom37_order=False,
|
||||
device=device,
|
||||
num_mix=3,
|
||||
)
|
||||
|
||||
checkpoint_sc = torch.load(args.checkpoint_path_sc, map_location=device)
|
||||
model_sc.load_state_dict(checkpoint_sc["model_state_dict"])
|
||||
model_sc.to(device)
|
||||
model_sc.eval()
|
||||
|
||||
if args.pdb_path_multi:
|
||||
with open(args.pdb_path_multi, "r") as fh:
|
||||
pdb_paths = list(json.load(fh))
|
||||
else:
|
||||
pdb_paths = [args.pdb_path]
|
||||
|
||||
if args.fixed_residues_multi:
|
||||
with open(args.fixed_residues_multi, "r") as fh:
|
||||
fixed_residues_multi = json.load(fh)
|
||||
fixed_residues_multi = {key:value.split() for key,value in fixed_residues_multi.items()}
|
||||
else:
|
||||
fixed_residues = [item for item in args.fixed_residues.split()]
|
||||
fixed_residues_multi = {}
|
||||
for pdb in pdb_paths:
|
||||
fixed_residues_multi[pdb] = fixed_residues
|
||||
|
||||
if args.redesigned_residues_multi:
|
||||
with open(args.redesigned_residues_multi, "r") as fh:
|
||||
redesigned_residues_multi = json.load(fh)
|
||||
redesigned_residues_multi = {key:value.split() for key,value in redesigned_residues_multi.items()}
|
||||
else:
|
||||
redesigned_residues = [item for item in args.redesigned_residues.split()]
|
||||
redesigned_residues_multi = {}
|
||||
for pdb in pdb_paths:
|
||||
redesigned_residues_multi[pdb] = redesigned_residues
|
||||
|
||||
bias_AA = torch.zeros([21], device=device, dtype=torch.float32)
|
||||
if args.bias_AA:
|
||||
tmp = [item.split(":") for item in args.bias_AA.split(",")]
|
||||
a1 = [b[0] for b in tmp]
|
||||
a2 = [float(b[1]) for b in tmp]
|
||||
for i, AA in enumerate(a1):
|
||||
bias_AA[restype_str_to_int[AA]] = a2[i]
|
||||
|
||||
if args.bias_AA_per_residue_multi:
|
||||
with open(args.bias_AA_per_residue_multi, "r") as fh:
|
||||
bias_AA_per_residue_multi = json.load(
|
||||
fh
|
||||
) # {"pdb_path" : {"A12": {"G": 1.1}}}
|
||||
else:
|
||||
if args.bias_AA_per_residue:
|
||||
with open(args.bias_AA_per_residue, "r") as fh:
|
||||
bias_AA_per_residue = json.load(fh) # {"A12": {"G": 1.1}}
|
||||
bias_AA_per_residue_multi = {}
|
||||
for pdb in pdb_paths:
|
||||
bias_AA_per_residue_multi[pdb] = bias_AA_per_residue
|
||||
|
||||
if args.omit_AA_per_residue_multi:
|
||||
with open(args.omit_AA_per_residue_multi, "r") as fh:
|
||||
omit_AA_per_residue_multi = json.load(
|
||||
fh
|
||||
) # {"pdb_path" : {"A12": "PQR", "A13": "QS"}}
|
||||
else:
|
||||
if args.omit_AA_per_residue:
|
||||
with open(args.omit_AA_per_residue, "r") as fh:
|
||||
omit_AA_per_residue = json.load(fh) # {"A12": "PG"}
|
||||
omit_AA_per_residue_multi = {}
|
||||
for pdb in pdb_paths:
|
||||
omit_AA_per_residue_multi[pdb] = omit_AA_per_residue
|
||||
omit_AA_list = args.omit_AA
|
||||
omit_AA = torch.tensor(
|
||||
np.array([AA in omit_AA_list for AA in alphabet]).astype(np.float32),
|
||||
device=device,
|
||||
)
|
||||
|
||||
if len(args.parse_these_chains_only) != 0:
|
||||
parse_these_chains_only_list = args.parse_these_chains_only.split(",")
|
||||
else:
|
||||
parse_these_chains_only_list = []
|
||||
|
||||
|
||||
# loop over PDB paths
|
||||
for pdb in pdb_paths:
|
||||
if args.verbose:
|
||||
print("Designing protein from this path:", pdb)
|
||||
fixed_residues = fixed_residues_multi[pdb]
|
||||
redesigned_residues = redesigned_residues_multi[pdb]
|
||||
parse_all_atoms_flag = args.ligand_mpnn_use_side_chain_context or (
|
||||
args.pack_side_chains and not args.repack_everything
|
||||
)
|
||||
protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
|
||||
pdb,
|
||||
device=device,
|
||||
chains=parse_these_chains_only_list,
|
||||
parse_all_atoms=parse_all_atoms_flag,
|
||||
parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy,
|
||||
)
|
||||
# make chain_letter + residue_idx + insertion_code mapping to integers
|
||||
R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
|
||||
chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
|
||||
encoded_residues = []
|
||||
for i, R_idx_item in enumerate(R_idx_list):
|
||||
tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
|
||||
encoded_residues.append(tmp)
|
||||
encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
|
||||
encoded_residue_dict_rev = dict(
|
||||
zip(list(range(len(encoded_residues))), encoded_residues)
|
||||
)
|
||||
|
||||
bias_AA_per_residue = torch.zeros(
|
||||
[len(encoded_residues), 21], device=device, dtype=torch.float32
|
||||
)
|
||||
if args.bias_AA_per_residue_multi or args.bias_AA_per_residue:
|
||||
bias_dict = bias_AA_per_residue_multi[pdb]
|
||||
for residue_name, v1 in bias_dict.items():
|
||||
if residue_name in encoded_residues:
|
||||
i1 = encoded_residue_dict[residue_name]
|
||||
for amino_acid, v2 in v1.items():
|
||||
if amino_acid in alphabet:
|
||||
j1 = restype_str_to_int[amino_acid]
|
||||
bias_AA_per_residue[i1, j1] = v2
|
||||
|
||||
omit_AA_per_residue = torch.zeros(
|
||||
[len(encoded_residues), 21], device=device, dtype=torch.float32
|
||||
)
|
||||
if args.omit_AA_per_residue_multi or args.omit_AA_per_residue:
|
||||
omit_dict = omit_AA_per_residue_multi[pdb]
|
||||
for residue_name, v1 in omit_dict.items():
|
||||
if residue_name in encoded_residues:
|
||||
i1 = encoded_residue_dict[residue_name]
|
||||
for amino_acid in v1:
|
||||
if amino_acid in alphabet:
|
||||
j1 = restype_str_to_int[amino_acid]
|
||||
omit_AA_per_residue[i1, j1] = 1.0
|
||||
|
||||
fixed_positions = torch.tensor(
|
||||
[int(item not in fixed_residues) for item in encoded_residues],
|
||||
device=device,
|
||||
)
|
||||
redesigned_positions = torch.tensor(
|
||||
[int(item not in redesigned_residues) for item in encoded_residues],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
|
||||
if args.transmembrane_buried:
|
||||
buried_residues = [item for item in args.transmembrane_buried.split()]
|
||||
buried_positions = torch.tensor(
|
||||
[int(item in buried_residues) for item in encoded_residues],
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
buried_positions = torch.zeros_like(fixed_positions)
|
||||
|
||||
if args.transmembrane_interface:
|
||||
interface_residues = [item for item in args.transmembrane_interface.split()]
|
||||
interface_positions = torch.tensor(
|
||||
[int(item in interface_residues) for item in encoded_residues],
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
interface_positions = torch.zeros_like(fixed_positions)
|
||||
protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
|
||||
1 - interface_positions
|
||||
) + 1 * interface_positions * (1 - buried_positions)
|
||||
|
||||
if args.model_type == "global_label_membrane_mpnn":
|
||||
protein_dict["membrane_per_residue_labels"] = (
|
||||
args.global_transmembrane_label + 0 * fixed_positions
|
||||
)
|
||||
if len(args.chains_to_design) != 0:
|
||||
chains_to_design_list = args.chains_to_design.split(",")
|
||||
else:
|
||||
chains_to_design_list = protein_dict["chain_letters"]
|
||||
|
||||
chain_mask = torch.tensor(
|
||||
np.array(
|
||||
[
|
||||
item in chains_to_design_list
|
||||
for item in protein_dict["chain_letters"]
|
||||
],
|
||||
dtype=np.int32,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
# create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
|
||||
if redesigned_residues:
|
||||
protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
|
||||
elif fixed_residues:
|
||||
protein_dict["chain_mask"] = chain_mask * fixed_positions
|
||||
else:
|
||||
protein_dict["chain_mask"] = chain_mask
|
||||
|
||||
if args.verbose:
|
||||
PDB_residues_to_be_redesigned = [
|
||||
encoded_residue_dict_rev[item]
|
||||
for item in range(protein_dict["chain_mask"].shape[0])
|
||||
if protein_dict["chain_mask"][item] == 1
|
||||
]
|
||||
PDB_residues_to_be_fixed = [
|
||||
encoded_residue_dict_rev[item]
|
||||
for item in range(protein_dict["chain_mask"].shape[0])
|
||||
if protein_dict["chain_mask"][item] == 0
|
||||
]
|
||||
print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
|
||||
print("These residues will be fixed: ", PDB_residues_to_be_fixed)
|
||||
|
||||
# specify which residues are linked
|
||||
if args.symmetry_residues:
|
||||
symmetry_residues_list_of_lists = [
|
||||
x.split(",") for x in args.symmetry_residues.split("|")
|
||||
]
|
||||
remapped_symmetry_residues = []
|
||||
for t_list in symmetry_residues_list_of_lists:
|
||||
tmp_list = []
|
||||
for t in t_list:
|
||||
tmp_list.append(encoded_residue_dict[t])
|
||||
remapped_symmetry_residues.append(tmp_list)
|
||||
else:
|
||||
remapped_symmetry_residues = [[]]
|
||||
|
||||
# specify linking weights
|
||||
if args.symmetry_weights:
|
||||
symmetry_weights = [
|
||||
[float(item) for item in x.split(",")]
|
||||
for x in args.symmetry_weights.split("|")
|
||||
]
|
||||
else:
|
||||
symmetry_weights = [[]]
|
||||
|
||||
if args.homo_oligomer:
|
||||
if args.verbose:
|
||||
print("Designing HOMO-OLIGOMER")
|
||||
chain_letters_set = list(set(chain_letters_list))
|
||||
reference_chain = chain_letters_set[0]
|
||||
lc = len(reference_chain)
|
||||
residue_indices = [
|
||||
item[lc:] for item in encoded_residues if item[:lc] == reference_chain
|
||||
]
|
||||
remapped_symmetry_residues = []
|
||||
symmetry_weights = []
|
||||
for res in residue_indices:
|
||||
tmp_list = []
|
||||
tmp_w_list = []
|
||||
for chain in chain_letters_set:
|
||||
name = chain + res
|
||||
tmp_list.append(encoded_residue_dict[name])
|
||||
tmp_w_list.append(1 / len(chain_letters_set))
|
||||
remapped_symmetry_residues.append(tmp_list)
|
||||
symmetry_weights.append(tmp_w_list)
|
||||
|
||||
# set other atom bfactors to 0.0
|
||||
if other_atoms:
|
||||
other_bfactors = other_atoms.getBetas()
|
||||
other_atoms.setBetas(other_bfactors * 0.0)
|
||||
|
||||
# adjust input PDB name by dropping .pdb if it does exist
|
||||
name = pdb[pdb.rfind("/") + 1 :]
|
||||
if name[-4:] == ".pdb":
|
||||
name = name[:-4]
|
||||
|
||||
with torch.no_grad():
|
||||
# run featurize to remap R_idx and add batch dimension
|
||||
if args.verbose:
|
||||
if "Y" in list(protein_dict):
|
||||
atom_coords = protein_dict["Y"].cpu().numpy()
|
||||
atom_types = list(protein_dict["Y_t"].cpu().numpy())
|
||||
atom_mask = list(protein_dict["Y_m"].cpu().numpy())
|
||||
number_of_atoms_parsed = np.sum(atom_mask)
|
||||
else:
|
||||
print("No ligand atoms parsed")
|
||||
number_of_atoms_parsed = 0
|
||||
atom_types = ""
|
||||
atom_coords = []
|
||||
if number_of_atoms_parsed == 0:
|
||||
print("No ligand atoms parsed")
|
||||
elif args.model_type == "ligand_mpnn":
|
||||
print(
|
||||
f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
|
||||
)
|
||||
for i, atom_type in enumerate(atom_types):
|
||||
print(
|
||||
f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
|
||||
)
|
||||
feature_dict = featurize(
|
||||
protein_dict,
|
||||
cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
|
||||
use_atom_context=args.ligand_mpnn_use_atom_context,
|
||||
number_of_ligand_atoms=atom_context_num,
|
||||
model_type=args.model_type,
|
||||
)
|
||||
feature_dict["batch_size"] = args.batch_size
|
||||
B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
|
||||
# add additional keys to the feature dictionary
|
||||
feature_dict["temperature"] = args.temperature
|
||||
feature_dict["bias"] = (
|
||||
(-1e8 * omit_AA[None, None, :] + bias_AA).repeat([1, L, 1])
|
||||
+ bias_AA_per_residue[None]
|
||||
- 1e8 * omit_AA_per_residue[None]
|
||||
)
|
||||
feature_dict["symmetry_residues"] = remapped_symmetry_residues
|
||||
feature_dict["symmetry_weights"] = symmetry_weights
|
||||
|
||||
sampling_probs_list = []
|
||||
log_probs_list = []
|
||||
decoding_order_list = []
|
||||
S_list = []
|
||||
loss_list = []
|
||||
loss_per_residue_list = []
|
||||
loss_XY_list = []
|
||||
for _ in range(args.number_of_batches):
|
||||
feature_dict["randn"] = torch.randn(
|
||||
[feature_dict["batch_size"], feature_dict["mask"].shape[1]],
|
||||
device=device,
|
||||
)
|
||||
output_dict = model.sample(feature_dict)
|
||||
|
||||
# compute confidence scores
|
||||
loss, loss_per_residue = get_score(
|
||||
output_dict["S"],
|
||||
output_dict["log_probs"],
|
||||
feature_dict["mask"] * feature_dict["chain_mask"],
|
||||
)
|
||||
if args.model_type == "ligand_mpnn":
|
||||
combined_mask = (
|
||||
feature_dict["mask"]
|
||||
* feature_dict["mask_XY"]
|
||||
* feature_dict["chain_mask"]
|
||||
)
|
||||
else:
|
||||
combined_mask = feature_dict["mask"] * feature_dict["chain_mask"]
|
||||
loss_XY, _ = get_score(
|
||||
output_dict["S"], output_dict["log_probs"], combined_mask
|
||||
)
|
||||
# -----
|
||||
S_list.append(output_dict["S"])
|
||||
log_probs_list.append(output_dict["log_probs"])
|
||||
sampling_probs_list.append(output_dict["sampling_probs"])
|
||||
decoding_order_list.append(output_dict["decoding_order"])
|
||||
loss_list.append(loss)
|
||||
loss_per_residue_list.append(loss_per_residue)
|
||||
loss_XY_list.append(loss_XY)
|
||||
S_stack = torch.cat(S_list, 0)
|
||||
log_probs_stack = torch.cat(log_probs_list, 0)
|
||||
sampling_probs_stack = torch.cat(sampling_probs_list, 0)
|
||||
decoding_order_stack = torch.cat(decoding_order_list, 0)
|
||||
loss_stack = torch.cat(loss_list, 0)
|
||||
loss_per_residue_stack = torch.cat(loss_per_residue_list, 0)
|
||||
loss_XY_stack = torch.cat(loss_XY_list, 0)
|
||||
rec_mask = feature_dict["mask"][:1] * feature_dict["chain_mask"][:1]
|
||||
rec_stack = get_seq_rec(feature_dict["S"][:1], S_stack, rec_mask)
|
||||
|
||||
native_seq = "".join(
|
||||
[restype_int_to_str[AA] for AA in feature_dict["S"][0].cpu().numpy()]
|
||||
)
|
||||
seq_np = np.array(list(native_seq))
|
||||
seq_out_str = []
|
||||
for mask in protein_dict["mask_c"]:
|
||||
seq_out_str += list(seq_np[mask.cpu().numpy()])
|
||||
seq_out_str += [args.fasta_seq_separation]
|
||||
seq_out_str = "".join(seq_out_str)[:-1]
|
||||
|
||||
output_fasta = base_folder + "/seqs/" + name + args.file_ending + ".fa"
|
||||
output_backbones = base_folder + "/backbones/"
|
||||
output_packed = base_folder + "/packed/"
|
||||
output_stats_path = base_folder + "stats/" + name + args.file_ending + ".pt"
|
||||
|
||||
out_dict = {}
|
||||
out_dict["generated_sequences"] = S_stack.cpu()
|
||||
out_dict["sampling_probs"] = sampling_probs_stack.cpu()
|
||||
out_dict["log_probs"] = log_probs_stack.cpu()
|
||||
out_dict["decoding_order"] = decoding_order_stack.cpu()
|
||||
out_dict["native_sequence"] = feature_dict["S"][0].cpu()
|
||||
out_dict["mask"] = feature_dict["mask"][0].cpu()
|
||||
out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu()
|
||||
out_dict["seed"] = seed
|
||||
out_dict["temperature"] = args.temperature
|
||||
if args.save_stats:
|
||||
torch.save(out_dict, output_stats_path)
|
||||
|
||||
if args.pack_side_chains:
|
||||
if args.verbose:
|
||||
print("Packing side chains...")
|
||||
feature_dict_ = featurize(
|
||||
protein_dict,
|
||||
cutoff_for_score=8.0,
|
||||
use_atom_context=args.pack_with_ligand_context,
|
||||
number_of_ligand_atoms=16,
|
||||
model_type="ligand_mpnn",
|
||||
)
|
||||
sc_feature_dict = copy.deepcopy(feature_dict_)
|
||||
B = args.batch_size
|
||||
for k, v in sc_feature_dict.items():
|
||||
if k != "S":
|
||||
try:
|
||||
num_dim = len(v.shape)
|
||||
if num_dim == 2:
|
||||
sc_feature_dict[k] = v.repeat(B, 1)
|
||||
elif num_dim == 3:
|
||||
sc_feature_dict[k] = v.repeat(B, 1, 1)
|
||||
elif num_dim == 4:
|
||||
sc_feature_dict[k] = v.repeat(B, 1, 1, 1)
|
||||
elif num_dim == 5:
|
||||
sc_feature_dict[k] = v.repeat(B, 1, 1, 1, 1)
|
||||
except:
|
||||
pass
|
||||
X_stack_list = []
|
||||
X_m_stack_list = []
|
||||
b_factor_stack_list = []
|
||||
for _ in range(args.number_of_packs_per_design):
|
||||
X_list = []
|
||||
X_m_list = []
|
||||
b_factor_list = []
|
||||
for c in range(args.number_of_batches):
|
||||
sc_feature_dict["S"] = S_list[c]
|
||||
sc_dict = pack_side_chains(
|
||||
sc_feature_dict,
|
||||
model_sc,
|
||||
args.sc_num_denoising_steps,
|
||||
args.sc_num_samples,
|
||||
args.repack_everything,
|
||||
)
|
||||
X_list.append(sc_dict["X"])
|
||||
X_m_list.append(sc_dict["X_m"])
|
||||
b_factor_list.append(sc_dict["b_factors"])
|
||||
|
||||
X_stack = torch.cat(X_list, 0)
|
||||
X_m_stack = torch.cat(X_m_list, 0)
|
||||
b_factor_stack = torch.cat(b_factor_list, 0)
|
||||
|
||||
X_stack_list.append(X_stack)
|
||||
X_m_stack_list.append(X_m_stack)
|
||||
b_factor_stack_list.append(b_factor_stack)
|
||||
|
||||
with open(output_fasta, "w") as f:
|
||||
f.write(
|
||||
">{}, T={}, seed={}, num_res={}, num_ligand_res={}, use_ligand_context={}, ligand_cutoff_distance={}, batch_size={}, number_of_batches={}, model_path={}\n{}\n".format(
|
||||
name,
|
||||
args.temperature,
|
||||
seed,
|
||||
torch.sum(rec_mask).cpu().numpy(),
|
||||
torch.sum(combined_mask[:1]).cpu().numpy(),
|
||||
bool(args.ligand_mpnn_use_atom_context),
|
||||
float(args.ligand_mpnn_cutoff_for_score),
|
||||
args.batch_size,
|
||||
args.number_of_batches,
|
||||
checkpoint_path,
|
||||
seq_out_str,
|
||||
)
|
||||
)
|
||||
for ix in range(S_stack.shape[0]):
|
||||
ix_suffix = ix
|
||||
if not args.zero_indexed:
|
||||
ix_suffix += 1
|
||||
seq_rec_print = np.format_float_positional(
|
||||
rec_stack[ix].cpu().numpy(), unique=False, precision=4
|
||||
)
|
||||
loss_np = np.format_float_positional(
|
||||
np.exp(-loss_stack[ix].cpu().numpy()), unique=False, precision=4
|
||||
)
|
||||
loss_XY_np = np.format_float_positional(
|
||||
np.exp(-loss_XY_stack[ix].cpu().numpy()),
|
||||
unique=False,
|
||||
precision=4,
|
||||
)
|
||||
seq = "".join(
|
||||
[restype_int_to_str[AA] for AA in S_stack[ix].cpu().numpy()]
|
||||
)
|
||||
|
||||
# write new sequences into PDB with backbone coordinates
|
||||
seq_prody = np.array([restype_1to3[AA] for AA in list(seq)])[
|
||||
None,
|
||||
].repeat(4, 1)
|
||||
bfactor_prody = (
|
||||
loss_per_residue_stack[ix].cpu().numpy()[None, :].repeat(4, 1)
|
||||
)
|
||||
backbone.setResnames(seq_prody)
|
||||
backbone.setBetas(
|
||||
np.exp(-bfactor_prody)
|
||||
* (bfactor_prody > 0.01).astype(np.float32)
|
||||
)
|
||||
if other_atoms:
|
||||
writePDB(
|
||||
output_backbones
|
||||
+ name
|
||||
+ "_"
|
||||
+ str(ix_suffix)
|
||||
+ args.file_ending
|
||||
+ ".pdb",
|
||||
backbone + other_atoms,
|
||||
)
|
||||
else:
|
||||
writePDB(
|
||||
output_backbones
|
||||
+ name
|
||||
+ "_"
|
||||
+ str(ix_suffix)
|
||||
+ args.file_ending
|
||||
+ ".pdb",
|
||||
backbone,
|
||||
)
|
||||
|
||||
# write full PDB files
|
||||
if args.pack_side_chains:
|
||||
for c_pack in range(args.number_of_packs_per_design):
|
||||
X_stack = X_stack_list[c_pack]
|
||||
X_m_stack = X_m_stack_list[c_pack]
|
||||
b_factor_stack = b_factor_stack_list[c_pack]
|
||||
write_full_PDB(
|
||||
output_packed
|
||||
+ name
|
||||
+ args.packed_suffix
|
||||
+ "_"
|
||||
+ str(ix_suffix)
|
||||
+ "_"
|
||||
+ str(c_pack + 1)
|
||||
+ args.file_ending
|
||||
+ ".pdb",
|
||||
X_stack[ix].cpu().numpy(),
|
||||
X_m_stack[ix].cpu().numpy(),
|
||||
b_factor_stack[ix].cpu().numpy(),
|
||||
feature_dict["R_idx_original"][0].cpu().numpy(),
|
||||
protein_dict["chain_letters"],
|
||||
S_stack[ix].cpu().numpy(),
|
||||
other_atoms=other_atoms,
|
||||
icodes=icodes,
|
||||
force_hetatm=args.force_hetatm,
|
||||
)
|
||||
# -----
|
||||
|
||||
# write fasta lines
|
||||
seq_np = np.array(list(seq))
|
||||
seq_out_str = []
|
||||
for mask in protein_dict["mask_c"]:
|
||||
seq_out_str += list(seq_np[mask.cpu().numpy()])
|
||||
seq_out_str += [args.fasta_seq_separation]
|
||||
seq_out_str = "".join(seq_out_str)[:-1]
|
||||
if ix == S_stack.shape[0] - 1:
|
||||
# final 2 lines
|
||||
f.write(
|
||||
">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}".format(
|
||||
name,
|
||||
ix_suffix,
|
||||
args.temperature,
|
||||
seed,
|
||||
loss_np,
|
||||
loss_XY_np,
|
||||
seq_rec_print,
|
||||
seq_out_str,
|
||||
)
|
||||
)
|
||||
else:
|
||||
f.write(
|
||||
">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}\n".format(
|
||||
name,
|
||||
ix_suffix,
|
||||
args.temperature,
|
||||
seed,
|
||||
loss_np,
|
||||
loss_XY_np,
|
||||
seq_rec_print,
|
||||
seq_out_str,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="protein_mpnn",
|
||||
help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
|
||||
)
|
||||
# protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
|
||||
# ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
|
||||
# per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
|
||||
# global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
|
||||
# soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
|
||||
argparser.add_argument(
|
||||
"--checkpoint_protein_mpnn",
|
||||
type=str,
|
||||
default="./model_params/proteinmpnn_v_48_020.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--checkpoint_ligand_mpnn",
|
||||
type=str,
|
||||
default="./model_params/ligandmpnn_v_32_010_25.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--checkpoint_per_residue_label_membrane_mpnn",
|
||||
type=str,
|
||||
default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--checkpoint_global_label_membrane_mpnn",
|
||||
type=str,
|
||||
default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--checkpoint_soluble_mpnn",
|
||||
type=str,
|
||||
default="./model_params/solublempnn_v_48_020.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--fasta_seq_separation",
|
||||
type=str,
|
||||
default=":",
|
||||
help="Symbol to use between sequences from different chains",
|
||||
)
|
||||
argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
|
||||
|
||||
argparser.add_argument(
|
||||
"--pdb_path", type=str, default="", help="Path to the input PDB."
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--pdb_path_multi",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--fixed_residues",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide fixed residues, A12 A13 A14 B2 B25",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--fixed_residues_multi",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--redesigned_residues",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--redesigned_residues_multi",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--bias_AA",
|
||||
type=str,
|
||||
default="",
|
||||
help="Bias generation of amino acids, e.g. 'A:-1.024,P:2.34,C:-12.34'",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--bias_AA_per_residue",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json mapping of bias {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--bias_AA_per_residue_multi",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json mapping of bias {'pdb_path': {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}}",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--omit_AA",
|
||||
type=str,
|
||||
default="",
|
||||
help="Bias generation of amino acids, e.g. 'ACG'",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--omit_AA_per_residue",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json mapping of bias {'A12': 'APQ', 'A13': 'QST'}",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--omit_AA_per_residue_multi",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json mapping of bias {'pdb_path': {'A12': 'QSPC', 'A13': 'AGE'}}",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--symmetry_residues",
|
||||
type=str,
|
||||
default="",
|
||||
help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--symmetry_weights",
|
||||
type=str,
|
||||
default="",
|
||||
help="Add weights that match symmetry_residues, e.g. '1.01,1.0,1.0|-1.0,2.0|2.0,2.3'",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--homo_oligomer",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--out_folder",
|
||||
type=str,
|
||||
help="Path to a folder to output sequences, e.g. /home/out/",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--file_ending", type=str, default="", help="adding_string_to_the_end"
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--zero_indexed",
|
||||
type=str,
|
||||
default=0,
|
||||
help="1 - to start output PDB numbering with 0",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Set seed for torch, numpy, and python random.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of sequence to generate per one pass.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--number_of_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of times to design sequence using a chosen batch size.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Temperature to sample sequences.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--save_stats", type=int, default=0, help="Save output statistics"
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--ligand_mpnn_use_atom_context",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1 - use atom context, 0 - do not use atom context.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--ligand_mpnn_cutoff_for_score",
|
||||
type=float,
|
||||
default=8.0,
|
||||
help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--ligand_mpnn_use_side_chain_context",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Flag to use side chain atoms as ligand context for the fixed residues",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--chains_to_design",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify which chains to redesign, all others will be kept fixed, 'A,B,C,F'",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--parse_these_chains_only",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide chains letters for parsing backbones, 'A,B,C,F'",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--transmembrane_buried",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--transmembrane_interface",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--global_transmembrane_label",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--parse_atoms_with_zero_occupancy",
|
||||
type=int,
|
||||
default=0,
|
||||
help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--pack_side_chains",
|
||||
type=int,
|
||||
default=0,
|
||||
help="1 - to run side chain packer, 0 - do not run it",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--checkpoint_path_sc",
|
||||
type=str,
|
||||
default="./model_params/ligandmpnn_sc_v_32_002_16.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--number_of_packs_per_design",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of independent side chain packing samples to return per design",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--sc_num_denoising_steps",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of denoising/recycling steps to make for side chain packing",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--sc_num_samples",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of samples to draw from a mixture distribution and then take a sample with the highest likelihood.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--repack_everything",
|
||||
type=int,
|
||||
default=0,
|
||||
help="1 - repacks side chains of all residues including the fixed ones; 0 - keeps the side chains fixed for fixed residues",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--force_hetatm",
|
||||
type=int,
|
||||
default=0,
|
||||
help="To force ligand atoms to be written as HETATM to PDB file after packing.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--packed_suffix",
|
||||
type=str,
|
||||
default="_packed",
|
||||
help="Suffix for packed PDB paths",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--pack_with_ligand_context",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1-pack side chains using ligand context, 0 - do not use it.",
|
||||
)
|
||||
|
||||
args = argparser.parse_args()
|
||||
main(args)
|
||||
244
run_examples.sh
Normal file
244
run_examples.sh
Normal file
@@ -0,0 +1,244 @@
|
||||
#!/bin/bash
|
||||
|
||||
#1
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/default"
|
||||
#2
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--temperature 0.05 \
|
||||
--out_folder "./outputs/temperature"
|
||||
|
||||
#3
|
||||
python run.py \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/random_seed"
|
||||
|
||||
#4
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--verbose 0 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/verbose"
|
||||
|
||||
#5
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/save_stats" \
|
||||
--save_stats 1
|
||||
|
||||
#6
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/fix_residues" \
|
||||
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
|
||||
--bias_AA "A:10.0"
|
||||
|
||||
#7
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/redesign_residues" \
|
||||
--redesigned_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
|
||||
--bias_AA "A:10.0"
|
||||
|
||||
#8
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/batch_size" \
|
||||
--batch_size 3 \
|
||||
--number_of_batches 5
|
||||
|
||||
#9
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--bias_AA "W:3.0,P:3.0,C:3.0,A:-3.0" \
|
||||
--out_folder "./outputs/global_bias"
|
||||
|
||||
#10
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--bias_AA_per_residue "./inputs/bias_AA_per_residue.json" \
|
||||
--out_folder "./outputs/per_residue_bias"
|
||||
|
||||
#11
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--omit_AA "CDFGHILMNPQRSTVWY" \
|
||||
--out_folder "./outputs/global_omit"
|
||||
|
||||
#12
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--omit_AA_per_residue "./inputs/omit_AA_per_residue.json" \
|
||||
--out_folder "./outputs/per_residue_omit"
|
||||
|
||||
#13
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/symmetry" \
|
||||
--symmetry_residues "C1,C2,C3|C4,C5|C6,C7" \
|
||||
--symmetry_weights "0.33,0.33,0.33|0.5,0.5|0.5,0.5"
|
||||
|
||||
#14
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/4GYT.pdb" \
|
||||
--out_folder "./outputs/homooligomer" \
|
||||
--homo_oligomer 1 \
|
||||
--number_of_batches 2
|
||||
|
||||
#15
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/file_ending" \
|
||||
--file_ending "_xyz"
|
||||
|
||||
#16
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/zero_indexed" \
|
||||
--zero_indexed 1 \
|
||||
--number_of_batches 2
|
||||
|
||||
#17
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/4GYT.pdb" \
|
||||
--out_folder "./outputs/chains_to_design" \
|
||||
--chains_to_design "A,B"
|
||||
|
||||
#18
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/4GYT.pdb" \
|
||||
--out_folder "./outputs/parse_these_chains_only" \
|
||||
--parse_these_chains_only "A,B"
|
||||
|
||||
#19
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/ligandmpnn_default"
|
||||
|
||||
#20
|
||||
python run.py \
|
||||
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/ligandmpnn_v_32_005_25"
|
||||
|
||||
#21
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/ligandmpnn_no_context" \
|
||||
--ligand_mpnn_use_atom_context 0
|
||||
|
||||
#22
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/ligandmpnn_use_side_chain_atoms" \
|
||||
--ligand_mpnn_use_side_chain_context 1 \
|
||||
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10"
|
||||
|
||||
#23
|
||||
python run.py \
|
||||
--model_type "soluble_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/soluble_mpnn_default"
|
||||
|
||||
#24
|
||||
python run.py \
|
||||
--model_type "global_label_membrane_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/global_label_membrane_mpnn_0" \
|
||||
--global_transmembrane_label 0
|
||||
|
||||
#25
|
||||
python run.py \
|
||||
--model_type "per_residue_label_membrane_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/per_residue_label_membrane_mpnn_default" \
|
||||
--transmembrane_buried "C1 C2 C3 C11" \
|
||||
--transmembrane_interface "C4 C5 C6 C22"
|
||||
|
||||
#26
|
||||
python run.py \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/fasta_seq_separation" \
|
||||
--fasta_seq_separation ":"
|
||||
|
||||
#27
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--out_folder "./outputs/pdb_path_multi" \
|
||||
--seed 111
|
||||
|
||||
#28
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--fixed_residues_multi "./inputs/fix_residues_multi.json" \
|
||||
--out_folder "./outputs/fixed_residues_multi" \
|
||||
--seed 111
|
||||
|
||||
#29
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--redesigned_residues_multi "./inputs/redesigned_residues_multi.json" \
|
||||
--out_folder "./outputs/redesigned_residues_multi" \
|
||||
--seed 111
|
||||
|
||||
#30
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--omit_AA_per_residue_multi "./inputs/omit_AA_per_residue_multi.json" \
|
||||
--out_folder "./outputs/omit_AA_per_residue_multi" \
|
||||
--seed 111
|
||||
|
||||
#31
|
||||
python run.py \
|
||||
--pdb_path_multi "./inputs/pdb_ids.json" \
|
||||
--bias_AA_per_residue_multi "./inputs/bias_AA_per_residue_multi.json" \
|
||||
--out_folder "./outputs/bias_AA_per_residue_multi" \
|
||||
--seed 111
|
||||
|
||||
#32
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--ligand_mpnn_cutoff_for_score "6.0" \
|
||||
--out_folder "./outputs/ligand_mpnn_cutoff_for_score"
|
||||
|
||||
#33
|
||||
python run.py \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/2GFB.pdb" \
|
||||
--out_folder "./outputs/insertion_code" \
|
||||
--redesigned_residues "B82 B82A B82B B82C" \
|
||||
--parse_these_chains_only "B"
|
||||
55
sc_examples.sh
Normal file
55
sc_examples.sh
Normal file
@@ -0,0 +1,55 @@
|
||||
#1 design a new sequence and pack side chains (return 1 side chain packing sample - fast)
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_default_fast" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 0 \
|
||||
--pack_with_ligand_context 1
|
||||
|
||||
#2 design a new sequence and pack side chains (return 4 side chain packing samples)
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_default" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 4 \
|
||||
--pack_with_ligand_context 1
|
||||
|
||||
|
||||
#3 fix specific residues for design and packing
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_fixed_residues" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 4 \
|
||||
--pack_with_ligand_context 1 \
|
||||
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
|
||||
--repack_everything 0
|
||||
|
||||
#4 fix specific residues for sequence design but repack everything
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_fixed_residues_full_repack" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 4 \
|
||||
--pack_with_ligand_context 1 \
|
||||
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
|
||||
--repack_everything 1
|
||||
|
||||
|
||||
#5 design a new sequence using LigandMPNN but pack side chains without considering ligand/DNA etc atoms
|
||||
python run.py \
|
||||
--model_type "ligand_mpnn" \
|
||||
--seed 111 \
|
||||
--pdb_path "./inputs/1BC8.pdb" \
|
||||
--out_folder "./outputs/sc_no_context" \
|
||||
--pack_side_chains 1 \
|
||||
--number_of_packs_per_design 4 \
|
||||
--pack_with_ligand_context 0
|
||||
1158
sc_utils.py
Normal file
1158
sc_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
549
score.py
Normal file
549
score.py
Normal file
@@ -0,0 +1,549 @@
|
||||
import argparse
|
||||
import json
|
||||
import os.path
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from data_utils import (
|
||||
element_dict_rev,
|
||||
alphabet,
|
||||
restype_int_to_str,
|
||||
featurize,
|
||||
parse_PDB,
|
||||
)
|
||||
from model_utils import ProteinMPNN
|
||||
|
||||
|
||||
def main(args) -> None:
|
||||
"""
|
||||
Inference function
|
||||
"""
|
||||
if args.seed:
|
||||
seed = args.seed
|
||||
else:
|
||||
seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
|
||||
folder_for_outputs = args.out_folder
|
||||
base_folder = folder_for_outputs
|
||||
if base_folder[-1] != "/":
|
||||
base_folder = base_folder + "/"
|
||||
if not os.path.exists(base_folder):
|
||||
os.makedirs(base_folder, exist_ok=True)
|
||||
if args.model_type == "protein_mpnn":
|
||||
checkpoint_path = args.checkpoint_protein_mpnn
|
||||
elif args.model_type == "ligand_mpnn":
|
||||
checkpoint_path = args.checkpoint_ligand_mpnn
|
||||
elif args.model_type == "per_residue_label_membrane_mpnn":
|
||||
checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
|
||||
elif args.model_type == "global_label_membrane_mpnn":
|
||||
checkpoint_path = args.checkpoint_global_label_membrane_mpnn
|
||||
elif args.model_type == "soluble_mpnn":
|
||||
checkpoint_path = args.checkpoint_soluble_mpnn
|
||||
else:
|
||||
print("Choose one of the available models")
|
||||
sys.exit()
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
if args.model_type == "ligand_mpnn":
|
||||
atom_context_num = checkpoint["atom_context_num"]
|
||||
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
|
||||
k_neighbors = checkpoint["num_edges"]
|
||||
else:
|
||||
atom_context_num = 1
|
||||
ligand_mpnn_use_side_chain_context = 0
|
||||
k_neighbors = checkpoint["num_edges"]
|
||||
|
||||
model = ProteinMPNN(
|
||||
node_features=128,
|
||||
edge_features=128,
|
||||
hidden_dim=128,
|
||||
num_encoder_layers=3,
|
||||
num_decoder_layers=3,
|
||||
k_neighbors=k_neighbors,
|
||||
device=device,
|
||||
atom_context_num=atom_context_num,
|
||||
model_type=args.model_type,
|
||||
ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
|
||||
)
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.pdb_path_multi:
|
||||
with open(args.pdb_path_multi, "r") as fh:
|
||||
pdb_paths = list(json.load(fh))
|
||||
else:
|
||||
pdb_paths = [args.pdb_path]
|
||||
|
||||
if args.fixed_residues_multi:
|
||||
with open(args.fixed_residues_multi, "r") as fh:
|
||||
fixed_residues_multi = json.load(fh)
|
||||
else:
|
||||
fixed_residues = [item for item in args.fixed_residues.split()]
|
||||
fixed_residues_multi = {}
|
||||
for pdb in pdb_paths:
|
||||
fixed_residues_multi[pdb] = fixed_residues
|
||||
|
||||
if args.redesigned_residues_multi:
|
||||
with open(args.redesigned_residues_multi, "r") as fh:
|
||||
redesigned_residues_multi = json.load(fh)
|
||||
else:
|
||||
redesigned_residues = [item for item in args.redesigned_residues.split()]
|
||||
redesigned_residues_multi = {}
|
||||
for pdb in pdb_paths:
|
||||
redesigned_residues_multi[pdb] = redesigned_residues
|
||||
|
||||
# loop over PDB paths
|
||||
for pdb in pdb_paths:
|
||||
if args.verbose:
|
||||
print("Designing protein from this path:", pdb)
|
||||
fixed_residues = fixed_residues_multi[pdb]
|
||||
redesigned_residues = redesigned_residues_multi[pdb]
|
||||
protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
|
||||
pdb,
|
||||
device=device,
|
||||
chains=args.parse_these_chains_only,
|
||||
parse_all_atoms=args.ligand_mpnn_use_side_chain_context,
|
||||
parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy
|
||||
)
|
||||
# make chain_letter + residue_idx + insertion_code mapping to integers
|
||||
R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
|
||||
chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
|
||||
encoded_residues = []
|
||||
for i, R_idx_item in enumerate(R_idx_list):
|
||||
tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
|
||||
encoded_residues.append(tmp)
|
||||
encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
|
||||
encoded_residue_dict_rev = dict(
|
||||
zip(list(range(len(encoded_residues))), encoded_residues)
|
||||
)
|
||||
|
||||
fixed_positions = torch.tensor(
|
||||
[int(item not in fixed_residues) for item in encoded_residues],
|
||||
device=device,
|
||||
)
|
||||
redesigned_positions = torch.tensor(
|
||||
[int(item not in redesigned_residues) for item in encoded_residues],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
|
||||
if args.transmembrane_buried:
|
||||
buried_residues = [item for item in args.transmembrane_buried.split()]
|
||||
buried_positions = torch.tensor(
|
||||
[int(item in buried_residues) for item in encoded_residues],
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
buried_positions = torch.zeros_like(fixed_positions)
|
||||
|
||||
if args.transmembrane_interface:
|
||||
interface_residues = [item for item in args.transmembrane_interface.split()]
|
||||
interface_positions = torch.tensor(
|
||||
[int(item in interface_residues) for item in encoded_residues],
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
interface_positions = torch.zeros_like(fixed_positions)
|
||||
protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
|
||||
1 - interface_positions
|
||||
) + 1 * interface_positions * (1 - buried_positions)
|
||||
|
||||
if args.model_type == "global_label_membrane_mpnn":
|
||||
protein_dict["membrane_per_residue_labels"] = (
|
||||
args.global_transmembrane_label + 0 * fixed_positions
|
||||
)
|
||||
if type(args.chains_to_design) == str:
|
||||
chains_to_design_list = args.chains_to_design.split(",")
|
||||
else:
|
||||
chains_to_design_list = protein_dict["chain_letters"]
|
||||
chain_mask = torch.tensor(
|
||||
np.array(
|
||||
[
|
||||
item in chains_to_design_list
|
||||
for item in protein_dict["chain_letters"]
|
||||
],
|
||||
dtype=np.int32,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
# create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
|
||||
if redesigned_residues:
|
||||
protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
|
||||
elif fixed_residues:
|
||||
protein_dict["chain_mask"] = chain_mask * fixed_positions
|
||||
else:
|
||||
protein_dict["chain_mask"] = chain_mask
|
||||
|
||||
if args.verbose:
|
||||
PDB_residues_to_be_redesigned = [
|
||||
encoded_residue_dict_rev[item]
|
||||
for item in range(protein_dict["chain_mask"].shape[0])
|
||||
if protein_dict["chain_mask"][item] == 1
|
||||
]
|
||||
PDB_residues_to_be_fixed = [
|
||||
encoded_residue_dict_rev[item]
|
||||
for item in range(protein_dict["chain_mask"].shape[0])
|
||||
if protein_dict["chain_mask"][item] == 0
|
||||
]
|
||||
print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
|
||||
print("These residues will be fixed: ", PDB_residues_to_be_fixed)
|
||||
|
||||
# specify which residues are linked
|
||||
if args.symmetry_residues:
|
||||
symmetry_residues_list_of_lists = [
|
||||
x.split(",") for x in args.symmetry_residues.split("|")
|
||||
]
|
||||
remapped_symmetry_residues = []
|
||||
for t_list in symmetry_residues_list_of_lists:
|
||||
tmp_list = []
|
||||
for t in t_list:
|
||||
tmp_list.append(encoded_residue_dict[t])
|
||||
remapped_symmetry_residues.append(tmp_list)
|
||||
else:
|
||||
remapped_symmetry_residues = [[]]
|
||||
|
||||
if args.homo_oligomer:
|
||||
if args.verbose:
|
||||
print("Designing HOMO-OLIGOMER")
|
||||
chain_letters_set = list(set(chain_letters_list))
|
||||
reference_chain = chain_letters_set[0]
|
||||
lc = len(reference_chain)
|
||||
residue_indices = [
|
||||
item[lc:] for item in encoded_residues if item[:lc] == reference_chain
|
||||
]
|
||||
remapped_symmetry_residues = []
|
||||
for res in residue_indices:
|
||||
tmp_list = []
|
||||
tmp_w_list = []
|
||||
for chain in chain_letters_set:
|
||||
name = chain + res
|
||||
tmp_list.append(encoded_residue_dict[name])
|
||||
tmp_w_list.append(1 / len(chain_letters_set))
|
||||
remapped_symmetry_residues.append(tmp_list)
|
||||
|
||||
# set other atom bfactors to 0.0
|
||||
if other_atoms:
|
||||
other_bfactors = other_atoms.getBetas()
|
||||
other_atoms.setBetas(other_bfactors * 0.0)
|
||||
|
||||
# adjust input PDB name by dropping .pdb if it does exist
|
||||
name = pdb[pdb.rfind("/") + 1 :]
|
||||
if name[-4:] == ".pdb":
|
||||
name = name[:-4]
|
||||
|
||||
with torch.no_grad():
|
||||
# run featurize to remap R_idx and add batch dimension
|
||||
if args.verbose:
|
||||
if "Y" in list(protein_dict):
|
||||
atom_coords = protein_dict["Y"].cpu().numpy()
|
||||
atom_types = list(protein_dict["Y_t"].cpu().numpy())
|
||||
atom_mask = list(protein_dict["Y_m"].cpu().numpy())
|
||||
number_of_atoms_parsed = np.sum(atom_mask)
|
||||
else:
|
||||
print("No ligand atoms parsed")
|
||||
number_of_atoms_parsed = 0
|
||||
atom_types = ""
|
||||
atom_coords = []
|
||||
if number_of_atoms_parsed == 0:
|
||||
print("No ligand atoms parsed")
|
||||
elif args.model_type == "ligand_mpnn":
|
||||
print(
|
||||
f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
|
||||
)
|
||||
for i, atom_type in enumerate(atom_types):
|
||||
print(
|
||||
f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
|
||||
)
|
||||
feature_dict = featurize(
|
||||
protein_dict,
|
||||
cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
|
||||
use_atom_context=args.ligand_mpnn_use_atom_context,
|
||||
number_of_ligand_atoms=atom_context_num,
|
||||
model_type=args.model_type,
|
||||
)
|
||||
feature_dict["batch_size"] = args.batch_size
|
||||
B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
|
||||
# add additional keys to the feature dictionary
|
||||
feature_dict["symmetry_residues"] = remapped_symmetry_residues
|
||||
|
||||
logits_list = []
|
||||
probs_list = []
|
||||
log_probs_list = []
|
||||
decoding_order_list = []
|
||||
for _ in range(args.number_of_batches):
|
||||
feature_dict["randn"] = torch.randn(
|
||||
[feature_dict["batch_size"], feature_dict["mask"].shape[1]],
|
||||
device=device,
|
||||
)
|
||||
if args.autoregressive_score:
|
||||
score_dict = model.score(feature_dict, use_sequence=args.use_sequence)
|
||||
elif args.single_aa_score:
|
||||
score_dict = model.single_aa_score(feature_dict, use_sequence=args.use_sequence)
|
||||
else:
|
||||
print("Set either autoregressive_score or single_aa_score to True")
|
||||
sys.exit()
|
||||
logits_list.append(score_dict["logits"])
|
||||
log_probs_list.append(score_dict["log_probs"])
|
||||
probs_list.append(torch.exp(score_dict["log_probs"]))
|
||||
decoding_order_list.append(score_dict["decoding_order"])
|
||||
log_probs_stack = torch.cat(log_probs_list, 0)
|
||||
logits_stack = torch.cat(logits_list, 0)
|
||||
probs_stack = torch.cat(probs_list, 0)
|
||||
decoding_order_stack = torch.cat(decoding_order_list, 0)
|
||||
|
||||
output_stats_path = base_folder + name + args.file_ending + ".pt"
|
||||
out_dict = {}
|
||||
out_dict["logits"] = logits_stack.cpu().numpy()
|
||||
out_dict["probs"] = probs_stack.cpu().numpy()
|
||||
out_dict["log_probs"] = log_probs_stack.cpu().numpy()
|
||||
out_dict["decoding_order"] = decoding_order_stack.cpu().numpy()
|
||||
out_dict["native_sequence"] = feature_dict["S"][0].cpu().numpy()
|
||||
out_dict["mask"] = feature_dict["mask"][0].cpu().numpy()
|
||||
out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu().numpy() #this affects decoding order
|
||||
out_dict["seed"] = seed
|
||||
out_dict["alphabet"] = alphabet
|
||||
out_dict["residue_names"] = encoded_residue_dict_rev
|
||||
|
||||
mean_probs = np.mean(out_dict["probs"], 0)
|
||||
std_probs = np.std(out_dict["probs"], 0)
|
||||
sequence = [restype_int_to_str[AA] for AA in out_dict["native_sequence"]]
|
||||
mean_dict = {}
|
||||
std_dict = {}
|
||||
for residue in range(L):
|
||||
mean_dict_ = dict(zip(alphabet, mean_probs[residue]))
|
||||
mean_dict[encoded_residue_dict_rev[residue]] = mean_dict_
|
||||
std_dict_ = dict(zip(alphabet, std_probs[residue]))
|
||||
std_dict[encoded_residue_dict_rev[residue]] = std_dict_
|
||||
|
||||
out_dict["sequence"] = sequence
|
||||
out_dict["mean_of_probs"] = mean_dict
|
||||
out_dict["std_of_probs"] = std_dict
|
||||
torch.save(out_dict, output_stats_path)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="protein_mpnn",
|
||||
help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
|
||||
)
|
||||
# protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
|
||||
# ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
|
||||
# per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
|
||||
# global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
|
||||
# soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
|
||||
argparser.add_argument(
|
||||
"--checkpoint_protein_mpnn",
|
||||
type=str,
|
||||
default="./model_params/proteinmpnn_v_48_020.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--checkpoint_ligand_mpnn",
|
||||
type=str,
|
||||
default="./model_params/ligandmpnn_v_32_010_25.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--checkpoint_per_residue_label_membrane_mpnn",
|
||||
type=str,
|
||||
default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--checkpoint_global_label_membrane_mpnn",
|
||||
type=str,
|
||||
default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--checkpoint_soluble_mpnn",
|
||||
type=str,
|
||||
default="./model_params/solublempnn_v_48_020.pt",
|
||||
help="Path to model weights.",
|
||||
)
|
||||
|
||||
argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
|
||||
|
||||
argparser.add_argument(
|
||||
"--pdb_path", type=str, default="", help="Path to the input PDB."
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--pdb_path_multi",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--fixed_residues",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide fixed residues, A12 A13 A14 B2 B25",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--fixed_residues_multi",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--redesigned_residues",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--redesigned_residues_multi",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--symmetry_residues",
|
||||
type=str,
|
||||
default="",
|
||||
help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--homo_oligomer",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--out_folder",
|
||||
type=str,
|
||||
help="Path to a folder to output scores, e.g. /home/out/",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--file_ending", type=str, default="", help="adding_string_to_the_end"
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--zero_indexed",
|
||||
type=str,
|
||||
default=0,
|
||||
help="1 - to start output PDB numbering with 0",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Set seed for torch, numpy, and python random.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of sequence to generate per one pass.",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--number_of_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of times to design sequence using a chosen batch size.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--ligand_mpnn_use_atom_context",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1 - use atom context, 0 - do not use atom context.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--ligand_mpnn_use_side_chain_context",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Flag to use side chain atoms as ligand context for the fixed residues",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--ligand_mpnn_cutoff_for_score",
|
||||
type=float,
|
||||
default=8.0,
|
||||
help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--chains_to_design",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify which chains to redesign, all others will be kept fixed.",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--parse_these_chains_only",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide chains letters for parsing backbones, 'ABCF'",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--transmembrane_buried",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
||||
)
|
||||
argparser.add_argument(
|
||||
"--transmembrane_interface",
|
||||
type=str,
|
||||
default="",
|
||||
help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--global_transmembrane_label",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--parse_atoms_with_zero_occupancy",
|
||||
type=int,
|
||||
default=0,
|
||||
help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--use_sequence",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1 - get scores using amino acid sequence info; 0 - get scores using backbone info only",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--autoregressive_score",
|
||||
type=int,
|
||||
default=0,
|
||||
help="1 - run autoregressive scoring function; p(AA_1|backbone); p(AA_2|backbone, AA_1) etc, 0 - False",
|
||||
)
|
||||
|
||||
argparser.add_argument(
|
||||
"--single_aa_score",
|
||||
type=int,
|
||||
default=1,
|
||||
help="1 - run single amino acid scoring function; p(AA_i|backbone, AA_{all except ith one}), 0 - False",
|
||||
)
|
||||
|
||||
args = argparser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user