Initial commit: FlowDock pipeline configured for WES execution
Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled

This commit is contained in:
2026-03-16 15:23:29 +01:00
commit a3ffec6a07
116 changed files with 16139 additions and 0 deletions

6
.env.example Normal file
View File

@@ -0,0 +1,6 @@
# example of file for storing private and user specific environment variables, like keys or system paths
# rename it to ".env" (excluded from version control by default)
# .env is loaded by train.py automatically
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
PLINDER_MOUNT="$(pwd)/data/PLINDER"

22
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@@ -0,0 +1,22 @@
## What does this PR do?
<!--
Please include a summary of the change and which issue is fixed.
Please also include relevant motivation and context.
List any dependencies that are required for this change.
List all the breaking changes introduced by this pull request.
-->
Fixes #\<issue_number>
## Before submitting
- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**?
- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together?
- [ ] Did you list all the **breaking changes** introduced by this pull request?
- [ ] Did you **test your PR locally** with `pytest` command?
- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command?
## Did you have fun?
Make sure you had fun coding 🙃

15
.github/codecov.yml vendored Normal file
View File

@@ -0,0 +1,15 @@
coverage:
status:
# measures overall project coverage
project:
default:
threshold: 100% # how much decrease in coverage is needed to not consider success
# measures PR or single commit coverage
patch:
default:
threshold: 100% # how much decrease in coverage is needed to not consider success
# project: off
# patch: off

16
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,16 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "daily"
ignore:
- dependency-name: "pytorch-lightning"
update-types: ["version-update:semver-patch"]
- dependency-name: "torchmetrics"
update-types: ["version-update:semver-patch"]

44
.github/release-drafter.yml vendored Normal file
View File

@@ -0,0 +1,44 @@
name-template: "v$RESOLVED_VERSION"
tag-template: "v$RESOLVED_VERSION"
categories:
- title: "🚀 Features"
labels:
- "feature"
- "enhancement"
- title: "🐛 Bug Fixes"
labels:
- "fix"
- "bugfix"
- "bug"
- title: "🧹 Maintenance"
labels:
- "maintenance"
- "dependencies"
- "refactoring"
- "cosmetic"
- "chore"
- title: "📝️ Documentation"
labels:
- "documentation"
- "docs"
change-template: "- $TITLE @$AUTHOR (#$NUMBER)"
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions
version-resolver:
major:
labels:
- "major"
minor:
labels:
- "minor"
patch:
labels:
- "patch"
default: patch
template: |
## Changes
$CHANGES

View File

@@ -0,0 +1,22 @@
# Same as `code-quality-pr.yaml` but triggered on commit to main branch
# and runs on all files (instead of only the changed ones)
name: Code Quality Main
on:
push:
branches: [main]
jobs:
code-quality:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
- name: Run pre-commits
uses: pre-commit/action@v2.0.3

36
.github/workflows/code-quality-pr.yaml vendored Normal file
View File

@@ -0,0 +1,36 @@
# This workflow finds which files were changed, prints them,
# and runs `pre-commit` on those files.
# Inspired by the sktime library:
# https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml
name: Code Quality PR
on:
pull_request:
branches: [main, "release/*", "dev"]
jobs:
code-quality:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
- name: Find modified files
id: file_changes
uses: trilom/file-changes-action@v1.2.4
with:
output: " "
- name: List modified files
run: echo '${{ steps.file_changes.outputs.files}}'
- name: Run pre-commits
uses: pre-commit/action@v2.0.3
with:
extra_args: --files ${{ steps.file_changes.outputs.files}}

27
.github/workflows/release-drafter.yml vendored Normal file
View File

@@ -0,0 +1,27 @@
name: Release Drafter
on:
push:
# branches to consider in the event; optional, defaults to all
branches:
- main
permissions:
contents: read
jobs:
update_release_draft:
permissions:
# write permission is required to create a github release
contents: write
# write permission is required for autolabeler
# otherwise, read permission is required at least
pull-requests: write
runs-on: ubuntu-latest
steps:
# Drafts your next Release notes as Pull Requests are merged into "master"
- uses: release-drafter/release-drafter@v5
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

139
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,139 @@
name: Tests
on:
push:
branches: [main]
pull_request:
branches: [main, "release/*", "dev"]
jobs:
run_tests_ubuntu:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8", "3.9", "3.10"]
timeout-minutes: 20
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
conda env create -f environment.yml
python -m pip install --upgrade pip
pip install pytest
pip install sh
- name: List dependencies
run: |
python -m pip list
- name: Run pytest
run: |
pytest -v
run_tests_macos:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["macos-latest"]
python-version: ["3.8", "3.9", "3.10"]
timeout-minutes: 20
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
conda env create -f environment.yml
python -m pip install --upgrade pip
pip install pytest
pip install sh
- name: List dependencies
run: |
python -m pip list
- name: Run pytest
run: |
pytest -v
run_tests_windows:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["windows-latest"]
python-version: ["3.8", "3.9", "3.10"]
timeout-minutes: 20
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
conda env create -f environment.yml
python -m pip install --upgrade pip
pip install pytest
- name: List dependencies
run: |
python -m pip list
- name: Run pytest
run: |
pytest -v
# upload code coverage report
code-coverage:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
conda env create -f environment.yml
python -m pip install --upgrade pip
pip install pytest
pip install pytest-cov[toml]
pip install sh
- name: Run tests and collect coverage
run: pytest --cov flowdock # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3

11
.gitignore vendored Normal file
View File

@@ -0,0 +1,11 @@
work/
.nextflow/
.nextflow.log*
*.log.*
results/
__pycache__/
*.pyc
.vscode/
.idea/
*.tmp
*.swp

150
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,150 @@
default_language_version:
python: python3
exclude: "^forks/"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-docstring-first
- id: check-yaml
- id: debug-statements
- id: detect-private-key
- id: check-executables-have-shebangs
- id: check-toml
- id: check-case-conflict
- id: check-added-large-files
args: ["--maxkb=20000"]
# python code formatting
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
args: [--line-length, "99"]
# python import sorting
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py38-plus]
# python docstring formatting
- repo: https://github.com/myint/docformatter
rev: eb1df347edd128b30cd3368dddc3aa65edcfac38 # Don't autoupdate until https://github.com/PyCQA/docformatter/issues/293 is fixed
hooks:
- id: docformatter
args:
[
--in-place,
--wrap-summaries=99,
--wrap-descriptions=99,
--style=sphinx,
--black,
]
# python docstring coverage checking
- repo: https://github.com/econchick/interrogate
rev: 1.7.0 # or master if you're bold
hooks:
- id: interrogate
args:
[
--verbose,
--fail-under=80,
--ignore-init-module,
--ignore-init-method,
--ignore-module,
--ignore-nested-functions,
-vv,
]
# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
rev: 7.1.2
hooks:
- id: flake8
args:
[
"--extend-ignore",
"E203,E402,E501,F401,F841,RST2,RST301",
"--exclude",
"logs/*,data/*",
]
additional_dependencies: [flake8-rst-docstrings==0.3.0]
# python security linter
- repo: https://github.com/PyCQA/bandit
rev: "1.8.3"
hooks:
- id: bandit
args: ["-s", "B101"]
# yaml formatting
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
types: [yaml]
exclude: "environment.yaml"
# shell scripts linter
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.10.0.1
hooks:
- id: shellcheck
# md formatting
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.22
hooks:
- id: mdformat
args: ["--number"]
additional_dependencies:
- mdformat-gfm
- mdformat-tables
- mdformat_frontmatter
# - mdformat-toc
# - mdformat-black
# word spelling linter
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
- id: codespell
args:
- --skip=logs/**,data/**,*.ipynb,flowdock/data/components/constants.py,flowdock/data/components/process_mols.py,flowdock/data/components/residue_constants.py,flowdock/data/components/uff_parameters.csv,flowdock/data/components/chemical/*,flowdock/utils/data_utils.py
# - --ignore-words-list=abc,def
# jupyter notebook cell output clearing
- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
hooks:
- id: nbstripout
# jupyter notebook linting
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.9.1
hooks:
- id: nbqa-black
args: ["--line-length=99"]
- id: nbqa-isort
args: ["--profile=black"]
- id: nbqa-flake8
args:
[
"--extend-ignore=E203,E402,E501,F401,F841",
"--exclude=logs/*,data/*",
]

2
.project-root Normal file
View File

@@ -0,0 +1,2 @@
# this file is required for inferring the project root directory
# do not delete

49
Dockerfile Normal file
View File

@@ -0,0 +1,49 @@
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-runtime
LABEL authors="BioinfoMachineLearning"
# Install system requirements
RUN apt-get update && \
apt-get install -y --reinstall ca-certificates && \
apt-get install -y --no-install-recommends \
git \
wget \
libxml2 \
libgl-dev \
libgl1 \
gcc \
g++ \
procps && \
rm -rf /var/lib/apt/lists/*
# Set working directory
RUN mkdir -p /software/flowdock
WORKDIR /software/flowdock
# Clone FlowDock repository
RUN git clone https://github.com/BioinfoMachineLearning/FlowDock /software/flowdock
# Create conda environment
RUN conda env create -f /software/flowdock/environments/flowdock_environment.yaml
# Install local package and ProDy
RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh && \
conda activate FlowDock && \
pip install --no-cache-dir -e /software/flowdock && \
pip install --no-cache-dir --no-dependencies prody==2.4.1"
# Create checkpoints directory
RUN mkdir -p /software/flowdock/checkpoints
# Download pretrained weights
RUN wget -q https://zenodo.org/records/15066450/files/flowdock_checkpoints.tar.gz && \
tar -xzf flowdock_checkpoints.tar.gz && \
rm flowdock_checkpoints.tar.gz
# Activate conda environment by default
RUN echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
echo "conda activate FlowDock" >> ~/.bashrc
# Default shell
SHELL ["/bin/bash", "-l", "-c"]
CMD ["/bin/bash"]

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 BioinfoMachineLearning
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

30
Makefile Normal file
View File

@@ -0,0 +1,30 @@
help: ## Show help
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
clean: ## Clean autogenerated files
rm -rf dist
find . -type f -name "*.DS_Store" -ls -delete
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
find . | grep -E ".pytest_cache" | xargs rm -rf
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
rm -f .coverage
clean-logs: ## Clean logs
rm -rf logs/**
format: ## Run pre-commit hooks
pre-commit run -a
sync: ## Merge changes from main branch to your current branch
git pull
git pull origin main
test: ## Run not slow tests
pytest -k "not slow"
test-full: ## Run all tests
pytest
train: ## Train the model
python flowdock/train.py

471
README.md Normal file
View File

@@ -0,0 +1,471 @@
<div align="center">
# FlowDock
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
<!-- <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br> -->
[![Paper](http://img.shields.io/badge/paper-arxiv.2412.10966-B31B1B.svg)](https://arxiv.org/abs/2412.10966)
[![Conference](http://img.shields.io/badge/ISMB-2025-4b44ce.svg)](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i198/8199366)
[![Data DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15066450.svg)](https://doi.org/10.5281/zenodo.15066450)
<img src="./img/FlowDock.png" width="600">
</div>
## Description
This is the official codebase of the paper
**FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction**
\[[arXiv](https://arxiv.org/abs/2412.10966)\] \[[ISMB](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i198/8199366)\] \[[Neurosnap](https://neurosnap.ai/service/FlowDock)\] \[[Tamarind Bio](https://app.tamarind.bio/tools/flowdock)\]
<div align="center">
![Animation of a flow model-predicted 3D protein-ligand complex structure visualized successively](img/6I67.gif)
![Animation of a flow model-predicted 3D protein-multi-ligand complex structure visualized successively](img/T1152.gif)
</div>
## Contents
- [FlowDock](#flowdock)
- [Description](#description)
- [Contents](#contents)
- [Installation](#installation)
- [How to prepare data for `FlowDock`](#how-to-prepare-data-for-flowdock)
- [Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint)](#generating-esm2-embeddings-for-each-protein-optional-cached-input-data-available-on-sharepoint)
- [Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo)](#predicting-apo-protein-structures-using-esmfold-optional-cached-data-available-on-zenodo)
- [How to train `FlowDock`](#how-to-train-flowdock)
- [How to evaluate `FlowDock`](#how-to-evaluate-flowdock)
- [How to create comparative plots of benchmarking results](#how-to-create-comparative-plots-of-benchmarking-results)
- [How to predict new protein-ligand complex structures and their affinities using `FlowDock`](#how-to-predict-new-protein-ligand-complex-structures-and-their-affinities-using-flowdock)
- [For developers](#for-developers)
- [Docker](#docker)
- [Acknowledgements](#acknowledgements)
- [License](#license)
- [Citing this work](#citing-this-work)
## Installation
<details>
Install Mamba
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh # accept all terms and install to the default location
rm Miniforge3-$(uname)-$(uname -m).sh # (optionally) remove installer after using it
source ~/.bashrc # alternatively, one can restart their shell session to achieve the same result
```
Install dependencies
```bash
# clone project
git clone https://github.com/BioinfoMachineLearning/FlowDock
cd FlowDock
# create conda environment
mamba env create -f environments/flowdock_environment.yaml
conda activate FlowDock # NOTE: one still needs to use `conda` to (de)activate environments
pip3 install -e . # install local project as package
pip3 install prody==2.4.1 --no-dependencies # install ProDy without NumPy dependency
```
Download checkpoints
```bash
# pretrained NeuralPLexer weights
cd checkpoints/
wget https://zenodo.org/records/10373581/files/neuralplexermodels_downstream_datasets_predictions.zip
unzip neuralplexermodels_downstream_datasets_predictions.zip
rm neuralplexermodels_downstream_datasets_predictions.zip
cd ../
```
```bash
# pretrained FlowDock weights
wget https://zenodo.org/records/15066450/files/flowdock_checkpoints.tar.gz
tar -xzf flowdock_checkpoints.tar.gz
rm flowdock_checkpoints.tar.gz
```
Download preprocessed datasets
```bash
# cached input data for training/validation/testing
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/ER1hctIBhDVFjM7YepOI6WcBXNBm4_e6EBjFEHAM1A3y5g?download=1"
tar -xzf flowdock_data_cache.tar.gz
rm flowdock_data_cache.tar.gz
# cached data for PDBBind, Binding MOAD, DockGen, and the PDB-based van der Mers (vdM) dataset
wget https://zenodo.org/records/15066450/files/flowdock_pdbbind_data.tar.gz
tar -xzf flowdock_pdbbind_data.tar.gz
rm flowdock_pdbbind_data.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_moad_data.tar.gz
tar -xzf flowdock_moad_data.tar.gz
rm flowdock_moad_data.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_dockgen_data.tar.gz
tar -xzf flowdock_dockgen_data.tar.gz
rm flowdock_dockgen_data.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_pdbsidechain_data.tar.gz
tar -xzf flowdock_pdbsidechain_data.tar.gz
rm flowdock_pdbsidechain_data.tar.gz
```
</details>
## How to prepare data for `FlowDock`
<details>
**NOTE:** The following steps (besides downloading PDBBind and Binding MOAD's PDB files) are only necessary if one wants to fully process each of the following datasets manually.
Otherwise, preprocessed versions of each dataset can be found on [Zenodo](https://zenodo.org/records/15066450).
Download data
```bash
# fetch preprocessed PDBBind and Binding MOAD (as well as the optional DockGen and vdM datasets)
cd data/
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/EXesf4oh6ztOusGqFcDyqP0Bvk-LdJ1DagEl8GNK-HxDtg?download=1"
wget https://zenodo.org/records/10656052/files/BindingMOAD_2020_processed.tar
wget https://zenodo.org/records/10656052/files/DockGen.tar
wget https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02.tar.gz
mv EXesf4oh6ztOusGqFcDyqP0Bvk-LdJ1DagEl8GNK-HxDtg?download=1 PDBBind.tar.gz
tar -xzf PDBBind.tar.gz
tar -xf BindingMOAD_2020_processed.tar
tar -xf DockGen.tar
tar -xzf pdb_2021aug02.tar.gz
rm PDBBind.tar.gz BindingMOAD_2020_processed.tar DockGen.tar pdb_2021aug02.tar.gz
mkdir pdbbind/ moad/ pdbsidechain/
mv PDBBind_processed/ pdbbind/
mv BindingMOAD_2020_processed/ moad/
mv pdb_2021aug02/ pdbsidechain/
cd ../
```
Lastly, to finetune `FlowDock` using the `PLINDER` dataset, one must first prepare this data for training
```bash
# fetch PLINDER data (NOTE: requires ~1 hour to download and ~750G of storage)
export PLINDER_MOUNT="$(pwd)/data/PLINDER"
mkdir -p "$PLINDER_MOUNT" # create the directory if it doesn't exist
plinder_download -y
```
### Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint)
To generate the ESM2 embeddings for the protein inputs,
first create all the corresponding FASTA files for each protein sequence
```bash
python flowdock/data/components/esm_embedding_preparation.py --dataset pdbbind --data_dir data/pdbbind/PDBBind_processed/ --out_file data/pdbbind/pdbbind_sequences.fasta
python flowdock/data/components/esm_embedding_preparation.py --dataset moad --data_dir data/moad/BindingMOAD_2020_processed/pdb_protein/ --out_file data/moad/moad_sequences.fasta
python flowdock/data/components/esm_embedding_preparation.py --dataset dockgen --data_dir data/DockGen/processed_files/ --out_file data/DockGen/dockgen_sequences.fasta
python flowdock/data/components/esm_embedding_preparation.py --dataset pdbsidechain --data_dir data/pdbsidechain/pdb_2021aug02/pdb/ --out_file data/pdbsidechain/pdbsidechain_sequences.fasta
```
Then, generate all ESM2 embeddings in batch using the ESM repository's helper script
```bash
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/pdbbind/pdbbind_sequences.fasta data/pdbbind/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/moad/moad_sequences.fasta data/moad/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/DockGen/dockgen_sequences.fasta data/DockGen/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/pdbsidechain/pdbsidechain_sequences.fasta data/pdbsidechain/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
```
### Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo)
To generate the apo version of each protein structure,
first create ESMFold-ready versions of the combined FASTA files
prepared above by the script `esm_embedding_preparation.py`
for the PDBBind, Binding MOAD, DockGen, and PDBSidechain datasets, respectively
```bash
python flowdock/data/components/esmfold_sequence_preparation.py dataset=pdbbind
python flowdock/data/components/esmfold_sequence_preparation.py dataset=moad
python flowdock/data/components/esmfold_sequence_preparation.py dataset=dockgen
python flowdock/data/components/esmfold_sequence_preparation.py dataset=pdbsidechain
```
Then, predict each apo protein structure using ESMFold's batch
inference script
```bash
# Note: Having a CUDA-enabled device available when running this script is highly recommended
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/pdbbind/pdbbind_esmfold_sequences.fasta -o data/pdbbind/pdbbind_esmfold_structures --cuda-device-index 0 --skip-existing
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/moad/moad_esmfold_sequences.fasta -o data/moad/moad_esmfold_structures --cuda-device-index 0 --skip-existing
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/DockGen/dockgen_esmfold_sequences.fasta -o data/DockGen/dockgen_esmfold_structures --cuda-device-index 0 --skip-existing
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/pdbsidechain/pdbsidechain_esmfold_sequences.fasta -o data/pdbsidechain/pdbsidechain_esmfold_structures --cuda-device-index 0 --skip-existing
```
Align each apo protein structure to its corresponding
holo protein structure counterpart in PDBBind, Binding MOAD, and PDBSidechain,
taking ligand conformations into account during each alignment
```bash
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=pdbbind num_workers=1
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=moad num_workers=1
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=dockgen num_workers=1
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=pdbsidechain num_workers=1
```
Lastly, assess the apo-to-holo alignments in terms of statistics and structural metrics
to enable runtime-dynamic dataset filtering using such information
```bash
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=pdbbind usalign_exec_path=$MY_USALIGN_EXEC_PATH
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=moad usalign_exec_path=$MY_USALIGN_EXEC_PATH
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=dockgen usalign_exec_path=$MY_USALIGN_EXEC_PATH
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=pdbsidechain usalign_exec_path=$MY_USALIGN_EXEC_PATH
```
</details>
## How to train `FlowDock`
<details>
Train model with default configuration
```bash
# train on CPU
python flowdock/train.py trainer=cpu
# train on GPU
python flowdock/train.py trainer=gpu
```
Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
```bash
python flowdock/train.py experiment=experiment_name.yaml
```
For example, reproduce `FlowDock`'s default model training run
```bash
python flowdock/train.py experiment=flowdock_fm
```
**Note:** You can override any parameter from command line like this
```bash
python flowdock/train.py experiment=flowdock_fm trainer.max_epochs=20 data.batch_size=8
```
For example, override parameters to finetune `FlowDock`'s pretrained weights using a new dataset such as [PLINDER](https://www.plinder.sh/)
```bash
python flowdock/train.py experiment=flowdock_fm data=plinder ckpt_path=checkpoints/esmfold_prior_paper_weights.ckpt
```
</details>
## How to evaluate `FlowDock`
<details>
To reproduce `FlowDock`'s evaluation results for structure prediction, please refer to its documentation in version `0.6.0-FlowDock` of the [PoseBench](https://github.com/BioinfoMachineLearning/PoseBench/tree/0.6.0-FlowDock?tab=readme-ov-file#how-to-run-inference-with-flowdock) GitHub repository.
To reproduce `FlowDock`'s evaluation results for binding affinity prediction using the PDBBind dataset
```bash
python flowdock/eval.py data.test_datasets=[pdbbind] ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt trainer=gpu
... # re-run two more times to gather triplicate results
```
</details>
## How to create comparative plots of benchmarking results
<details>
Download baseline method predictions and results
```bash
# cached predictions and evaluation metrics for reproducing structure prediction paper results
wget https://zenodo.org/records/15066450/files/alphafold3_baseline_method_predictions.tar.gz
tar -xzf alphafold3_baseline_method_predictions.tar.gz
rm alphafold3_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/chai_baseline_method_predictions.tar.gz
tar -xzf chai_baseline_method_predictions.tar.gz
rm chai_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/diffdock_baseline_method_predictions.tar.gz
tar -xzf diffdock_baseline_method_predictions.tar.gz
rm diffdock_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/dynamicbind_baseline_method_predictions.tar.gz
tar -xzf dynamicbind_baseline_method_predictions.tar.gz
rm dynamicbind_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_baseline_method_predictions.tar.gz
tar -xzf flowdock_baseline_method_predictions.tar.gz
rm flowdock_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_aft_baseline_method_predictions.tar.gz
tar -xzf flowdock_aft_baseline_method_predictions.tar.gz
rm flowdock_aft_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_pft_baseline_method_predictions.tar.gz
tar -xzf flowdock_pft_baseline_method_predictions.tar.gz
rm flowdock_pft_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_esmfold_baseline_method_predictions.tar.gz
tar -xzf flowdock_esmfold_baseline_method_predictions.tar.gz
rm flowdock_esmfold_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_chai_baseline_method_predictions.tar.gz
tar -xzf flowdock_chai_baseline_method_predictions.tar.gz
rm flowdock_chai_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/flowdock_hp_baseline_method_predictions.tar.gz
tar -xzf flowdock_hp_baseline_method_predictions.tar.gz
rm flowdock_hp_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/neuralplexer_baseline_method_predictions.tar.gz
tar -xzf neuralplexer_baseline_method_predictions.tar.gz
rm neuralplexer_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/vina_p2rank_baseline_method_predictions.tar.gz
tar -xzf vina_p2rank_baseline_method_predictions.tar.gz
rm vina_p2rank_baseline_method_predictions.tar.gz
wget https://zenodo.org/records/15066450/files/rfaa_baseline_method_predictions.tar.gz
tar -xzf rfaa_baseline_method_predictions.tar.gz
rm rfaa_baseline_method_predictions.tar.gz
```
Reproduce paper result figures
```bash
jupyter notebook notebooks/casp16_binding_affinity_prediction_results_plotting.ipynb
jupyter notebook notebooks/casp16_flowdock_vs_multicom_ligand_structure_prediction_results_plotting.ipynb
jupyter notebook notebooks/dockgen_structure_prediction_results_plotting.ipynb
jupyter notebook notebooks/posebusters_benchmark_structure_prediction_chemical_similarity_analysis.ipynb
jupyter notebook notebooks/posebusters_benchmark_structure_prediction_results_plotting.ipynb
```
</details>
## How to predict new protein-ligand complex structures and their affinities using `FlowDock`
<details>
For example, generate new protein-ligand complexes for a pair of protein sequence and ligand SMILES strings such as those of the PDBBind 2020 test target `6i67`
```bash
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling input_receptor='YNKIVHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNKLEGKVPMHKLFLEMLEAKV' input_ligand='"c1cc2c(cc1O)CCCC2"' input_template=data/pdbbind/pdbbind_holo_aligned_esmfold_structures/6i67_holo_aligned_esmfold_protein.pdb sample_id='6i67' out_path='./6i67_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=true auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
```
Or, for example, generate new protein-ligand complexes for pairs of protein sequences and (multi-)ligand SMILES strings (delimited via `|`) such as those of the CASP15 target `T1152`
```bash
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling input_receptor='MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIPN' input_ligand='"CC(=O)NC1C(O)OC(CO)C(OC2OC(CO)C(OC3OC(CO)C(O)C(O)C3NC(C)=O)C(O)C2NC(C)=O)C1O"' input_template=data/test_cases/predicted_structures/T1152.pdb sample_id='T1152' out_path='./T1152_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=true auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
```
If you do not already have a template protein structure available for your target of interest, set `input_template=null` to instead have the sampling script predict the ESMFold structure of your provided `input_protein` sequence before running the sampling pipeline. For more information regarding the input arguments available for sampling, please refer to the config at `configs/sample.yaml`.
**NOTE:** To optimize prediction runtimes, a `csv_path` can be specified instead of the `input_receptor`, `input_ligand`, and `input_template` CLI arguments to perform *batched* prediction for a collection of protein-ligand sequence pairs, each represented as a CSV row containing column values for `id`, `input_receptor`, `input_ligand`, and `input_template`. Additionally, disabling `visualize_sample_trajectories` may reduce storage requirements when predicting a large batch of inputs.
For instance, one can perform batched prediction as follows:
```bash
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling csv_path='./data/test_cases/prediction_inputs/flowdock_batched_inputs.csv' out_path='./T1152_batch_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=false auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
```
</details>
## For developers
<details>
Set up `pre-commit` (one time only) for automatic code linting and formatting upon each `git commit`
```bash
pre-commit install
```
Manually reformat all files in the project, as desired
```bash
pre-commit run -a
```
Update dependencies in a `*_environment.yml` file
```bash
mamba env export > env.yaml # e.g., run this after installing new dependencies locally
diff environments/flowdock_environment.yaml env.yaml # note the differences and copy accepted changes back into e.g., `environments/flowdock_environment.yaml`
rm env.yaml # clean up temporary environment file
```
</details>
## Docker
<details>
Given that this tool has a number of dependencies, it may be easier to run it in a Docker container.
Pull from [Docker Hub](https://hub.docker.com/repository/docker/cford38/flowdock): `docker pull cford38/flowdock:latest`
Alternatively, build the Docker image locally:
```bash
docker build --platform linux/amd64 -t flowdock .
```
Then, run the Docker container (and mount your local `checkpoints/` directory)
```bash
docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --name flowdock -it flowdock /bin/bash
# docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --name flowdock -it cford38/flowdock:latest /bin/bash
```
</details>
## Acknowledgements
`FlowDock` builds upon the source code and data from the following projects:
- [DiffDock](https://github.com/gcorso/DiffDock)
- [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template)
- [NeuralPLexer](https://github.com/zrqiao/NeuralPLexer)
We thank all their contributors and maintainers!
## License
This project is covered under the **MIT License**.
## Citing this work
If you use the code or data associated with this package or otherwise find this work useful, please cite:
```bibtex
@inproceedings{morehead2025flowdock,
title={FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction},
author={Alex Morehead and Jianlin Cheng},
booktitle={Intelligent Systems for Molecular Biology (ISMB)},
year=2025,
}
```

6
citation.bib Normal file
View File

@@ -0,0 +1,6 @@
@inproceedings{morehead2025flowdock,
title={FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction},
author={Alex Morehead and Jianlin Cheng},
booktitle={Intelligent Systems for Molecular Biology (ISMB)},
year=2025,
}

1
configs/__init__.py Normal file
View File

@@ -0,0 +1 @@
# this file is needed here to include configs when building project as a package

View File

@@ -0,0 +1,21 @@
defaults:
- ema
- last_model_checkpoint
- learning_rate_monitor
- model_checkpoint
- model_summary
- rich_progress_bar
- _self_
last_model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "last"
monitor: null
verbose: True
auto_insert_metric_name: False
every_n_epochs: 1
save_on_train_epoch_end: True
enable_version_counter: False
model_summary:
max_depth: -1

View File

@@ -0,0 +1,15 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
monitor: ??? # quantity to be monitored, must be specified !!!
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
patience: 3 # number of checks with no improvement after which training will be stopped
verbose: False # verbosity mode
mode: "min" # "max" means higher metric value is better, can be also "min"
strict: True # whether to crash the training if monitor is not found in the validation metrics
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
# log_rank_zero_only: False # this keyword argument isn't available in stable version

View File

@@ -0,0 +1,10 @@
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
# Maintains an exponential moving average (EMA) of model weights.
# Look at the above link for more detailed information regarding the original implementation.
ema:
_target_: flowdock.models.components.callbacks.ema.EMA
decay: 0.999
validate_original_weights: false
every_n_steps: 4
cpu_offload: false

View File

@@ -0,0 +1,21 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
last_model_checkpoint:
# NOTE: this is a direct copy of `model_checkpoint`,
# which is necessary to make to work around the
# key-duplication limitations of YAML config files
_target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint
dirpath: null # directory to save the model file
filename: null # checkpoint filename
monitor: null # name of the logged metric which determines when model is improving
verbose: False # verbosity mode
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 1 # save k best models (determined by above metric)
mode: "min" # "max" means higher metric value is better, can be also "min"
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
save_weights_only: False # if True, then only the models weights will be saved
every_n_train_steps: null # number of training steps between checkpoints
train_time_interval: null # checkpoints are monitored at the specified time interval
every_n_epochs: null # number of epochs between checkpoints
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
enable_version_counter: True # enables versioning for checkpoint names

View File

@@ -0,0 +1,7 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html
learning_rate_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: null # set to `epoch` or `step` to log learning rate of all optimizers at the same interval, or set to `null` to log at individual interval according to the interval key of each scheduler
log_momentum: false # whether to also log the momentum values of the optimizer, if the optimizer has the `momentum` or `betas` attribute
log_weight_decay: false # whether to also log the weight decay values of the optimizer, if the optimizer has the `weight_decay` attribute

View File

@@ -0,0 +1,18 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
model_checkpoint:
_target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint
dirpath: null # directory to save the model file
filename: "best" # checkpoint filename
monitor: val_sampling/ligand_hit_score_2A_epoch # name of the logged metric which determines when model is improving
verbose: True # verbosity mode
save_last: False # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 1 # save k best models (determined by above metric)
mode: "max" # "max" means higher metric value is better, can be also "min"
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
save_weights_only: False # if True, then only the models weights will be saved
every_n_train_steps: null # number of training steps between checkpoints
train_time_interval: null # checkpoints are monitored at the specified time interval
every_n_epochs: null # number of epochs between checkpoints
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
enable_version_counter: False # enables versioning for checkpoint names

View File

@@ -0,0 +1,5 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
model_summary:
_target_: lightning.pytorch.callbacks.RichModelSummary
max_depth: 1 # the maximum depth of layer nesting that the summary will include

View File

View File

@@ -0,0 +1,4 @@
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
rich_progress_bar:
_target_: lightning.pytorch.callbacks.RichProgressBar

View File

@@ -0,0 +1,35 @@
# @package _global_
# default debugging setup, runs 1 full epoch
# other debugging configs can inherit from this one
# overwrite task name so debugging logs are stored in separate folder
task_name: "debug"
# disable callbacks and loggers during debugging
callbacks: null
logger: null
extras:
ignore_warnings: False
enforce_tags: False
# sets level of all command line loggers to 'DEBUG'
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
hydra:
job_logging:
root:
level: DEBUG
# use this to also set hydra loggers to 'DEBUG'
# verbose: True
trainer:
max_epochs: 1
accelerator: cpu # debuggers don't like gpus
devices: 1 # debuggers don't like multiprocessing
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
data:
num_workers: 0 # debuggers don't like multiprocessing
pin_memory: False # disable gpu memory pin

9
configs/debug/fdr.yaml Normal file
View File

@@ -0,0 +1,9 @@
# @package _global_
# runs 1 train, 1 validation and 1 test step
defaults:
- default
trainer:
fast_dev_run: true

12
configs/debug/limit.yaml Normal file
View File

@@ -0,0 +1,12 @@
# @package _global_
# uses only 1% of the training data and 5% of validation/test data
defaults:
- default
trainer:
max_epochs: 3
limit_train_batches: 0.01
limit_val_batches: 0.05
limit_test_batches: 0.05

View File

@@ -0,0 +1,13 @@
# @package _global_
# overfits to 3 batches
defaults:
- default
trainer:
max_epochs: 20
overfit_batches: 3
# model ckpt and early stopping need to be disabled during overfitting
callbacks: null

View File

@@ -0,0 +1,12 @@
# @package _global_
# runs with execution time profiling
defaults:
- default
trainer:
max_epochs: 1
profiler: "simple"
# profiler: "advanced"
# profiler: "pytorch"

View File

@@ -0,0 +1,2 @@
defaults:
- _self_

View File

@@ -0,0 +1 @@
_target_: lightning.fabric.plugins.environments.LightningEnvironment

View File

@@ -0,0 +1,3 @@
_target_: lightning.fabric.plugins.environments.SLURMEnvironment
auto_requeue: true
requeue_signal: null

49
configs/eval.yaml Normal file
View File

@@ -0,0 +1,49 @@
# @package _global_
defaults:
- data: combined # choose datamodule with `test_dataloader()` for evaluation
- model: flowdock_fm
- logger: null
- strategy: default
- trainer: default
- paths: default
- extras: default
- hydra: default
- environment: default
- _self_
task_name: "eval"
tags: ["eval", "combined", "flowdock_fm"]
# passing checkpoint path is necessary for evaluation
ckpt_path: ???
# seed for random number generators in pytorch, numpy and python.random
seed: null
# model arguments
model:
cfg:
mol_encoder:
from_pretrained: false
protein_encoder:
from_pretrained: false
relational_reasoning:
from_pretrained: false
contact_predictor:
from_pretrained: false
score_head:
from_pretrained: false
confidence:
from_pretrained: false
affinity:
from_pretrained: false
task:
freeze_mol_encoder: true
freeze_protein_encoder: false
freeze_relational_reasoning: false
freeze_contact_predictor: false
freeze_score_head: false
freeze_confidence: true
freeze_affinity: false

View File

@@ -0,0 +1,35 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=flowdock_fm
defaults:
- override /data: combined
- override /model: flowdock_fm
- override /callbacks: default
- override /trainer: default
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["flowdock_fm", "combined_dataset"]
seed: 496
trainer:
max_epochs: 300
check_val_every_n_epoch: 5 # NOTE: we increase this since validation steps involve full model sampling and evaluation
reload_dataloaders_every_n_epochs: 1
model:
optimizer:
lr: 2e-4
compile: false
data:
batch_size: 16
logger:
wandb:
tags: ${tags}
group: "FlowDock-FM"

View File

@@ -0,0 +1,8 @@
# disable python warnings if they annoy you
ignore_warnings: False
# ask user for tags if none are provided in the config
enforce_tags: True
# pretty print config tree at the start of the run using Rich library
print_config: True

View File

@@ -0,0 +1,50 @@
# @package _global_
# example hyperparameter optimization of some experiment with Optuna:
# python train.py -m hparams_search=mnist_optuna experiment=example
defaults:
- override /hydra/sweeper: optuna
# choose metric which will be optimized by Optuna
# make sure this is the correct name of some metric logged in lightning module!
optimized_metric: "val/loss"
# here we define Optuna hyperparameter search
# it optimizes for value returned from function with @hydra.main decorator
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
hydra:
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
sweeper:
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
# storage URL to persist optimization results
# for example, you can use SQLite if you set 'sqlite:///example.db'
storage: null
# name of the study to persist optimization results
study_name: null
# number of parallel workers
n_jobs: 1
# 'minimize' or 'maximize' the objective
direction: minimize
# total number of runs that will be executed
n_trials: 20
# choose Optuna hyperparameter sampler
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
sampler:
_target_: optuna.samplers.TPESampler
seed: 1234
n_startup_trials: 10 # number of random sampling runs before optimization starts
# define hyperparameter search space
params:
model.optimizer.lr: interval(0.0001, 0.1)
data.batch_size: choice(2, 4, 8, 16)
model.net.hidden_dim: choice(64, 128, 256)

View File

@@ -0,0 +1,19 @@
# https://hydra.cc/docs/configure_hydra/intro/
# enable color logging
defaults:
- override hydra_logging: colorlog
- override job_logging: colorlog
# output directory, generated dynamically on each run
run:
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
sweep:
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
subdir: ${hydra.job.num}
job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/${task_name}.log

0
configs/local/.gitkeep Normal file
View File

28
configs/logger/aim.yaml Normal file
View File

@@ -0,0 +1,28 @@
# https://aimstack.io/
# example usage in lightning module:
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
# `aim up`
aim:
_target_: aim.pytorch_lightning.AimLogger
repo: ${paths.root_dir} # .aim folder will be created here
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
# aim allows to group runs under experiment name
experiment: null # any string, set to "default" if not specified
train_metric_prefix: "train/"
val_metric_prefix: "val/"
test_metric_prefix: "test/"
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
system_tracking_interval: 10 # set to null to disable system metrics tracking
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
log_system_params: true
# enable/disable tracking console logs (default value is true)
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550

12
configs/logger/comet.yaml Normal file
View File

@@ -0,0 +1,12 @@
# https://www.comet.ml
comet:
_target_: lightning.pytorch.loggers.comet.CometLogger
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
save_dir: "${paths.output_dir}"
project_name: "FlowDock_FM"
rest_api_key: null
# experiment_name: ""
experiment_key: null # set to resume experiment
offline: False
prefix: ""

7
configs/logger/csv.yaml Normal file
View File

@@ -0,0 +1,7 @@
# csv logger built in lightning
csv:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
save_dir: "${paths.output_dir}"
name: "csv/"
prefix: ""

View File

@@ -0,0 +1,9 @@
# train with many loggers at once
defaults:
# - comet
- csv
# - mlflow
# - neptune
- tensorboard
- wandb

View File

@@ -0,0 +1,12 @@
# https://mlflow.org
mlflow:
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
# experiment_name: ""
# run_name: ""
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
tags: null
# save_dir: "./mlruns"
prefix: ""
artifact_location: null
# run_id: ""

View File

@@ -0,0 +1,9 @@
# https://neptune.ai
neptune:
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
project: username/FlowDock_FM
# name: ""
log_model_checkpoints: True
prefix: ""

View File

@@ -0,0 +1,10 @@
# https://www.tensorflow.org/tensorboard/
tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.output_dir}/tensorboard/"
name: null
log_graph: False
default_hp_metric: True
prefix: ""
# version: ""

16
configs/logger/wandb.yaml Normal file
View File

@@ -0,0 +1,16 @@
# https://wandb.ai
wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "FlowDock_FM"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
entity: "bml-lab" # set to name of your wandb team
group: ""
tags: []
job_type: ""

View File

@@ -0,0 +1,148 @@
_target_: flowdock.models.flowdock_fm_module.FlowDockFMLitModule
net:
_target_: flowdock.models.components.flowdock.FlowDock
_partial_: true
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 2e-4
weight_decay: 0.0
scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
_partial_: true
T_0: ${int_divide:${trainer.max_epochs},15}
T_mult: 2
eta_min: 1e-8
verbose: true
# compile model for faster training with pytorch 2.0
compile: false
# model arguments
cfg:
mol_encoder:
node_channels: 512
pair_channels: 64
n_atom_encodings: 23
n_bond_encodings: 4
n_atom_pos_encodings: 6
n_stereo_encodings: 14
n_attention_heads: 8
attention_head_dim: 8
hidden_dim: 2048
max_path_integral_length: 6
n_transformer_stacks: 8
n_heads: 8
n_patches: ${data.n_lig_patches}
checkpoint_file: ${oc.env:PROJECT_ROOT}/checkpoints/neuralplexermodels_downstream_datasets_predictions/models/complex_structure_prediction.ckpt
megamolbart: null
from_pretrained: true
protein_encoder:
use_esm_embedding: true
esm_version: esm2_t33_650M_UR50D
esm_repr_layer: 33
residue_dim: 512
plm_embed_dim: 1280
n_aa_types: 21
atom_padding_dim: 37
n_atom_types: 4 # [C, N, O, S]
n_patches: ${data.n_protein_patches}
n_attention_heads: 8
scalar_dim: 16
point_dim: 4
pair_dim: 64
n_heads: 8
head_dim: 8
max_residue_degree: 32
n_encoder_stacks: 2
from_pretrained: true
relational_reasoning:
from_pretrained: true
contact_predictor:
n_stacks: 4
dropout: 0.01
from_pretrained: true
score_head:
fiber_dim: 64
hidden_dim: 512
n_stacks: 4
max_atom_degree: 8
from_pretrained: true
confidence:
enabled: true # whether the confidence prediction head is to be used e.g., during inference
fiber_dim: ${..score_head.fiber_dim}
hidden_dim: ${..score_head.hidden_dim}
n_stacks: ${..score_head.n_stacks}
from_pretrained: true
affinity:
enabled: true # whether the affinity prediction head is to be used e.g., during inference
fiber_dim: ${..score_head.fiber_dim}
hidden_dim: ${..score_head.hidden_dim}
n_stacks: ${..score_head.n_stacks}
ligand_pooling: sum # NOTE: must be a value in (`sum`, `mean`)
dropout: 0.01
from_pretrained: false
latent_model: default
prior_type: esmfold # NOTE: must be a value in (`gaussian`, `harmonic`, `esmfold`)
task:
pretrained: null
ligands: true
epoch_frac: ${data.epoch_frac}
label_name: null
sequence_crop_size: 1600
edge_crop_size: ${data.edge_crop_size} # NOTE: for dynamic batching via `max_n_edges`
max_masking_rate: 0.0
n_modes: 8
dropout: 0.01
# pretraining: true
freeze_mol_encoder: true
freeze_protein_encoder: false
freeze_relational_reasoning: false
freeze_contact_predictor: true
freeze_score_head: false
freeze_confidence: true
freeze_affinity: false
use_template: true
use_plddt: false
block_contact_decoding_scheme: "beam"
frozen_ligand_backbone: false
frozen_protein_backbone: false
single_protein_batch: true
contact_loss_weight: 0.2
global_score_loss_weight: 0.2
ligand_score_loss_weight: 0.1
clash_loss_weight: 10.0
local_distgeom_loss_weight: 10.0
drmsd_loss_weight: 2.0
distogram_loss_weight: 0.05
plddt_loss_weight: 1.0
affinity_loss_weight: 0.1
aux_batch_freq: 10 # NOTE: e.g., `10` means that auxiliary estimation losses will be calculated every 10th batch
global_max_sigma: 5.0
internal_max_sigma: 2.0
detect_covalent: true
# runtime configs
float32_matmul_precision: highest
# sampling configs
constrained_inpainting: false
visualize_generated_samples: true
# testing configs
loss_mode: auxiliary_estimation # NOTE: must be one of (`structure_prediction`, `auxiliary_estimation`, `auxiliary_estimation_without_structure_prediction`)
num_steps: 20
sampler: VDODE # NOTE: must be one of (`ODE`, `VDODE`)
sampler_eta: 1.0 # NOTE: this corresponds to the variance diminishing factor for the `VDODE` sampler, which offers a trade-off between exploration (1.0) and exploitation (> 1.0)
start_time: 1.0
eval_structure_prediction: false # whether to evaluate structure prediction performance (`true`) or instead only binding affinity performance (`false`) when running a test epoch
# overfitting configs
overfitting_example_name: ${data.overfitting_example_name}

View File

@@ -0,0 +1,21 @@
# path to root directory
# this requires PROJECT_ROOT environment variable to exist
# you can replace it with "." if you want the root to be the current working directory
root_dir: ${oc.env:PROJECT_ROOT}
# path to data directory
data_dir: ${paths.root_dir}/data/
# path to logging directory
log_dir: ${paths.root_dir}/logs/
# path to output directory, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like ckpts and metrics
output_dir: ${hydra:runtime.output_dir}
# path to working directory
work_dir: ${hydra:runtime.cwd}
# path to the directory containing the model checkpoints
ckpt_dir: ${paths.root_dir}/checkpoints/

78
configs/sample.yaml Normal file
View File

@@ -0,0 +1,78 @@
# @package _global_
defaults:
- data: combined # NOTE: this will not be referenced during sampling
- model: flowdock_fm
- logger: null
- strategy: default
- trainer: default
- paths: default
- extras: default
- hydra: default
- environment: default
- _self_
task_name: "sample"
tags: ["sample", "combined", "flowdock_fm"]
# passing checkpoint path is necessary for sampling
ckpt_path: ???
# seed for random number generators in pytorch, numpy and python.random
seed: null
# sampling arguments
sampling_task: batched_structure_sampling # NOTE: must be one of (`batched_structure_sampling`)
sample_id: null # optional identifier for the sampling run
input_receptor: null # NOTE: must be either a protein sequence string (with chains separated by `|`) or a path to a PDB file (from which protein chain sequences will be parsed)
input_ligand: null # NOTE: must be either a ligand SMILES string (with chains/fragments separated by `|`) or a path to a ligand SDF file (from which ligand SMILES will be parsed)
input_template: null # path to a protein PDB file to use as a starting protein template for sampling (with an ESMFold prior model)
out_path: ??? # path to which to save the output PDB and SDF files
n_samples: 5 # number of structures to sample
chunk_size: 5 # number of structures to concurrently sample within each batch segment - NOTE: `n_samples` should be evenly divisible by `chunk_size` to produce the expected number of outputs
num_steps: 40 # number of sampling steps to perform
latent_model: null # if provided, the type of latent model to use
sampler: VDODE # sampling algorithm to use - NOTE: must be one of (`ODE`, `VDODE`)
sampler_eta: 1.0 # the variance diminishing factor for the `VDODE` sampler - NOTE: offers a trade-off between exploration (1.0) and exploitation (> 1.0)
start_time: "1.0" # time at which to start sampling
max_chain_encoding_k: -1 # maximum number of chains to encode in the chain encoding
exact_prior: false # whether to use the "ground-truth" binding site for sampling, if available
prior_type: esmfold # the type of prior to use for sampling - NOTE: must be one of (`gaussian`, `harmonic`, `esmfold`)
discard_ligand: false # whether to discard a given input ligand during sampling
discard_sdf_coords: true # whether to discard the input ligand's 3D structure during sampling, if available
detect_covalent: false # whether to detect covalent bonds between the input receptor and ligand
use_template: true # whether to use the input protein template for sampling if one is provided
separate_pdb: true # whether to save separate PDB files for each sampled structure instead of simply a single PDB file
rank_outputs_by_confidence: true # whether to rank the sampled structures by estimated confidence
plddt_ranking_type: ligand # the type of plDDT ranking to apply to generated samples - NOTE: must be one of (`protein`, `ligand`, `protein_ligand`)
visualize_sample_trajectories: false # whether to visualize the generated samples' trajectories
auxiliary_estimation_only: false # whether to only estimate auxiliary outputs (e.g., confidence, affinity) for the input (generated) samples (potentially derived from external sources)
csv_path: null # if provided, the CSV file (with columns `id`, `input_receptor`, `input_ligand`, and `input_template`) from which to parse input receptors and ligands for sampling, overriding the `input_receptor` and `input_ligand` arguments in the process and ignoring the `input_template` for now
esmfold_chunk_size: null # chunks axial attention computation to reduce memory usage from O(L^2) to O(L); equivalent to running a for loop over chunks of of each dimension; lower values will result in lower memory usage at the cost of speed; recommended values: 128, 64, 32
# model arguments
model:
cfg:
mol_encoder:
from_pretrained: false
protein_encoder:
from_pretrained: false
relational_reasoning:
from_pretrained: false
contact_predictor:
from_pretrained: false
score_head:
from_pretrained: false
confidence:
from_pretrained: false
affinity:
from_pretrained: false
task:
freeze_mol_encoder: true
freeze_protein_encoder: false
freeze_relational_reasoning: false
freeze_contact_predictor: false
freeze_score_head: false
freeze_confidence: true
freeze_affinity: false

View File

@@ -0,0 +1,4 @@
_target_: lightning.pytorch.strategies.DDPStrategy
static_graph: false
gradient_as_bucket_view: false
find_unused_parameters: true

View File

@@ -0,0 +1,5 @@
_target_: lightning.pytorch.strategies.DDPStrategy
static_graph: false
gradient_as_bucket_view: false
find_unused_parameters: true
start_method: spawn

View File

@@ -0,0 +1,5 @@
_target_: lightning.pytorch.strategies.DeepSpeedStrategy
stage: 2
offload_optimizer: false
allgather_bucket_size: 200_000_000
reduce_bucket_size: 200_000_000

View File

@@ -0,0 +1,2 @@
defaults:
- _self_

View File

@@ -0,0 +1,12 @@
_target_: lightning.pytorch.strategies.FSDPStrategy
sharding_strategy: ${resolve_variable:torch.distributed.fsdp.ShardingStrategy.FULL_SHARD}
cpu_offload: null
activation_checkpointing: null
mixed_precision:
_target_: torch.distributed.fsdp.MixedPrecision
param_dtype: null
reduce_dtype: null
buffer_dtype: null
keep_low_precision_grads: false
cast_forward_inputs: false
cast_root_forward_inputs: true

View File

@@ -0,0 +1,4 @@
_target_: lightning.pytorch.strategies.DDPStrategy
static_graph: true
gradient_as_bucket_view: true
find_unused_parameters: false

51
configs/train.yaml Normal file
View File

@@ -0,0 +1,51 @@
# @package _global_
# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: combined
- model: flowdock_fm
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- strategy: default
- trainer: default
- paths: default
- extras: default
- hydra: default
- environment: default
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null
# config for hyperparameter optimization
- hparams_search: null
# optional local config for machine/user specific settings
# it's optional since it doesn't need to exist and is excluded from version control
- optional local: default
# debugging config (enable through command line, e.g. `python train.py debug=default)
- debug: null
# task name, determines output directory path
task_name: "train"
# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
tags: ["train", "combined", "flowdock_fm"]
# set False to skip model training
train: True
# evaluate on test set, using best model weights achieved during training
# lightning chooses best weights based on the metric specified in checkpoint callback
test: False
# simply provide checkpoint path to resume training
ckpt_path: null
# seed for random number generators in pytorch, numpy and python.random
seed: null

5
configs/trainer/cpu.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: cpu
devices: 1

9
configs/trainer/ddp.yaml Normal file
View File

@@ -0,0 +1,9 @@
defaults:
- default
strategy: ddp
accelerator: gpu
devices: 4
num_nodes: 1
sync_batchnorm: True

View File

@@ -0,0 +1,7 @@
defaults:
- default
# simulate DDP on CPU, useful for debugging
accelerator: cpu
devices: 2
strategy: ddp_spawn

View File

@@ -0,0 +1,9 @@
defaults:
- default
strategy: ddp_spawn
accelerator: gpu
devices: 4
num_nodes: 1
sync_batchnorm: True

View File

@@ -0,0 +1,29 @@
_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.output_dir}
min_epochs: 1 # prevents early stopping
max_epochs: 10
accelerator: cpu
devices: 1
# mixed precision for extra speed-up
# precision: 16
# perform a validation loop every N training epochs
check_val_every_n_epoch: 1
# set True to to ensure deterministic results
# makes training slower but gives more reproducibility than just setting seeds
deterministic: False
# determine the frequency of how often to reload the dataloaders
reload_dataloaders_every_n_epochs: 1
# if `gradient_clip_val` is not `null`, gradients will be norm-clipped during training
gradient_clip_algorithm: norm
gradient_clip_val: 1.0
# if `num_sanity_val_steps` is > 0, then specifically that many validation steps will be run during the first call to `trainer.fit`
num_sanity_val_steps: 0

5
configs/trainer/gpu.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: gpu
devices: 1

5
configs/trainer/mps.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: mps
devices: 1

View File

@@ -0,0 +1,540 @@
name: FlowDock
channels:
- pytorch
- pyg
- senyan.dev
- nvidia
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=3_kmp_llvm
- aiohappyeyeballs=2.6.1=pyhd8ed1ab_0
- aiohttp=3.11.13=py310h89163eb_0
- aiosignal=1.3.2=pyhd8ed1ab_0
- alsa-lib=1.2.13=hb9d3cd8_0
- ambertools=24.8=cuda_None_nompi_py310h834fefc_101
- annotated-types=0.7.0=pyhd8ed1ab_1
- anyio=4.8.0=pyhd8ed1ab_0
- aom=3.9.1=hac33072_0
- argon2-cffi=23.1.0=pyhd8ed1ab_1
- argon2-cffi-bindings=21.2.0=py310ha75aee5_5
- arpack=3.9.1=nompi_hf03ea27_102
- arrow=1.3.0=pyhd8ed1ab_1
- asttokens=3.0.0=pyhd8ed1ab_1
- async-lru=2.0.4=pyhd8ed1ab_1
- async-timeout=5.0.1=pyhd8ed1ab_1
- attr=2.5.1=h166bdaf_1
- attrs=25.3.0=pyh71513ae_0
- babel=2.17.0=pyhd8ed1ab_0
- beautifulsoup4=4.13.3=pyha770c72_0
- blas=2.116=mkl
- blas-devel=3.9.0=16_linux64_mkl
- bleach=6.2.0=pyh29332c3_4
- bleach-with-css=6.2.0=h82add2a_4
- blosc=1.21.6=he440d0b_1
- brotli=1.1.0=hb9d3cd8_2
- brotli-bin=1.1.0=hb9d3cd8_2
- brotli-python=1.1.0=py310hf71b8c6_2
- bson=0.5.10=pyhd8ed1ab_0
- bzip2=1.0.8=h4bc722e_7
- c-ares=1.34.4=hb9d3cd8_0
- c-blosc2=2.15.2=h3122c55_1
- ca-certificates=2025.1.31=hbcca054_0
- cached-property=1.5.2=hd8ed1ab_1
- cached_property=1.5.2=pyha770c72_1
- cachetools=5.5.2=pyhd8ed1ab_0
- cairo=1.18.4=h3394656_0
- certifi=2025.1.31=pyhd8ed1ab_0
- cffi=1.17.1=py310h8deb56e_0
- chardet=5.2.0=pyhd8ed1ab_3
- charset-normalizer=3.4.1=pyhd8ed1ab_0
- colorama=0.4.6=pyhd8ed1ab_1
- comm=0.2.2=pyhd8ed1ab_1
- contourpy=1.3.1=py310h3788b33_0
- cpython=3.10.16=py310hd8ed1ab_1
- cuda-cudart=11.8.89=0
- cuda-cupti=11.8.87=0
- cuda-libraries=11.8.0=0
- cuda-nvrtc=11.8.89=0
- cuda-nvtx=11.8.86=0
- cuda-runtime=11.8.0=0
- cuda-version=11.8=h70ddcb2_3
- cudatoolkit=11.8.0=h4ba93d1_13
- cudatoolkit-dev=11.8.0=h1fa729e_6
- cycler=0.12.1=pyhd8ed1ab_1
- cyrus-sasl=2.1.27=h54b06d7_7
- dav1d=1.2.1=hd590300_0
- dbus=1.13.6=h5008d03_3
- debugpy=1.8.13=py310hf71b8c6_0
- decorator=5.2.1=pyhd8ed1ab_0
- defusedxml=0.7.1=pyhd8ed1ab_0
- deprecated=1.2.18=pyhd8ed1ab_0
- exceptiongroup=1.2.2=pyhd8ed1ab_1
- expat=2.6.4=h5888daf_0
- ffmpeg=7.1.1=gpl_h24e5c1d_701
- fftw=3.3.10=nompi_hf1063bd_110
- filelock=3.18.0=pyhd8ed1ab_0
- flexcache=0.3=pyhd8ed1ab_1
- flexparser=0.4=pyhd8ed1ab_1
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
- font-ttf-inconsolata=3.000=h77eed37_0
- font-ttf-source-code-pro=2.038=h77eed37_0
- font-ttf-ubuntu=0.83=h77eed37_3
- fontconfig=2.15.0=h7e30c49_1
- fonts-conda-ecosystem=1=0
- fonts-conda-forge=1=0
- fonttools=4.56.0=py310h89163eb_0
- fqdn=1.5.1=pyhd8ed1ab_1
- freetype=2.13.3=h48d6fc4_0
- freetype-py=2.3.0=pyhd8ed1ab_0
- fribidi=1.0.10=h36c2ea0_0
- frozenlist=1.5.0=py310h89163eb_1
- fsspec=2025.3.0=pyhd8ed1ab_0
- gdk-pixbuf=2.42.12=hb9ae30d_0
- gettext=0.23.1=h5888daf_0
- gettext-tools=0.23.1=h5888daf_0
- gmp=6.3.0=hac33072_2
- gmpy2=2.1.5=py310he8512ff_3
- graphite2=1.3.13=h59595ed_1003
- greenlet=3.1.1=py310hf71b8c6_1
- h11=0.14.0=pyhd8ed1ab_1
- h2=4.2.0=pyhd8ed1ab_0
- harfbuzz=10.4.0=h76408a6_0
- hdf4=4.2.15=h2a13503_7
- hdf5=1.14.4=nompi_h2d575fe_105
- hpack=4.1.0=pyhd8ed1ab_0
- httpcore=1.0.7=pyh29332c3_1
- httpx=0.28.1=pyhd8ed1ab_0
- hyperframe=6.1.0=pyhd8ed1ab_0
- icu=75.1=he02047a_0
- idna=3.10=pyhd8ed1ab_1
- importlib-metadata=8.6.1=pyha770c72_0
- importlib_resources=6.5.2=pyhd8ed1ab_0
- ipykernel=6.29.5=pyh3099207_0
- ipython=8.34.0=pyh907856f_0
- isoduration=20.11.0=pyhd8ed1ab_1
- jack=1.9.22=h7c63dc7_2
- jedi=0.19.2=pyhd8ed1ab_1
- jinja2=3.1.6=pyhd8ed1ab_0
- joblib=1.4.2=pyhd8ed1ab_1
- json5=0.10.0=pyhd8ed1ab_1
- jsonpointer=3.0.0=py310hff52083_1
- jsonschema=4.23.0=pyhd8ed1ab_1
- jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
- jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
- jupyter-lsp=2.2.5=pyhd8ed1ab_1
- jupyter_client=8.6.3=pyhd8ed1ab_1
- jupyter_core=5.7.2=pyh31011fe_1
- jupyter_events=0.12.0=pyh29332c3_0
- jupyter_server=2.15.0=pyhd8ed1ab_0
- jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
- jupyterlab=4.3.6=pyhd8ed1ab_0
- jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
- jupyterlab_server=2.27.3=pyhd8ed1ab_1
- jupyterlab_widgets=3.0.13=pyhd8ed1ab_1
- kernel-headers_linux-64=3.10.0=he073ed8_18
- keyutils=1.6.1=h166bdaf_0
- kiwisolver=1.4.7=py310h3788b33_0
- krb5=1.21.3=h659f571_0
- lame=3.100=h166bdaf_1003
- lcms2=2.17=h717163a_0
- ld_impl_linux-64=2.43=h712a8e2_4
- lerc=4.0.0=h27087fc_0
- level-zero=1.21.2=h84d6215_0
- libabseil=20250127.0=cxx17_hbbce691_0
- libaec=1.1.3=h59595ed_0
- libasprintf=0.23.1=h8e693c7_0
- libasprintf-devel=0.23.1=h8e693c7_0
- libass=0.17.3=hba53ac1_1
- libblas=3.9.0=16_linux64_mkl
- libboost=1.86.0=h6c02f8c_3
- libboost-python=1.86.0=py310ha2bacc8_3
- libbrotlicommon=1.1.0=hb9d3cd8_2
- libbrotlidec=1.1.0=hb9d3cd8_2
- libbrotlienc=1.1.0=hb9d3cd8_2
- libcap=2.75=h39aace5_0
- libcblas=3.9.0=16_linux64_mkl
- libcublas=11.11.3.6=0
- libcufft=10.9.0.58=0
- libcufile=1.9.1.3=0
- libcurand=10.3.5.147=0
- libcurl=8.12.1=h332b0f4_0
- libcusolver=11.4.1.48=0
- libcusparse=11.7.5.86=0
- libdb=6.2.32=h9c3ff4c_0
- libdeflate=1.23=h4ddbbb0_0
- libdrm=2.4.124=hb9d3cd8_0
- libedit=3.1.20250104=pl5321h7949ede_0
- libegl=1.7.0=ha4b6fd6_2
- libev=4.33=hd590300_2
- libexpat=2.6.4=h5888daf_0
- libffi=3.4.6=h2dba641_0
- libflac=1.4.3=h59595ed_0
- libgcc=14.2.0=h767d61c_2
- libgcc-ng=14.2.0=h69a702a_2
- libgcrypt-lib=1.11.0=hb9d3cd8_2
- libgettextpo=0.23.1=h5888daf_0
- libgettextpo-devel=0.23.1=h5888daf_0
- libgfortran=14.2.0=h69a702a_2
- libgfortran-ng=14.2.0=h69a702a_2
- libgfortran5=14.2.0=hf1ad2bd_2
- libgl=1.7.0=ha4b6fd6_2
- libglib=2.82.2=h2ff4ddf_1
- libglvnd=1.7.0=ha4b6fd6_2
- libglx=1.7.0=ha4b6fd6_2
- libgomp=14.2.0=h767d61c_2
- libgpg-error=1.51=hbd13f7d_1
- libhwloc=2.11.2=default_h0d58e46_1001
- libiconv=1.18=h4ce23a2_1
- libjpeg-turbo=3.0.0=hd590300_1
- liblapack=3.9.0=16_linux64_mkl
- liblapacke=3.9.0=16_linux64_mkl
- liblzma=5.6.4=hb9d3cd8_0
- libnetcdf=4.9.2=nompi_h5ddbaa4_116
- libnghttp2=1.64.0=h161d5f1_0
- libnpp=11.8.0.86=0
- libnsl=2.0.1=hd590300_0
- libntlm=1.8=hb9d3cd8_0
- libnvjpeg=11.9.0.86=0
- libogg=1.3.5=h4ab18f5_0
- libopenvino=2025.0.0=hdc3f47d_3
- libopenvino-auto-batch-plugin=2025.0.0=h4d9b6c2_3
- libopenvino-auto-plugin=2025.0.0=h4d9b6c2_3
- libopenvino-hetero-plugin=2025.0.0=h981d57b_3
- libopenvino-intel-cpu-plugin=2025.0.0=hdc3f47d_3
- libopenvino-intel-gpu-plugin=2025.0.0=hdc3f47d_3
- libopenvino-intel-npu-plugin=2025.0.0=hdc3f47d_3
- libopenvino-ir-frontend=2025.0.0=h981d57b_3
- libopenvino-onnx-frontend=2025.0.0=h0e684df_3
- libopenvino-paddle-frontend=2025.0.0=h0e684df_3
- libopenvino-pytorch-frontend=2025.0.0=h5888daf_3
- libopenvino-tensorflow-frontend=2025.0.0=h684f15b_3
- libopenvino-tensorflow-lite-frontend=2025.0.0=h5888daf_3
- libopus=1.3.1=h7f98852_1
- libpciaccess=0.18=hd590300_0
- libpng=1.6.47=h943b412_0
- libpq=17.4=h27ae623_0
- libprotobuf=5.29.3=h501fc15_0
- librdkit=2024.09.6=h84b0b3c_0
- librsvg=2.58.4=h49af25d_2
- libsndfile=1.2.2=hc60ed4a_1
- libsodium=1.0.20=h4ab18f5_0
- libsqlite=3.49.1=hee588c1_1
- libssh2=1.11.1=hf672d98_0
- libstdcxx=14.2.0=h8f9b012_2
- libstdcxx-ng=14.2.0=h4852527_2
- libsystemd0=257.4=h4e0b6ca_1
- libtiff=4.7.0=hd9ff511_3
- libudev1=257.4=hbe16f8c_1
- libunwind=1.6.2=h9c3ff4c_0
- liburing=2.9=h84d6215_0
- libusb=1.0.27=hb9d3cd8_101
- libuuid=2.38.1=h0b41bf4_0
- libva=2.22.0=h4f16b4b_2
- libvorbis=1.3.7=h9c3ff4c_0
- libvpx=1.14.1=hac33072_0
- libwebp-base=1.5.0=h851e524_0
- libxcb=1.17.0=h8a09558_0
- libxcrypt=4.4.36=hd590300_1
- libxkbcommon=1.8.1=hc4a0caf_0
- libxml2=2.13.6=h8d12d68_0
- libxslt=1.1.39=h76b75d6_0
- libzip=1.11.2=h6991a6a_0
- libzlib=1.3.1=hb9d3cd8_2
- llvm-openmp=15.0.7=h0cdce71_0
- lxml=5.3.1=py310h6ee67d5_0
- lz4-c=1.10.0=h5888daf_1
- markupsafe=3.0.2=py310h89163eb_1
- matplotlib-base=3.10.1=py310h68603db_0
- matplotlib-inline=0.1.7=pyhd8ed1ab_1
- mda-xdrlib=0.2.0=pyhd8ed1ab_1
- mdtraj=1.10.3=py310h4cdbd58_0
- mendeleev=0.20.1=pymin39_ha308f57_3
- mistune=3.1.2=pyhd8ed1ab_0
- mkl=2022.1.0=h84fe81f_915
- mkl-devel=2022.1.0=ha770c72_916
- mkl-include=2022.1.0=h84fe81f_915
- mpc=1.3.1=h24ddda3_1
- mpfr=4.2.1=h90cbb55_3
- mpg123=1.32.9=hc50e24c_0
- mpmath=1.3.0=pyhd8ed1ab_1
- multidict=6.1.0=py310h89163eb_2
- munkres=1.1.4=pyh9f0ad1d_0
- nbclient=0.10.2=pyhd8ed1ab_0
- nbconvert-core=7.16.6=pyh29332c3_0
- nbformat=5.10.4=pyhd8ed1ab_1
- ncurses=6.5=h2d0b736_3
- nest-asyncio=1.6.0=pyhd8ed1ab_1
- netcdf-fortran=4.6.1=nompi_ha5d1325_108
- networkx=3.4.2=pyh267e887_2
- notebook=7.3.3=pyhd8ed1ab_0
- notebook-shim=0.2.4=pyhd8ed1ab_1
- numexpr=2.7.3=py310hb5077e9_1
- ocl-icd=2.3.2=hb9d3cd8_2
- ocl-icd-system=1.0.0=1
- opencl-headers=2024.10.24=h5888daf_0
- openff-amber-ff-ports=0.0.4=pyhca7485f_0
- openff-forcefields=2024.09.0=pyhff2d567_0
- openff-interchange=0.4.2=pyhd8ed1ab_2
- openff-interchange-base=0.4.2=pyhd8ed1ab_2
- openff-toolkit=0.16.8=pyhd8ed1ab_2
- openff-toolkit-base=0.16.8=pyhd8ed1ab_2
- openff-units=0.3.0=pyhd8ed1ab_1
- openff-utilities=0.1.15=pyhd8ed1ab_0
- openh264=2.6.0=hc22cd8d_0
- openjpeg=2.5.3=h5fbd93e_0
- openldap=2.6.9=he970967_0
- openmm=8.2.0=py310h30bdd6a_2
- openmmforcefields=0.14.2=pyhd8ed1ab_0
- openssl=3.4.1=h7b32b05_0
- overrides=7.7.0=pyhd8ed1ab_1
- packaging=24.2=pyhd8ed1ab_2
- panedr=0.8.0=pyhd8ed1ab_1
- pango=1.56.2=h861ebed_0
- parmed=4.3.0=py310h78e4988_1
- parso=0.8.4=pyhd8ed1ab_1
- pcre2=10.44=hba22ea6_2
- pdbfixer=1.11=pyhd8ed1ab_0
- perl=5.32.1=7_hd590300_perl5
- pexpect=4.9.0=pyhd8ed1ab_1
- pickleshare=0.7.5=pyhd8ed1ab_1004
- pillow=11.1.0=py310h7e6dc6c_0
- pint=0.24.4=pyhd8ed1ab_1
- pip=25.0.1=pyh8b19718_0
- pixman=0.44.2=h29eaf8c_0
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
- platformdirs=4.3.6=pyhd8ed1ab_1
- prometheus_client=0.21.1=pyhd8ed1ab_0
- prompt-toolkit=3.0.50=pyha770c72_0
- psutil=7.0.0=py310ha75aee5_0
- pthread-stubs=0.4=hb9d3cd8_1002
- ptyprocess=0.7.0=pyhd8ed1ab_1
- pugixml=1.15=h3f63f65_0
- pulseaudio-client=17.0=hb77b528_0
- pure_eval=0.2.3=pyhd8ed1ab_1
- py-cpuinfo=9.0.0=pyhd8ed1ab_1
- pycairo=1.27.0=py310h25ff670_0
- pycparser=2.22=pyh29332c3_1
- pydantic=2.10.6=pyh3cfb1c2_0
- pydantic-core=2.27.2=py310h505e2c1_0
- pyedr=0.8.0=pyhd8ed1ab_1
- pyfiglet=0.8.post1=py_0
- pyg=2.5.2=py310_torch_2.2.0_cu118
- pygments=2.19.1=pyhd8ed1ab_0
- pyparsing=3.2.1=pyhd8ed1ab_0
- pysocks=1.7.1=pyha55dd90_7
- pytables=3.10.1=py310h431dcdc_4
- python=3.10.16=he725a3c_1_cpython
- python-constraint=1.4.0=pyhff2d567_1
- python-dateutil=2.9.0.post0=pyhff2d567_1
- python-fastjsonschema=2.21.1=pyhd8ed1ab_0
- python-tzdata=2025.1=pyhd8ed1ab_0
- python_abi=3.10=5_cp310
- pytorch=2.2.1=py3.10_cuda11.8_cudnn8.7.0_0
- pytorch-cuda=11.8=h7e8668a_6
- pytorch-mutex=1.0=cuda
- pytorch-scatter=2.1.2=py310_torch_2.2.0_cu118
- pytz=2025.1=pyhd8ed1ab_0
- pyyaml=6.0.2=py310h89163eb_2
- pyzmq=26.3.0=py310h71f11fc_0
- qhull=2020.2=h434a139_5
- rdkit=2024.09.6=py310hcd13295_0
- readline=8.2=h8c095d6_2
- referencing=0.36.2=pyh29332c3_0
- reportlab=4.3.1=py310ha75aee5_0
- requests=2.32.3=pyhd8ed1ab_1
- rfc3339-validator=0.1.4=pyhd8ed1ab_1
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
- rlpycairo=0.2.0=pyhd8ed1ab_0
- rpds-py=0.23.1=py310hc1293b2_0
- scikit-learn=1.6.1=py310h27f47ee_0
- scipy=1.15.2=py310h1d65ade_0
- sdl2=2.32.50=h9b8e6db_1
- sdl3=3.2.8=h3083f51_0
- send2trash=1.8.3=pyh0d859eb_1
- setuptools=75.8.2=pyhff2d567_0
- six=1.17.0=pyhd8ed1ab_0
- smirnoff99frosst=1.1.0=pyh44b312d_0
- snappy=1.2.1=h8bd8927_1
- sniffio=1.3.1=pyhd8ed1ab_1
- sqlalchemy=2.0.39=py310ha75aee5_1
- stack_data=0.6.3=pyhd8ed1ab_1
- svt-av1=3.0.1=h5888daf_0
- sympy=1.13.3=pyh2585a3b_105
- sysroot_linux-64=2.17=h0157908_18
- tbb=2021.13.0=hceb3a55_1
- terminado=0.18.1=pyh0d859eb_0
- threadpoolctl=3.6.0=pyhecae5ae_0
- tinycss2=1.4.0=pyhd8ed1ab_0
- tinydb=4.8.2=pyhd8ed1ab_1
- tk=8.6.13=noxft_h4845f30_101
- tomli=2.2.1=pyhd8ed1ab_1
- torchaudio=2.2.1=py310_cu118
- torchtriton=2.2.0=py310
- torchvision=0.17.1=py310_cu118
- tornado=6.4.2=py310ha75aee5_0
- tqdm=4.67.1=pyhd8ed1ab_1
- traitlets=5.14.3=pyhd8ed1ab_1
- types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
- typing-extensions=4.12.2=hd8ed1ab_1
- typing_extensions=4.12.2=pyha770c72_1
- typing_utils=0.1.0=pyhd8ed1ab_1
- tzdata=2025a=h78e105d_0
- unicodedata2=16.0.0=py310ha75aee5_0
- uri-template=1.3.0=pyhd8ed1ab_1
- urllib3=2.3.0=pyhd8ed1ab_0
- validators=0.34.0=pyhd8ed1ab_1
- wayland=1.23.1=h3e06ad9_0
- wayland-protocols=1.41=hd8ed1ab_0
- wcwidth=0.2.13=pyhd8ed1ab_1
- webcolors=24.11.1=pyhd8ed1ab_0
- webencodings=0.5.1=pyhd8ed1ab_3
- websocket-client=1.8.0=pyhd8ed1ab_1
- wheel=0.45.1=pyhd8ed1ab_1
- wrapt=1.17.2=py310ha75aee5_0
- x264=1!164.3095=h166bdaf_2
- x265=3.5=h924138e_3
- xkeyboard-config=2.43=hb9d3cd8_0
- xmltodict=0.14.2=pyhd8ed1ab_1
- xorg-libice=1.1.2=hb9d3cd8_0
- xorg-libsm=1.2.6=he73a12e_0
- xorg-libx11=1.8.12=h4f16b4b_0
- xorg-libxau=1.0.12=hb9d3cd8_0
- xorg-libxcursor=1.2.3=hb9d3cd8_0
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
- xorg-libxext=1.3.6=hb9d3cd8_0
- xorg-libxfixes=6.0.1=hb9d3cd8_0
- xorg-libxrender=0.9.12=hb9d3cd8_0
- xorg-libxscrnsaver=1.2.4=hb9d3cd8_0
- xorg-libxt=1.3.1=hb9d3cd8_0
- yaml=0.2.5=h7f98852_2
- yarl=1.18.3=py310h89163eb_1
- zeromq=4.3.5=h3b0a872_7
- zipp=3.21.0=pyhd8ed1ab_1
- zlib=1.3.1=hb9d3cd8_2
- zlib-ng=2.2.4=h7955e40_0
- zstandard=0.23.0=py310ha75aee5_1
- zstd=1.5.7=hb8e6e7a_1
- pip:
- absl-py==2.1.0
- alembic==1.15.1
- amberutils==21.0
- antlr4-python3-runtime==4.9.3
- autopage==0.5.2
- beartype==0.20.0
- biopandas==0.5.1
- biopython==1.79
- biotite==1.1.0
- biotraj==1.2.2
- cfgv==3.4.0
- cftime==1.6.4.post1
- click==8.1.8
- cliff==4.9.1
- cloudpathlib==0.21.0
- cmaes==0.11.1
- cmd2==2.5.11
- colorlog==6.9.0
- distlib==0.3.9
- git+https://github.com/NVIDIA/dllogger.git@0540a43971f4a8a16693a9de9de73c1072020769
- dm-tree==0.1.9
- docker-pycreds==0.4.0
- duckdb==1.2.1
- edgembar==3.0
- einops==0.8.1
- eval-type-backport==0.2.2
- executing==2.2.0
- fair-esm==2.0.0
- fairscale==0.4.13
- fastcore==1.7.29
- future==1.0.0
- fvcore==0.1.5.post20221221
- gcsfs==2025.3.0
- gemmi==0.7.0
- gitdb==4.0.12
- gitpython==3.1.44
- google-api-core==2.24.2
- google-auth==2.38.0
- google-auth-oauthlib==1.2.1
- google-cloud-core==2.4.3
- google-cloud-storage==3.1.0
- google-crc32c==1.6.0
- google-resumable-media==2.7.2
- googleapis-common-protos==1.69.1
- hydra-colorlog==1.2.0
- hydra-core==1.3.2
- hydra-optuna-sweeper==1.2.0
- identify==2.6.9
- iniconfig==2.0.0
- iopath==0.1.10
- ipython-genutils==0.2.0
- ipywidgets==7.8.5
- jupyterlab-widgets==1.1.11
- lightning==2.5.0.post0
- lightning-utilities==0.14.1
- looseversion==1.1.2
- lovely-numpy==0.2.13
- lovely-tensors==0.1.18
- mako==1.3.9
- markdown-it-py==3.0.0
- mdurl==0.1.2
- ml-collections==1.0.0
- mmcif==0.91.0
- mmpbsa-py==16.0
- mmtf-python==1.1.3
- mols2grid==2.0.0
- msgpack==1.1.0
- msgpack-numpy==0.4.8
- narwhals==1.30.0
- netcdf4==1.7.2
- nodeenv==1.9.1
- numpy==1.26.4
- oauthlib==3.2.2
- omegaconf==2.3.0
- git+https://github.com/amorehead/openfold.git@fe1275099639bf7e617e09ef24d6af778647dd64
- optuna==2.10.1
- packmol-memgen==2025.1.29
- pandas==2.2.3
- pandocfilters==1.5.1
- pbr==6.1.1
- pdb4amber==22.0
- plinder==0.2.24
- plotly==6.0.0
- pluggy==1.5.0
- portalocker==3.1.1
- posebusters==0.2.13
- git+https://git@github.com/zrqiao/power_spherical.git@290b1630c5f84e3bb0d61711046edcf6e47200d4
- pre-commit==4.1.0
- prettytable==3.15.1
# - prody==2.4.1 # NOTE: we must `pip` install Prody to skip its NumPy dependency
- propcache==0.3.0
- proto-plus==1.26.1
- protobuf==5.29.3
- pyarrow==19.0.1
- pyasn1==0.6.1
- pyasn1-modules==0.4.1
- pymsmt==22.0
- pyperclip==1.9.0
- pytest==8.3.5
- python-dotenv==1.0.1
- python-json-logger==3.3.0
- pytorch-lightning==2.5.0.post0
- git+https://github.com/facebookresearch/pytorch3d.git@3da7703c5ac10039645966deddffe8db52eab8c5
- pytraj==2.0.6
- requests-oauthlib==2.0.0
- rich==13.9.4
- rootutils==1.0.7
- rsa==4.9
- sander==22.0
- seaborn==0.13.2
- sentry-sdk==2.22.0
- setproctitle==1.3.5
- smmap==5.0.2
- soupsieve==2.6
- stevedore==5.4.1
- tabulate==0.9.0
- termcolor==2.5.0
- torchmetrics==1.6.3
- virtualenv==20.29.3
- wandb==0.19.8
- widgetsnbextension==3.6.10
- yacs==0.1.8

View File

@@ -0,0 +1,58 @@
name: flowdock
channels:
- pyg
- pytorch
- nvidia
- defaults
- conda-forge
dependencies:
- mendeleev=0.20.1=pymin39_ha308f57_3
- networkx=3.4.2=pyh267e887_2
- python=3.10.16=he725a3c_1_cpython
- pytorch=2.2.1=py3.10_cuda11.8_cudnn8.7.0_0
- pytorch-cuda=11.8=h7e8668a_6
- pytorch-mutex=1.0=cuda
- pytorch-scatter=2.1.2=py310_torch_2.2.0_cu118
- rdkit=2024.09.6=py310hcd13295_0
- scikit-learn=1.6.1=py310h27f47ee_0
- scipy=1.15.2=py310h1d65ade_0
- torchaudio=2.2.1=py310_cu118
- torchtriton=2.2.0=py310
- torchvision=0.17.1=py310_cu118
- tqdm=4.67.1=pyhd8ed1ab_1
- pip:
- beartype==0.20.0
- biopandas==0.5.1
- biopython==1.79
- biotite==1.1.0
- git+https://github.com/NVIDIA/dllogger.git@0540a43971f4a8a16693a9de9de73c1072020769
- dm-tree==0.1.9
- einops==0.8.1
- fair-esm==2.0.0
- fairscale==0.4.13
- gemmi==0.7.0
- hydra-colorlog==1.2.0
- hydra-core==1.3.2
- hydra-optuna-sweeper==1.2.0
- lightning==2.5.0.post0
- lightning-utilities==0.14.1
- lovely-numpy==0.2.13
- lovely-tensors==0.1.18
- ml-collections==1.0.0
- msgpack==1.1.0
- msgpack-numpy==0.4.8
- numpy==1.26.4
- omegaconf==2.3.0
- git+https://github.com/amorehead/openfold.git@fe1275099639bf7e617e09ef24d6af778647dd64
- pandas==2.2.3
- plinder==0.2.24
- plotly==6.0.0
- posebusters==0.2.13
# - prody==2.4.1 # NOTE: we must `pip` install Prody to skip its NumPy dependency
- pytorch-lightning==2.5.0.post0
- git+https://github.com/facebookresearch/pytorch3d.git@3da7703c5ac10039645966deddffe8db52eab8c5
- rich==13.9.4
- rootutils==1.0.7
- seaborn==0.13.2
- torchmetrics==1.6.3
- wandb==0.19.8

120
flowdock/__init__.py Normal file
View File

@@ -0,0 +1,120 @@
import importlib
import os
from beartype.typing import Any
from omegaconf import OmegaConf
METHOD_TITLE_MAPPING = {
"diffdock": "DiffDock",
"flowdock": "FlowDock",
"neuralplexer": "NeuralPLexer",
}
STANDARDIZED_DIR_METHODS = ["diffdock"]
def resolve_omegaconf_variable(variable_path: str) -> Any:
"""Resolve an OmegaConf variable path to its value."""
# split the string into parts using the dot separator
parts = variable_path.rsplit(".", 1)
# get the module name from the first part of the path
module_name = parts[0]
# dynamically import the module using the module name
try:
module = importlib.import_module(module_name)
# use the imported module to get the requested attribute value
attribute = getattr(module, parts[1])
except Exception:
module = importlib.import_module(".".join(module_name.split(".")[:-1]))
inner_module = ".".join(module_name.split(".")[-1:])
# use the imported module to get the requested attribute value
attribute = getattr(getattr(module, inner_module), parts[1])
return attribute
def resolve_dataset_path_dirname(dataset: str) -> str:
"""Resolve the dataset path directory name based on the dataset's name.
:param dataset: Name of the dataset.
:return: Directory name for the dataset path.
"""
return "DockGen" if dataset == "dockgen" else dataset
def resolve_method_input_csv_path(method: str, dataset: str) -> str:
"""Resolve the input CSV path for a given method.
:param method: The method name.
:param dataset: The dataset name.
:return: The input CSV path for the given method.
"""
if method in STANDARDIZED_DIR_METHODS or method in ["flowdock", "neuralplexer"]:
return os.path.join(
"forks",
METHOD_TITLE_MAPPING.get(method, method),
"inference",
f"{method}_{dataset}_inputs.csv",
)
else:
raise ValueError(f"Invalid method: {method}")
def resolve_method_title(method: str) -> str:
"""Resolve the method title for a given method.
:param method: The method name.
:return: The method title for the given method.
"""
return METHOD_TITLE_MAPPING.get(method, method)
def resolve_method_output_dir(
method: str,
dataset: str,
repeat_index: int,
) -> str:
"""Resolve the output directory for a given method.
:param method: The method name.
:param dataset: The dataset name.
:param repeat_index: The repeat index for the method.
:return: The output directory for the given method.
"""
if method in STANDARDIZED_DIR_METHODS or method in ["flowdock", "neuralplexer"]:
return os.path.join(
"forks",
METHOD_TITLE_MAPPING.get(method, method),
"inference",
f"{method}_{dataset}_output{'s' if method in ['flowdock', 'neuralplexer'] else ''}_{repeat_index}",
)
else:
raise ValueError(f"Invalid method: {method}")
def register_custom_omegaconf_resolvers():
"""Register custom OmegaConf resolvers."""
OmegaConf.register_new_resolver(
"resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path)
)
OmegaConf.register_new_resolver(
"resolve_dataset_path_dirname", lambda dataset: resolve_dataset_path_dirname(dataset)
)
OmegaConf.register_new_resolver(
"resolve_method_input_csv_path",
lambda method, dataset: resolve_method_input_csv_path(method, dataset),
)
OmegaConf.register_new_resolver(
"resolve_method_title", lambda method: resolve_method_title(method)
)
OmegaConf.register_new_resolver(
"resolve_method_output_dir",
lambda method, dataset, repeat_index: resolve_method_output_dir(
method, dataset, repeat_index
),
)
OmegaConf.register_new_resolver(
"int_divide", lambda dividend, divisor: int(dividend) // int(divisor)
)

165
flowdock/eval.py Normal file
View File

@@ -0,0 +1,165 @@
import os
import hydra
import lightning as L
import lovely_tensors as lt
import rootutils
import torch
from beartype.typing import Any, Dict, List, Tuple
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies.strategy import Strategy
from omegaconf import DictConfig, open_dict
lt.monkey_patch()
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from flowdock import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
from flowdock.utils import (
RankedLogger,
extras,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)
log = RankedLogger(__name__, rank_zero_only=True)
@task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: DictConfig configuration composed by Hydra.
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
"""
assert cfg.ckpt_path, "Please provide a checkpoint path to evaluate!"
assert os.path.exists(cfg.ckpt_path), f"Checkpoint path {cfg.ckpt_path} does not exist!"
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
)
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, stage="test")
# Establish model input arguments
with open_dict(cfg):
if cfg.model.cfg.task.start_time == "auto":
cfg.model.cfg.task.start_time = 1.0
else:
cfg.model.cfg.task.start_time = float(cfg.model.cfg.task.start_time)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
plugins = None
if "_target_" in cfg.environment:
log.info(f"Instantiating environment <{cfg.environment._target_}>")
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
strategy = getattr(cfg.trainer, "strategy", None)
if "_target_" in cfg.strategy:
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
if (
"mixed_precision" in strategy.__dict__
and getattr(strategy, "mixed_precision", None) is not None
):
strategy.mixed_precision.param_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
else None
)
strategy.mixed_precision.reduce_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
else None
)
strategy.mixed_precision.buffer_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
else None
)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = (
hydra.utils.instantiate(
cfg.trainer,
logger=logger,
plugins=plugins,
strategy=strategy,
)
if strategy is not None
else hydra.utils.instantiate(
cfg.trainer,
logger=logger,
plugins=plugins,
)
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
metric_dict = trainer.callback_metrics
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
def main(cfg: DictConfig) -> None:
"""Main entry point for evaluation.
:param cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)
evaluate(cfg)
if __name__ == "__main__":
register_custom_omegaconf_resolvers()
main()

View File

View File

View File

@@ -0,0 +1,452 @@
# Adapted from https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
import os
import threading
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Union
import lightning.pytorch as pl
import torch
from lightning.pytorch import Callback
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_info
class EMA(Callback):
"""Implements Exponential Moving Averaging (EMA).
When training a model, this callback will maintain moving averages of the trained parameters.
When evaluating, we use the moving averages copy of the trained parameters.
When saving, we save an additional set of parameters with the prefix `ema`.
Args:
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
every_n_steps: Apply EMA every N steps.
cpu_offload: Offload weights to CPU.
"""
def __init__(
self,
decay: float,
validate_original_weights: bool = False,
every_n_steps: int = 1,
cpu_offload: bool = False,
):
if not (0 <= decay <= 1):
raise MisconfigurationException("EMA decay value must be between 0 and 1")
self.decay = decay
self.validate_original_weights = validate_original_weights
self.every_n_steps = every_n_steps
self.cpu_offload = cpu_offload
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Add the EMA optimizer to the trainer."""
device = pl_module.device if not self.cpu_offload else torch.device("cpu")
trainer.optimizers = [
EMAOptimizer(
optim,
device=device,
decay=self.decay,
every_n_steps=self.every_n_steps,
current_step=trainer.global_step,
)
for optim in trainer.optimizers
if not isinstance(optim, EMAOptimizer)
]
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Swap the model weights with the EMA weights."""
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Swap the model weights back to the original weights."""
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Swap the model weights with the EMA weights."""
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Swap the model weights back to the original weights."""
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
"""Check if the EMA weights should be validated."""
return not self.validate_original_weights and self._ema_initialized(trainer)
def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
"""Check if the EMA weights have been initialized."""
return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers)
def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
"""Swaps the model weights with the EMA weights."""
for optimizer in trainer.optimizers:
assert isinstance(optimizer, EMAOptimizer)
optimizer.switch_main_parameter_weights(saving_ema_model)
@contextlib.contextmanager
def save_ema_model(self, trainer: "pl.Trainer"):
"""Saves an EMA copy of the model + EMA optimizer states for resume."""
self.swap_model_weights(trainer, saving_ema_model=True)
try:
yield
finally:
self.swap_model_weights(trainer, saving_ema_model=False)
@contextlib.contextmanager
def save_original_optimizer_state(self, trainer: "pl.Trainer"):
"""Save the original optimizer state."""
for optimizer in trainer.optimizers:
assert isinstance(optimizer, EMAOptimizer)
optimizer.save_original_optimizer_state = True
try:
yield
finally:
for optimizer in trainer.optimizers:
optimizer.save_original_optimizer_state = False
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
"""Load the EMA state from the checkpoint if it exists."""
checkpoint_callback = trainer.checkpoint_callback
# use the connector as NeMo calls the connector directly in the exp_manager when restoring.
connector = trainer._checkpoint_connector
# Replace connector._ckpt_path with below to avoid calling into lightning's protected API
ckpt_path = trainer.ckpt_path
if (
ckpt_path
and checkpoint_callback is not None
and "EMA" in type(checkpoint_callback).__name__
):
ext = checkpoint_callback.FILE_EXTENSION
if ckpt_path.endswith(f"-EMA{ext}"):
rank_zero_info(
"loading EMA based weights. "
"The callback will treat the loaded EMA weights as the main weights"
" and create a new EMA copy when training."
)
return
ema_path = ckpt_path.replace(ext, f"-EMA{ext}")
if os.path.exists(ema_path):
ema_state_dict = torch.load(ema_path, map_location=torch.device("cpu"))
checkpoint["optimizer_states"] = ema_state_dict["optimizer_states"]
del ema_state_dict
rank_zero_info("EMA state has been restored.")
else:
raise MisconfigurationException(
"Unable to find the associated EMA weights when re-loading, "
f"training will start with new EMA weights. Expected them to be at: {ema_path}",
)
@torch.no_grad()
def ema_update(ema_model_tuple, current_model_tuple, decay):
"""Update the EMA model with the current model."""
torch._foreach_mul_(ema_model_tuple, decay)
torch._foreach_add_(
ema_model_tuple,
current_model_tuple,
alpha=(1.0 - decay),
)
def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None):
"""Run EMA update on CPU."""
if pre_sync_stream is not None:
pre_sync_stream.synchronize()
ema_update(ema_model_tuple, current_model_tuple, decay)
class EMAOptimizer(torch.optim.Optimizer):
r"""EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average
of parameters registered in the optimizer.
EMA parameters are automatically updated after every step of the optimizer
with the following formula:
ema_weight = decay * ema_weight + (1 - decay) * training_weight
To access EMA parameters, use ``swap_ema_weights()`` context manager to
perform a temporary in-place swap of regular parameters with EMA
parameters.
Notes:
- EMAOptimizer is not compatible with APEX AMP O2.
Args:
optimizer (torch.optim.Optimizer): optimizer to wrap
device (torch.device): device for EMA parameters
decay (float): decay factor
Returns:
returns an instance of torch.optim.Optimizer that computes EMA of
parameters
Example:
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())
opt = EMAOptimizer(opt, device, 0.9999)
for epoch in range(epochs):
training_loop(model, opt)
regular_eval_accuracy = evaluate(model)
with opt.swap_ema_weights():
ema_eval_accuracy = evaluate(model)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
device: torch.device,
decay: float = 0.9999,
every_n_steps: int = 1,
current_step: int = 0,
):
self.optimizer = optimizer
self.decay = decay
self.device = device
self.current_step = current_step
self.every_n_steps = every_n_steps
self.save_original_optimizer_state = False
self.first_iteration = True
self.rebuild_ema_params = True
self.stream = None
self.thread = None
self.ema_params = ()
self.in_saving_ema_model_context = False
def all_parameters(self) -> Iterable[torch.Tensor]:
"""Return an iterator over all parameters in the optimizer."""
return (param for group in self.param_groups for param in group["params"])
def step(self, closure=None, grad_scaler=None, **kwargs):
"""Perform a single optimization step."""
self.join()
if self.first_iteration:
if any(p.is_cuda for p in self.all_parameters()):
self.stream = torch.cuda.Stream()
self.first_iteration = False
if self.rebuild_ema_params:
opt_params = list(self.all_parameters())
self.ema_params += tuple(
copy.deepcopy(param.data.detach()).to(self.device)
for param in opt_params[len(self.ema_params) :]
)
self.rebuild_ema_params = False
if (
getattr(self.optimizer, "_step_supports_amp_scaling", False)
and grad_scaler is not None
):
loss = self.optimizer.step(closure=closure, grad_scaler=grad_scaler)
else:
loss = self.optimizer.step(closure)
if self._should_update_at_step():
self.update()
self.current_step += 1
return loss
def _should_update_at_step(self) -> bool:
"""Check if the EMA parameters should be updated at the current step."""
return self.current_step % self.every_n_steps == 0
@torch.no_grad()
def update(self):
"""Update the EMA parameters."""
if self.stream is not None:
self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
current_model_state = tuple(
param.data.to(self.device, non_blocking=True) for param in self.all_parameters()
)
if self.device.type == "cuda":
ema_update(self.ema_params, current_model_state, self.decay)
if self.device.type == "cpu":
self.thread = threading.Thread(
target=run_ema_update_cpu,
args=(
self.ema_params,
current_model_state,
self.decay,
self.stream,
),
)
self.thread.start()
def swap_tensors(self, tensor1, tensor2):
"""Swaps the tensors in-place."""
tmp = torch.empty_like(tensor1)
tmp.copy_(tensor1)
tensor1.copy_(tensor2)
tensor2.copy_(tmp)
def switch_main_parameter_weights(self, saving_ema_model: bool = False):
"""Switches the main parameter weights with the EMA weights."""
self.join()
self.in_saving_ema_model_context = saving_ema_model
for param, ema_param in zip(self.all_parameters(), self.ema_params):
self.swap_tensors(param.data, ema_param)
@contextlib.contextmanager
def swap_ema_weights(self, enabled: bool = True):
r"""A context manager to in-place swap regular parameters with EMA parameters. It swaps back
to the original regular parameters on context manager exit.
Args:
enabled (bool): whether the swap should be performed
"""
if enabled:
self.switch_main_parameter_weights()
try:
yield
finally:
if enabled:
self.switch_main_parameter_weights()
def __getattr__(self, name):
"""Forward all other attribute calls to the optimizer."""
return getattr(self.optimizer, name)
def join(self):
"""Wait for the update to complete."""
if self.stream is not None:
self.stream.synchronize()
if self.thread is not None:
self.thread.join()
def state_dict(self):
"""Return the state dict for the optimizer."""
self.join()
if self.save_original_optimizer_state:
return self.optimizer.state_dict()
# if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
ema_params = (
self.ema_params
if not self.in_saving_ema_model_context
else list(self.all_parameters())
)
state_dict = {
"opt": self.optimizer.state_dict(),
"ema": ema_params,
"current_step": self.current_step,
"decay": self.decay,
"every_n_steps": self.every_n_steps,
}
return state_dict
def load_state_dict(self, state_dict):
"""Load the state dict for the optimizer."""
self.join()
self.optimizer.load_state_dict(state_dict["opt"])
self.ema_params = tuple(
param.to(self.device) for param in copy.deepcopy(state_dict["ema"])
)
self.current_step = state_dict["current_step"]
self.decay = state_dict["decay"]
self.every_n_steps = state_dict["every_n_steps"]
self.rebuild_ema_params = False
def add_param_group(self, param_group):
"""Add a param group to the optimizer."""
self.optimizer.add_param_group(param_group)
self.rebuild_ema_params = True
class EMAModelCheckpoint(ModelCheckpoint):
"""EMA version of ModelCheckpoint that saves EMA checkpoints as well."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _ema_callback(self, trainer: "pl.Trainer") -> Optional[EMA]:
"""Returns the EMA callback if it exists."""
ema_callback = None
for callback in trainer.callbacks:
if isinstance(callback, EMA):
ema_callback = callback
return ema_callback
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""Saves the checkpoint file and the EMA checkpoint file if it exists."""
ema_callback = self._ema_callback(trainer)
if ema_callback is not None:
with ema_callback.save_original_optimizer_state(trainer):
super()._save_checkpoint(trainer, filepath)
# save EMA copy of the model as well.
with ema_callback.save_ema_model(trainer):
filepath = self._ema_format_filepath(filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
else:
super()._save_checkpoint(trainer, filepath)
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""Removes the checkpoint file and the EMA checkpoint file if it exists."""
super()._remove_checkpoint(trainer, filepath)
ema_callback = self._ema_callback(trainer)
if ema_callback is not None:
# remove EMA copy of the state dict as well.
filepath = self._ema_format_filepath(filepath)
super()._remove_checkpoint(trainer, filepath)
def _ema_format_filepath(self, filepath: str) -> str:
"""Appends '-EMA' to the filepath."""
return filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}")
def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool:
"""Checks if any of the checkpoints are EMA checkpoints."""
return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints)
def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool:
"""Checks if the filepath is an EMA checkpoint."""
return str(filepath).endswith(f"-EMA{self.FILE_EXTENSION}")
@property
def _saved_checkpoint_paths(self) -> Iterable[Path]:
"""Returns all the saved checkpoint paths in the directory."""
return Path(self.dirpath).rglob("*.ckpt")

View File

@@ -0,0 +1,857 @@
# Adapted from: https://github.com/zrqiao/NeuralPLexer
import random
import rootutils
import torch
from beartype.typing import Any, Dict, Optional, Tuple
from omegaconf import DictConfig
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.models.components.embedding import (
GaussianFourierEncoding1D,
RelativeGeometryEncoding,
)
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
from flowdock.models.components.modules import (
BiDirectionalTriangleAttention,
TransformerLayer,
)
from flowdock.utils import RankedLogger
from flowdock.utils.frame_utils import cartesian_to_internal, get_frame_matrix
from flowdock.utils.model_utils import GELUMLP
log = RankedLogger(__name__, rank_zero_only=True)
STATE_DICT = Dict[str, Any]
class ProtFormer(torch.nn.Module):
"""Protein relational reasoning with downsampled edges."""
def __init__(
self,
dim: int,
pair_dim: int,
n_blocks: int = 4,
n_heads: int = 8,
dropout: float = 0.0,
):
"""Initialize the ProtFormer model."""
super().__init__()
self.dim = dim
self.pair_dim = pair_dim
self.n_heads = n_heads
self.n_blocks = n_blocks
self.time_encoding = GaussianFourierEncoding1D(16)
self.res_in_mlp = GELUMLP(dim + 32, dim, dropout=dropout)
self.chain_pos_encoding = GaussianFourierEncoding1D(self.pair_dim // 4)
self.rel_geom_enc = RelativeGeometryEncoding(15, self.pair_dim)
self.template_nenc = GELUMLP(64 + 37 * 3, self.dim, n_hidden_feats=128)
self.template_eenc = RelativeGeometryEncoding(15, self.pair_dim)
self.template_binding_site_enc = torch.nn.Linear(1, 64, bias=False)
self.pp_edge_embed = GELUMLP(
pair_dim + self.pair_dim // 4 * 2 + dim * 2,
self.pair_dim,
n_hidden_feats=dim,
dropout=dropout,
)
self.graph_stacks = torch.nn.ModuleList(
[
TransformerLayer(
dim,
n_heads,
head_dim=pair_dim // n_heads,
edge_channels=pair_dim,
edge_update=True,
dropout=dropout,
)
for _ in range(self.n_blocks)
]
)
self.ABab_mha = TransformerLayer(
pair_dim,
n_heads,
bidirectional=True,
)
self.triangle_stacks = torch.nn.ModuleList(
[
BiDirectionalTriangleAttention(pair_dim, pair_dim // n_heads, n_heads)
for _ in range(self.n_blocks)
]
)
self.graph_relations = [
(
"residue_to_residue",
"gather_idx_ab_a",
"gather_idx_ab_b",
"prot_res",
"prot_res",
),
(
"sampled_residue_to_sampled_residue",
"gather_idx_AB_a",
"gather_idx_AB_b",
"prot_res",
"prot_res",
),
]
def compute_chain_pe(
self,
residue_index,
res_chain_index,
src_rid,
dst_rid,
):
"""Compute chain positional encoding for a pair of residues."""
chain_disp = residue_index[src_rid] - residue_index[dst_rid]
chain_rope = self.chain_pos_encoding(chain_disp.div(8).unsqueeze(-1)).div(
chain_disp.div(8).abs().add(1).unsqueeze(-1)
)
# Mask cross-chain entries
chain_mask = res_chain_index[src_rid] == res_chain_index[dst_rid]
chain_rope = chain_rope * chain_mask[..., None]
return chain_rope
def compute_chain_pair_pe(
self,
residue_index,
res_chain_index,
AB_broadcasted_rid,
ab_rid,
AB_broadcasted_cid,
ab_cid,
):
"""Compute chain positional encoding for a pair of residues."""
chain_disp_row = residue_index[AB_broadcasted_rid] - residue_index[ab_rid]
chain_disp_col = residue_index[AB_broadcasted_cid] - residue_index[ab_cid]
chain_disp = torch.stack([chain_disp_row, chain_disp_col], dim=-1)
chain_rope = self.chain_pos_encoding(chain_disp.div(8).unsqueeze(-1)).div(
chain_disp.div(8).abs().add(1).unsqueeze(-1)
)
# Mask cross-chain entries
chain_mask_row = res_chain_index[AB_broadcasted_rid] == res_chain_index[ab_rid]
chain_mask_col = res_chain_index[AB_broadcasted_cid] == res_chain_index[ab_cid]
chain_mask = torch.stack([chain_mask_row, chain_mask_col], dim=-1)
chain_rope = (chain_rope * chain_mask[..., None]).flatten(-2, -1)
return chain_rope
def eval_protein_template_encodings(self, batch, edge_idx, use_plddt=False):
"""Evaluate template encodings for protein residues."""
with torch.no_grad():
template_bb_coords = batch["features"]["apo_res_atom_positions"][:, :3]
template_bb_frames = get_frame_matrix(
template_bb_coords[:, 0, :],
template_bb_coords[:, 1, :],
template_bb_coords[:, 2, :],
)
# Add template local representations & lddt
template_local_coords = cartesian_to_internal(
batch["features"]["apo_res_atom_positions"],
template_bb_frames.unsqueeze(1),
)
template_local_coords[~batch["features"]["apo_res_atom_mask"].bool()] = 0
# if use_plddt:
# template_plddt_enc = F.one_hot(
# torch.bucketize(
# batch["features"]["apo_pLDDT"],
# torch.linspace(0, 1, 65, device=template_bb_coords.device)[:-1],
# right=True,
# )
# - 1,
# num_classes=64,
# )
# else:
# template_plddt_enc = torch.zeros(
# template_local_coords.shape[0], 64, device=template_bb_coords.device
# )
if self.training:
use_sidechain_coords = random.randint(0, 1) # nosec
template_local_coords = template_local_coords * use_sidechain_coords
# use_plddt_input = random.randint(0, 1)
# template_plddt_enc = template_plddt_enc * use_plddt_input
if "binding_site_mask" in batch["features"].keys():
# Externally-specified binding residue list
binding_site_enc = self.template_binding_site_enc(
batch["features"]["binding_site_mask"][:, None].float()
)
else:
binding_site_enc = torch.zeros(
template_local_coords.shape[0], 64, device=template_bb_coords.device
)
template_nfeat = self.template_nenc(
torch.cat([template_local_coords.flatten(-2, -1), binding_site_enc], dim=-1)
)
template_efeat = self.template_eenc(template_bb_frames, edge_idx)
template_alignment_mask = batch["features"]["apo_res_alignment_mask"].float()
if self.training:
# template_alignment_mask = template_alignment_mask * use_template
nomasking_rate = random.randint(9, 10) / 10 # nosec
template_alignment_mask = template_alignment_mask * (
torch.rand_like(template_alignment_mask) < nomasking_rate
)
template_nfeat = template_nfeat * template_alignment_mask.unsqueeze(-1)
template_efeat = (
template_efeat
* template_alignment_mask[edge_idx[0]].unsqueeze(-1)
* template_alignment_mask[edge_idx[1]].unsqueeze(-1)
)
return template_nfeat, template_efeat
def forward(self, batch, **kwargs):
"""Forward pass of the ProtFormer model."""
return self.forward_prot_sample(batch, **kwargs)
def forward_prot_sample(
self,
batch,
embed_coords=True,
in_attr_suffix="",
out_attr_suffix="",
use_template=False,
use_plddt=False,
**kwargs,
):
"""Forward pass of the ProtFormer model for a single protein sample."""
features = batch["features"]
indexer = batch["indexer"]
metadata = batch["metadata"]
device = features["res_type"].device
time_encoding = self.time_encoding(features["timestep_encoding_prot"])
if not embed_coords:
time_encoding = torch.zeros_like(time_encoding)
residue_rep = (
self.res_in_mlp(
torch.cat(
[
features["res_embedding_in"],
time_encoding,
],
dim=1,
)
)
+ features["res_embedding_in"]
)
batch_size = metadata["num_structid"]
# Prepare indexers
# Use max to ensure segmentation faults are 100% invoked
# in case there are any bad indices
max(metadata["num_a_per_sample"])
n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"]
indexer["gather_idx_pid_b"] = indexer["gather_idx_pid_a"]
# Evaluate gather_idx_AB_a and gather_idx_AB_b
# Assign a to rows and b to columns
# Simple broadcasting for single-structure batches
indexer["gather_idx_AB_a"] = (
indexer["gather_idx_pid_a"]
.view(batch_size, n_protein_patches)[:, :, None]
.expand(-1, -1, n_protein_patches)
.contiguous()
.flatten()
)
indexer["gather_idx_AB_b"] = (
indexer["gather_idx_pid_b"]
.view(batch_size, n_protein_patches)[:, None, :]
.expand(-1, n_protein_patches, -1)
.contiguous()
.flatten()
)
# Handle all batch offsets here
graph_batcher = make_multi_relation_graph_batcher(self.graph_relations, indexer, metadata)
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
input_protein_coords_padded = features["input_protein_coords"]
backbone_frames = get_frame_matrix(
input_protein_coords_padded[:, 0, :],
input_protein_coords_padded[:, 1, :],
input_protein_coords_padded[:, 2, :],
)
batch["features"]["backbone_frames"] = backbone_frames
# Adding geometrical info to pair representations
chain_pe = self.compute_chain_pe(
features["residue_index"],
features["res_chain_id"],
merged_edge_idx[0],
merged_edge_idx[1],
)
geometry_pe = self.rel_geom_enc(backbone_frames, merged_edge_idx)
if not embed_coords:
geometry_pe = torch.zeros_like(geometry_pe)
merged_edge_reps = self.pp_edge_embed(
torch.cat(
[
geometry_pe,
chain_pe,
residue_rep[merged_edge_idx[0]],
residue_rep[merged_edge_idx[1]],
],
dim=-1,
)
)
if use_template and "apo_res_atom_positions" in features.keys():
(
template_res_encodings,
template_geom_encodings,
) = self.eval_protein_template_encodings(batch, merged_edge_idx, use_plddt=use_plddt)
residue_rep = residue_rep + template_res_encodings
merged_edge_reps = merged_edge_reps + template_geom_encodings
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
node_reps = {"prot_res": residue_rep}
gather_idx_res_protpatch = indexer["gather_idx_a_pid"]
# Pointer: AB->AB, ab->AB
gather_idx_ab_AB = (
indexer["gather_idx_ab_structid"] * n_protein_patches**2
+ (gather_idx_res_protpatch % n_protein_patches)[indexer["gather_idx_ab_a"]]
* n_protein_patches
+ (gather_idx_res_protpatch % n_protein_patches)[indexer["gather_idx_ab_b"]]
)
# Intertwine graph iterations and triangle iterations
for block_id in range(self.n_blocks):
# Communicate between atomistic and patch resolutions
# Up-sampling for interface edge embeddings
rec_pair_rep = edge_reps["residue_to_residue"]
AB_grid_attr_flat = edge_reps["sampled_residue_to_sampled_residue"]
# Upper-left block: intra-window visual-attention
# Cross-attention between random and grid edges
rec_pair_rep, AB_grid_attr_flat = self.ABab_mha(
rec_pair_rep,
AB_grid_attr_flat,
(
torch.arange(metadata["num_ab"], device=device),
gather_idx_ab_AB,
),
)
AB_grid_attr = AB_grid_attr_flat.view(
batch_size,
n_protein_patches,
n_protein_patches,
self.pair_dim,
)
# Inter-patch triangle attentions, refining intermolecular edges
_, AB_grid_attr = self.triangle_stacks[block_id](
AB_grid_attr,
AB_grid_attr,
AB_grid_attr.unsqueeze(-4),
)
# Transfer grid-formatted representations to edges
edge_reps["residue_to_residue"] = rec_pair_rep
edge_reps["sampled_residue_to_sampled_residue"] = AB_grid_attr.flatten(0, 2)
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
merged_edge_reps = graph_batcher.collate_edge_attr(edge_reps)
# Graph transformer iteration
_, merged_node_reps, merged_edge_reps = self.graph_stacks[block_id](
merged_node_reps,
merged_node_reps,
merged_edge_idx,
merged_edge_reps,
)
node_reps = graph_batcher.offload_node_attr(merged_node_reps)
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
batch["features"][f"rec_res_attr{out_attr_suffix}"] = node_reps["prot_res"]
batch["features"][f"res_res_pair_attr{out_attr_suffix}"] = edge_reps["residue_to_residue"]
batch["features"][f"res_res_grid_attr_flat{out_attr_suffix}"] = edge_reps[
"sampled_residue_to_sampled_residue"
]
batch["indexer"]["gather_idx_AB_a"] = indexer["gather_idx_AB_a"]
batch["indexer"]["gather_idx_AB_b"] = indexer["gather_idx_AB_b"]
batch["indexer"]["gather_idx_ab_AB"] = gather_idx_ab_AB
return batch
class BindingFormer(ProtFormer):
"""Edge inference on protein-ligand graphs."""
def __init__(
self,
dim: int,
pair_dim: int,
n_blocks: int = 4,
n_heads: int = 8,
n_ligand_patches: int = 16,
dropout: float = 0.0,
):
"""Initialize the BindingFormer model."""
super().__init__(
dim,
pair_dim,
n_blocks,
n_heads,
dropout,
)
self.dim = dim
self.n_heads = n_heads
self.n_blocks = n_blocks
self.n_ligand_patches = n_ligand_patches
self.pl_edge_embed = GELUMLP(dim * 2, self.pair_dim, n_hidden_feats=dim, dropout=dropout)
self.AaJ_mha = TransformerLayer(pair_dim, n_heads, bidirectional=True)
self.graph_relations = [
(
"residue_to_residue",
"gather_idx_ab_a",
"gather_idx_ab_b",
"prot_res",
"prot_res",
),
(
"sampled_residue_to_sampled_residue",
"gather_idx_AB_a",
"gather_idx_AB_b",
"prot_res",
"prot_res",
),
(
"sampled_residue_to_sampled_lig_triplet",
"gather_idx_AJ_a",
"gather_idx_AJ_J",
"prot_res",
"lig_trp",
),
(
"sampled_lig_triplet_to_sampled_residue",
"gather_idx_AJ_J",
"gather_idx_AJ_a",
"lig_trp",
"prot_res",
),
(
"residue_to_sampled_lig_triplet",
"gather_idx_aJ_a",
"gather_idx_aJ_J",
"prot_res",
"lig_trp",
),
(
"sampled_lig_triplet_to_residue",
"gather_idx_aJ_J",
"gather_idx_aJ_a",
"lig_trp",
"prot_res",
),
(
"sampled_lig_triplet_to_sampled_lig_triplet",
"gather_idx_IJ_I",
"gather_idx_IJ_J",
"lig_trp",
"lig_trp",
),
]
def forward(
self,
batch,
observed_block_contacts=None,
in_attr_suffix="",
out_attr_suffix="",
):
"""Forward pass of the BindingFormer model."""
features = batch["features"]
indexer = batch["indexer"]
metadata = batch["metadata"]
device = features["res_type"].device
# Synchronize with a language model
residue_rep = features[f"rec_res_attr{in_attr_suffix}"]
rec_pair_rep = features[f"res_res_pair_attr{in_attr_suffix}"]
# Inherit the last-layer pair representations from protein encoder
AB_grid_attr_flat = features[f"res_res_grid_attr_flat{in_attr_suffix}"]
# Prepare indexers
batch_size = metadata["num_structid"]
n_a_per_sample = max(metadata["num_a_per_sample"])
n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"]
if not batch["misc"]["protein_only"]:
n_ligand_patches = max(metadata["num_I_per_sample"])
max(metadata["num_molid_per_sample"])
lig_frame_rep = features[f"lig_trp_attr{in_attr_suffix}"]
UI_grid_attr = features["lig_af_grid_attr_projected"]
IJ_grid_attr = (UI_grid_attr + UI_grid_attr.transpose(1, 2)) / 2
aJ_grid_attr = self.pl_edge_embed(
torch.cat(
[
residue_rep.view(batch_size, n_a_per_sample, self.dim)[:, :, None].expand(
-1, -1, n_ligand_patches, -1
),
lig_frame_rep.view(batch_size, n_ligand_patches, self.dim)[
:, None, :
].expand(-1, n_a_per_sample, -1, -1),
],
dim=-1,
)
)
AJ_grid_attr = IJ_grid_attr.new_zeros(
batch_size, n_protein_patches, n_ligand_patches, self.pair_dim
)
gather_idx_I_I = torch.arange(
batch_size * n_ligand_patches, device=AJ_grid_attr.device
)
gather_idx_a_a = torch.arange(batch_size * n_a_per_sample, device=AJ_grid_attr.device)
# Note: off-diagonal (AJ) blocks are zero-initialized in the prior stack
indexer["gather_idx_IJ_I"] = (
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, :, None]
.expand(-1, -1, n_ligand_patches)
.contiguous()
.flatten()
)
indexer["gather_idx_IJ_J"] = (
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
.expand(-1, n_ligand_patches, -1)
.contiguous()
.flatten()
)
indexer["gather_idx_AJ_a"] = (
indexer["gather_idx_pid_a"]
.view(batch_size, n_protein_patches)[:, :, None]
.expand(-1, -1, n_ligand_patches)
.contiguous()
.flatten()
)
indexer["gather_idx_AJ_J"] = (
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
.expand(-1, n_protein_patches, -1)
.contiguous()
.flatten()
)
indexer["gather_idx_aJ_a"] = (
gather_idx_a_a.view(batch_size, n_a_per_sample)[:, :, None]
.expand(-1, -1, n_ligand_patches)
.contiguous()
.flatten()
)
indexer["gather_idx_aJ_J"] = (
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
.expand(-1, n_a_per_sample, -1)
.contiguous()
.flatten()
)
batch["indexer"] = indexer
if observed_block_contacts is not None:
# Generative feedback from block one-hot sampling
# AJ_grid_attr = (
# AJ_grid_attr
# + observed_block_contacts.transpose(1, 2)
# .contiguous()
# .flatten(0, 1)[indexer["gather_idx_I_molid"]]
# .view(batch_size, n_ligand_patches, n_protein_patches, -1)
# .transpose(1, 2)
# .contiguous()
# )
AJ_grid_attr = AJ_grid_attr + observed_block_contacts
graph_batcher = make_multi_relation_graph_batcher(
self.graph_relations, indexer, metadata
)
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
node_reps = {
"prot_res": residue_rep,
"lig_trp": lig_frame_rep,
}
edge_reps = {
"residue_to_residue": rec_pair_rep,
"sampled_residue_to_sampled_residue": AB_grid_attr_flat,
"sampled_lig_triplet_to_sampled_residue": AJ_grid_attr.flatten(0, 2),
"sampled_residue_to_sampled_lig_triplet": AJ_grid_attr.flatten(0, 2),
"sampled_lig_triplet_to_residue": aJ_grid_attr.flatten(0, 2),
"residue_to_sampled_lig_triplet": aJ_grid_attr.flatten(0, 2),
"sampled_lig_triplet_to_sampled_lig_triplet": IJ_grid_attr.flatten(0, 2),
}
edge_reps = graph_batcher.zero_pad_edge_attr(edge_reps, self.dim, device)
else:
graph_batcher = make_multi_relation_graph_batcher(
self.graph_relations[:2], indexer, metadata
)
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
node_reps = {
"prot_res": residue_rep,
}
edge_reps = {
"residue_to_residue": rec_pair_rep,
"sampled_residue_to_sampled_residue": AB_grid_attr_flat,
}
edge_reps = graph_batcher.zero_pad_edge_attr(edge_reps, self.dim, device)
# Intertwine graph iterations and triangle iterations
gather_idx_res_protpatch = indexer["gather_idx_a_pid"]
for block_id in range(self.n_blocks):
# Communicate between atomistic and patch resolutions
# Up-sampling for interface edge embeddings
rec_pair_rep = edge_reps["residue_to_residue"]
AB_grid_attr_flat = edge_reps["sampled_residue_to_sampled_residue"]
AB_grid_attr = AB_grid_attr_flat.view(
batch_size,
n_protein_patches,
n_protein_patches,
self.pair_dim,
)
if not batch["misc"]["protein_only"]:
# Symmetrize off-diagonal blocks
AJ_grid_attr_flat_ = (
edge_reps["sampled_residue_to_sampled_lig_triplet"]
+ edge_reps["sampled_lig_triplet_to_sampled_residue"]
) / 2
AJ_grid_attr = AJ_grid_attr_flat_.contiguous().view(
batch_size, n_protein_patches, n_ligand_patches, -1
)
aJ_grid_attr_flat_ = (
edge_reps["residue_to_sampled_lig_triplet"]
+ edge_reps["sampled_lig_triplet_to_residue"]
) / 2
aJ_grid_attr = aJ_grid_attr_flat_.contiguous().view(
batch_size, n_a_per_sample, n_ligand_patches, -1
)
IJ_grid_attr = (
edge_reps["sampled_lig_triplet_to_sampled_lig_triplet"]
.contiguous()
.view(batch_size, n_ligand_patches, n_ligand_patches, -1)
)
AJ_grid_attr_temp_, aJ_grid_attr_temp_ = self.AaJ_mha(
AJ_grid_attr.flatten(0, 1),
aJ_grid_attr.flatten(0, 1),
(
gather_idx_res_protpatch,
torch.arange(gather_idx_res_protpatch.shape[0], device=device),
),
)
AJ_grid_attr = AJ_grid_attr_temp_.contiguous().view(
batch_size, n_protein_patches, n_ligand_patches, -1
)
aJ_grid_attr = aJ_grid_attr_temp_.contiguous().view(
batch_size, n_a_per_sample, n_ligand_patches, -1
)
merged_grid_rep = torch.cat(
[
torch.cat([AB_grid_attr, AJ_grid_attr], dim=2),
torch.cat([AJ_grid_attr.transpose(1, 2), IJ_grid_attr], dim=2),
],
dim=1,
)
else:
merged_grid_rep = AB_grid_attr
# Inter-patch triangle attentions
_, merged_grid_rep = self.triangle_stacks[block_id](
merged_grid_rep,
merged_grid_rep,
merged_grid_rep.unsqueeze(-4),
)
# Dis-assemble the grid representation
AB_grid_attr = merged_grid_rep[:, :n_protein_patches, :n_protein_patches]
# Transfer grid-formatted representations to edges
edge_reps["residue_to_residue"] = rec_pair_rep
edge_reps["sampled_residue_to_sampled_residue"] = AB_grid_attr.flatten(0, 2)
if not batch["misc"]["protein_only"]:
AJ_grid_attr = merged_grid_rep[
:, :n_protein_patches, n_protein_patches:
].contiguous()
IJ_grid_attr = merged_grid_rep[
:, n_protein_patches:, n_protein_patches:
].contiguous()
edge_reps["sampled_residue_to_sampled_lig_triplet"] = AJ_grid_attr.flatten(0, 2)
edge_reps["sampled_lig_triplet_to_sampled_residue"] = AJ_grid_attr.flatten(0, 2)
edge_reps["residue_to_sampled_lig_triplet"] = aJ_grid_attr.flatten(0, 2)
edge_reps["sampled_lig_triplet_to_residue"] = aJ_grid_attr.flatten(0, 2)
edge_reps["sampled_lig_triplet_to_sampled_lig_triplet"] = IJ_grid_attr.flatten(
0, 2
)
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
merged_edge_reps = graph_batcher.collate_edge_attr(edge_reps)
# Graph transformer iteration
_, merged_node_reps, merged_edge_reps = self.graph_stacks[block_id](
merged_node_reps,
merged_node_reps,
merged_edge_idx,
merged_edge_reps,
)
node_reps = graph_batcher.offload_node_attr(merged_node_reps)
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
batch["features"][f"rec_res_attr{out_attr_suffix}"] = node_reps["prot_res"]
batch["features"][f"res_res_pair_attr{out_attr_suffix}"] = edge_reps["residue_to_residue"]
batch["features"][f"res_res_grid_attr_flat{out_attr_suffix}"] = edge_reps[
"sampled_residue_to_sampled_residue"
]
if not batch["misc"]["protein_only"]:
batch["features"][f"lig_trp_attr{out_attr_suffix}"] = node_reps["lig_trp"]
batch["features"][f"res_trp_grid_attr_flat{out_attr_suffix}"] = edge_reps[
"sampled_residue_to_sampled_lig_triplet"
]
batch["features"][f"res_trp_pair_attr_flat{out_attr_suffix}"] = edge_reps[
"residue_to_sampled_lig_triplet"
]
batch["features"][f"trp_trp_grid_attr_flat{out_attr_suffix}"] = edge_reps[
"sampled_lig_triplet_to_sampled_lig_triplet"
]
batch["metadata"]["n_lig_patches_per_sample"] = n_ligand_patches
return batch
def resolve_protein_encoder(
protein_model_cfg: DictConfig,
task_cfg: DictConfig,
state_dict: Optional[STATE_DICT] = None,
) -> Tuple[torch.nn.Module, torch.nn.Module]:
"""Instantiates a ProtFormer model for protein encoding.
:param protein_model_cfg: Protein model configuration.
:param task_cfg: Task configuration.
:param state_dict: Optional (potentially-pretrained) state dictionary.
:return: Protein encoder model and residue input projector.
"""
node_dim = protein_model_cfg.residue_dim
model = ProtFormer(
node_dim,
protein_model_cfg.pair_dim,
n_heads=protein_model_cfg.n_heads,
n_blocks=protein_model_cfg.n_encoder_stacks,
dropout=task_cfg.dropout,
)
if protein_model_cfg.use_esm_embedding:
# protein sequence language model
res_in_projector = torch.nn.Linear(protein_model_cfg.plm_embed_dim, node_dim, bias=False)
else:
# one-hot amino acid types
res_in_projector = torch.nn.Linear(
protein_model_cfg.n_aa_types,
node_dim,
bias=False,
)
if protein_model_cfg.from_pretrained and state_dict is not None:
try:
# NOTE: we must avoid enforcing strict key matching
# due to the (new) weights `template_binding_site_enc.weight`
model.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("protein_encoder")
},
strict=False,
)
log.info("Successfully loaded pretrained protein encoder weights.")
except Exception as e:
log.warning(f"Skipping loading of pretrained protein encoder weights due to: {e}.")
try:
res_in_projector.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith(
"plm_adapter"
if protein_model_cfg.use_esm_embedding
else "res_in_projector"
)
}
)
log.info("Successfully loaded pretrained protein input projector weights.")
except Exception as e:
log.warning(
f"Skipping loading of pretrained protein input projector weights due to: {e}."
)
return model, res_in_projector
def resolve_pl_contact_stack(
protein_model_cfg: DictConfig,
ligand_model_cfg: DictConfig,
contact_cfg: DictConfig,
task_cfg: DictConfig,
state_dict: Optional[STATE_DICT] = None,
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.nn.Module]:
"""Instantiates a BindingFormer model for protein-ligand contact prediction.
:param protein_model_cfg: Protein model configuration.
:param ligand_model_cfg: Ligand model configuration.
:param contact_cfg: Contact prediction configuration.
:param task_cfg: Task configuration.
:param state_dict: Optional (potentially-pretrained) state dictionary.
:return: Protein-ligand contact prediction model, contact code embedding, distance bins, and
distogram head.
"""
pl_contact_stack = BindingFormer(
protein_model_cfg.residue_dim,
protein_model_cfg.pair_dim,
n_heads=protein_model_cfg.n_heads,
n_blocks=contact_cfg.n_stacks,
n_ligand_patches=ligand_model_cfg.n_patches,
dropout=contact_cfg.dropout if contact_cfg.get("dropout") else task_cfg.dropout,
)
contact_code_embed = torch.nn.Embedding(2, protein_model_cfg.pair_dim)
# Distogram heads
dist_bins = torch.nn.Parameter(torch.linspace(2, 22, 32), requires_grad=False)
dgram_head = GELUMLP(
protein_model_cfg.pair_dim,
32,
n_hidden_feats=protein_model_cfg.pair_dim,
zero_init=True,
)
if contact_cfg.from_pretrained and state_dict is not None:
try:
# NOTE: we must avoid enforcing strict key matching
# due to the (new) weights `template_binding_site_enc.weight`
pl_contact_stack.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("pl_contact_stack")
},
strict=False,
)
log.info("Successfully loaded pretrained protein-ligand contact prediction weights.")
except Exception as e:
log.warning(
f"Skipping loading of pretrained protein-ligand contact prediction weights due to: {e}."
)
try:
contact_code_embed.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("contact_code_embed")
}
)
log.info("Successfully loaded pretrained contact code embedding weights.")
except Exception as e:
log.warning(
f"Skipping loading of pretrained contact code embedding weights due to: {e}."
)
try:
dgram_head.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("dgram_head")
}
)
log.info("Successfully loaded pretrained distogram head weights.")
except Exception as e:
log.warning(f"Skipping loading of pretrained distogram head weights due to: {e}.")
return pl_contact_stack, contact_code_embed, dist_bins, dgram_head

View File

@@ -0,0 +1,105 @@
# Adapted from: https://github.com/zrqiao/NeuralPLexer
import math
import rootutils
import torch
from beartype.typing import Tuple
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.utils.frame_utils import RigidTransform
class GaussianFourierEncoding1D(torch.nn.Module):
"""Gaussian Fourier Encoding for 1D data."""
def __init__(
self,
n_basis: int,
eps: float = 1e-2,
):
"""Initialize Gaussian Fourier Encoding."""
super().__init__()
self.eps = eps
self.fourier_freqs = torch.nn.Parameter(torch.randn(n_basis) * math.pi)
def forward(
self,
x: torch.Tensor,
):
"""Forward pass of Gaussian Fourier Encoding."""
encodings = torch.cat(
[
torch.sin(self.fourier_freqs.mul(x)),
torch.cos(self.fourier_freqs.mul(x)),
],
dim=-1,
)
return encodings
class GaussianRBFEncoding1D(torch.nn.Module):
"""Gaussian RBF Encoding for 1D data."""
def __init__(
self,
n_basis: int,
x_max: float,
sigma: float = 1.0,
):
"""Initialize Gaussian RBF Encoding."""
super().__init__()
self.sigma = sigma
self.rbf_centers = torch.nn.Parameter(
torch.linspace(0, x_max, n_basis), requires_grad=False
)
def forward(
self,
x: torch.Tensor,
):
"""Forward pass of Gaussian RBF Encoding."""
encodings = torch.exp(-((x.unsqueeze(-1) - self.rbf_centers).div(self.sigma).square()))
return encodings
class RelativeGeometryEncoding(torch.nn.Module):
"Compute radial basis functions and iterresidue/pseudoresidue orientations."
def __init__(self, n_basis: int, out_dim: int, d_max: float = 20.0):
"""Initialize RelativeGeometryEncoding."""
super().__init__()
self.rbf_encoding = GaussianRBFEncoding1D(n_basis, d_max)
self.rel_geom_projector = torch.nn.Linear(n_basis + 15, out_dim, bias=False)
def forward(self, frames: RigidTransform, merged_edge_idx: Tuple[torch.Tensor, torch.Tensor]):
"""Forward pass of RelativeGeometryEncoding."""
frame_t, frame_R = frames.t, frames.R
pair_dists = torch.norm(
frame_t[merged_edge_idx[0]] - frame_t[merged_edge_idx[1]],
dim=-1,
)
pair_directions_l = torch.matmul(
(frame_t[merged_edge_idx[1]] - frame_t[merged_edge_idx[0]]).unsqueeze(-2),
frame_R[merged_edge_idx[0]],
).squeeze(-2) / pair_dists.square().add(1).sqrt().unsqueeze(-1)
pair_directions_r = torch.matmul(
(frame_t[merged_edge_idx[0]] - frame_t[merged_edge_idx[1]]).unsqueeze(-2),
frame_R[merged_edge_idx[1]],
).squeeze(-2) / pair_dists.square().add(1).sqrt().unsqueeze(-1)
pair_orientations = torch.matmul(
frame_R.transpose(-2, -1).contiguous()[merged_edge_idx[0]],
frame_R[merged_edge_idx[1]],
)
return self.rel_geom_projector(
torch.cat(
[
self.rbf_encoding(pair_dists),
pair_directions_l,
pair_directions_r,
pair_orientations.flatten(-2, -1),
],
dim=-1,
)
)

View File

@@ -0,0 +1,884 @@
# Adapted from: https://github.com/zrqiao/NeuralPLexer
import rootutils
import torch
from beartype.typing import Any, Dict, Optional, Tuple
from omegaconf import DictConfig
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.models.components.embedding import (
GaussianFourierEncoding1D,
RelativeGeometryEncoding,
)
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
from flowdock.models.components.modules import PointSetAttention
from flowdock.utils import RankedLogger
from flowdock.utils.frame_utils import RigidTransform, get_frame_matrix
from flowdock.utils.model_utils import GELUMLP, AveragePooling, SumPooling, segment_mean
STATE_DICT = Dict[str, Any]
log = RankedLogger(__name__, rank_zero_only=True)
class LocalUpdateUsingReferenceRotations(torch.nn.Module):
"""Update local geometric representations using reference rotations."""
def __init__(
self,
fiber_dim: int,
extra_feat_dim: int = 0,
eps: float = 1e-4,
dropout: float = 0.0,
hidden_dim: Optional[int] = None,
zero_init: bool = False,
):
"""Initialize the LocalUpdateUsingReferenceRotations module."""
super().__init__()
self.dim = fiber_dim * 5 + extra_feat_dim
self.fiber_dim = fiber_dim
self.mlp = GELUMLP(
self.dim,
fiber_dim * 4,
dropout=dropout,
zero_init=zero_init,
n_hidden_feats=hidden_dim,
)
self.eps = eps
def forward(
self,
x: torch.Tensor,
rotation_mats: torch.Tensor,
extra_feats=None,
):
"""Forward pass of the LocalUpdateUsingReferenceRotations module."""
# Vector norms are evaluated without applying rigid transform
vecx_local = torch.matmul(
x[:, 1:].transpose(-2, -1),
rotation_mats,
)
x1_local = torch.cat(
[
x[:, 0],
vecx_local.flatten(-2, -1),
x[:, 1:].square().sum(dim=-2).add(self.eps).sqrt(),
],
dim=-1,
)
if extra_feats is not None:
x1_local = torch.cat([x1_local, extra_feats], dim=-1)
x1_local = self.mlp(x1_local).view(-1, 4, self.fiber_dim)
vecx1_out = torch.matmul(
rotation_mats,
x1_local[:, 1:],
)
x1_out = torch.cat([x1_local[:, :1], vecx1_out], dim=-2)
return x1_out
class LocalUpdateUsingChannelWiseGating(torch.nn.Module):
"""Update local geometric representations using channel-wise gating."""
def __init__(
self,
fiber_dim: int,
eps: float = 1e-4,
dropout: float = 0.0,
hidden_dim: Optional[int] = None,
zero_init: bool = False,
):
"""Initialize the LocalUpdateUsingChannelWiseGating module."""
super().__init__()
self.dim = fiber_dim * 2
self.fiber_dim = fiber_dim
self.mlp = GELUMLP(
self.dim,
self.dim,
dropout=dropout,
n_hidden_feats=hidden_dim,
zero_init=zero_init,
)
self.gate = torch.nn.Sigmoid()
self.lin_out = torch.nn.Linear(fiber_dim, fiber_dim, bias=False)
if zero_init:
self.lin_out.weight.data.fill_(0.0)
self.eps = eps
def forward(
self,
x: torch.Tensor,
):
"""Forward pass of the LocalUpdateUsingChannelWiseGating module."""
x1 = torch.cat(
[
x[:, 0],
x[:, 1:].square().sum(dim=-2).add(self.eps).sqrt(),
],
dim=-1,
)
x1 = self.mlp(x1)
# Gated nonlinear operation on l=1 representations
x1_scalar, x1_gatein = torch.split(x1, self.fiber_dim, dim=-1)
x1_gate = self.gate(x1_gatein).unsqueeze(-2)
vecx1_out = self.lin_out(x[:, 1:]).mul(x1_gate)
x1_out = torch.cat([x1_scalar.unsqueeze(-2), vecx1_out], dim=-2)
return x1_out
class EquivariantTransformerBlock(torch.nn.Module):
"""Equivariant Transformer Block module."""
def __init__(
self,
fiber_dim: int,
heads: int = 8,
point_dim: int = 4,
eps: float = 1e-4,
edge_dim: Optional[int] = None,
target_frames: bool = False,
edge_update: bool = False,
dropout: float = 0.0,
):
"""Initialize the EquivariantTransformerBlock module."""
super().__init__()
self.attn_conv = PointSetAttention(
fiber_dim,
heads=heads,
point_dim=point_dim,
edge_dim=edge_dim,
edge_update=edge_update,
)
self.fiber_dim = fiber_dim
self.target_frames = target_frames
self.eps = eps
self.edge_update = edge_update
if target_frames:
self.local_update = LocalUpdateUsingReferenceRotations(
fiber_dim, eps=eps, dropout=dropout, zero_init=True
)
else:
self.local_update = LocalUpdateUsingChannelWiseGating(
fiber_dim, eps=eps, dropout=dropout, zero_init=True
)
def forward(
self,
x: torch.Tensor,
edge_index: torch.LongTensor,
t: torch.Tensor,
R: torch.Tensor = None,
x_edge: torch.Tensor = None,
):
"""Forward pass of the EquivariantTransformerBlock module."""
if self.edge_update:
xout, edge_out = self.attn_conv(x, x, edge_index, t, t, x_edge=x_edge)
x_edge = x_edge + edge_out
else:
xout = self.attn_conv(x, x, edge_index, t, t, x_edge=x_edge)
x = x + xout
if self.target_frames:
x = self.local_update(x, R) + x
else:
x = self.local_update(x) + x
return x, x_edge
class EquivariantStructureDenoisingModule(torch.nn.Module):
"""Equivariant Structure Denoising Module."""
def __init__(
self,
fiber_dim: int,
input_dim: int,
input_pair_dim: int,
hidden_dim: int = 1024,
n_stacks: int = 4,
n_heads: int = 8,
dropout: float = 0.0,
):
"""Initialize the EquivariantStructureDenoisingModule module."""
super().__init__()
self.input_dim = input_dim
self.input_pair_dim = input_pair_dim
self.fiber_dim = fiber_dim
self.protatm_padding_dim = 37
self.n_blocks = n_stacks
self.input_node_projector = torch.nn.Linear(input_dim, fiber_dim, bias=False)
self.input_node_vec_projector = torch.nn.Linear(input_dim, fiber_dim * 3, bias=False)
self.input_pair_projector = torch.nn.Linear(input_pair_dim, fiber_dim, bias=False)
# Inherit the residue representations
self.atm_embed = GELUMLP(input_dim + 32, fiber_dim)
self.ipa_modules = torch.nn.ModuleList(
[
EquivariantTransformerBlock(
fiber_dim,
heads=n_heads,
point_dim=fiber_dim // (n_heads * 2),
edge_dim=fiber_dim,
target_frames=True,
edge_update=True,
dropout=dropout,
)
for _ in range(n_stacks)
]
)
self.res_adapters = torch.nn.ModuleList(
[
LocalUpdateUsingReferenceRotations(
fiber_dim,
extra_feat_dim=input_dim,
hidden_dim=hidden_dim,
dropout=dropout,
zero_init=True,
)
for _ in range(n_stacks)
]
)
self.protatm_type_encoding = GELUMLP(self.protatm_padding_dim + input_dim, input_pair_dim)
self.time_encoding = GaussianFourierEncoding1D(16)
self.rel_geom_enc = RelativeGeometryEncoding(15, fiber_dim)
self.rel_geom_embed = GELUMLP(fiber_dim, fiber_dim, n_hidden_feats=fiber_dim)
# [displacement, scale]
self.out_drift_res = torch.nn.ModuleList(
[torch.nn.Linear(fiber_dim, 1, bias=False) for _ in range(n_stacks)]
)
# for i in range(n_stacks):
# self.out_drift_res[i].weight.data.fill_(0.0)
self.out_scale_res = torch.nn.ModuleList(
[GELUMLP(fiber_dim, 1, zero_init=True) for _ in range(n_stacks)]
)
self.out_drift_atm = torch.nn.ModuleList(
[torch.nn.Linear(fiber_dim, 1, bias=False) for _ in range(n_stacks)]
)
# for i in range(n_stacks):
# self.out_drift_atm[i].weight.data.fill_(0.0)
self.out_scale_atm = torch.nn.ModuleList(
[GELUMLP(fiber_dim, 1, zero_init=True) for _ in range(n_stacks)]
)
# Pre-tabulated edges
self.graph_relations = [
(
"residue_to_residue",
"gather_idx_ab_a",
"gather_idx_ab_b",
"prot_res",
"prot_res",
),
(
"sampled_residue_to_sampled_residue",
"gather_idx_AB_a",
"gather_idx_AB_b",
"prot_res",
"prot_res",
),
(
"prot_atm_to_prot_atm_graph",
"protatm_protatm_idx_src",
"protatm_protatm_idx_dst",
"prot_atm",
"prot_atm",
),
(
"prot_atm_to_prot_atm_knn",
"knn_idx_protatm_protatm_src",
"knn_idx_protatm_protatm_dst",
"prot_atm",
"prot_atm",
),
(
"prot_atm_to_residue",
"protatm_res_idx_protatm",
"protatm_res_idx_res",
"prot_atm",
"prot_res",
),
(
"residue_to_prot_atm",
"protatm_res_idx_res",
"protatm_res_idx_protatm",
"prot_res",
"prot_atm",
),
(
"sampled_lig_triplet_to_lig_atm",
"gather_idx_UI_I",
"gather_idx_UI_u",
"lig_trp",
"lig_atm",
),
(
"lig_atm_to_sampled_lig_triplet",
"gather_idx_UI_u",
"gather_idx_UI_I",
"lig_atm",
"lig_trp",
),
(
"lig_atm_to_lig_atm_graph",
"gather_idx_uv_u",
"gather_idx_uv_v",
"lig_atm",
"lig_atm",
),
(
"sampled_residue_to_sampled_lig_triplet",
"gather_idx_AJ_a",
"gather_idx_AJ_J",
"prot_res",
"lig_trp",
),
(
"sampled_lig_triplet_to_sampled_residue",
"gather_idx_AJ_J",
"gather_idx_AJ_a",
"lig_trp",
"prot_res",
),
(
"residue_to_sampled_lig_triplet",
"gather_idx_aJ_a",
"gather_idx_aJ_J",
"prot_res",
"lig_trp",
),
(
"sampled_lig_triplet_to_residue",
"gather_idx_aJ_J",
"gather_idx_aJ_a",
"lig_trp",
"prot_res",
),
(
"sampled_lig_triplet_to_sampled_lig_triplet",
"gather_idx_IJ_I",
"gather_idx_IJ_J",
"lig_trp",
"lig_trp",
),
(
"prot_atm_to_lig_atm_knn",
"knn_idx_protatm_ligatm_src",
"knn_idx_protatm_ligatm_dst",
"prot_atm",
"lig_atm",
),
(
"lig_atm_to_prot_atm_knn",
"knn_idx_ligatm_protatm_src",
"knn_idx_ligatm_protatm_dst",
"lig_atm",
"prot_atm",
),
(
"lig_atm_to_lig_atm_knn",
"knn_idx_ligatm_ligatm_src",
"knn_idx_ligatm_ligatm_dst",
"lig_atm",
"lig_atm",
),
]
self.graph_relations_no_ligand = self.graph_relations[:6]
def init_scalar_vec_rep(self, x, x_v=None, frame=None):
"""Initialize scalar and vector representations."""
if frame is None:
# Zero-initialize the vector channels
vec_shape = (*x.shape[:-1], 3, x.shape[-1])
res = torch.cat([x.unsqueeze(-2), torch.zeros(vec_shape, device=x.device)], dim=-2)
else:
x_v = x_v.view(*x.shape[:-1], 3, x.shape[-1])
x_v_glob = torch.matmul(frame.R, x_v)
res = torch.cat([x.unsqueeze(-2), x_v_glob], dim=-2)
return res
def forward(
self,
batch,
frozen_lig=False,
frozen_prot=False,
**kwargs,
):
"""Forward pass of the EquivariantStructureDenoisingModule module."""
features = batch["features"]
indexer = batch["indexer"]
metadata = batch["metadata"]
metadata["num_structid"]
max(metadata["num_a_per_sample"])
prot_res_rep_in = features["rec_res_attr_decin"]
timestep_prot = features["timestep_encoding_prot"]
device = features["res_type"].device
# Protein all-atom representation initialization
protatm_padding_mask = features["res_atom_mask"]
protatm_atom37_onehot = torch.nn.functional.one_hot(
features["protatm_to_atom37_index"], num_classes=self.protatm_padding_dim
)
protatm_res_pair_encoding = self.protatm_type_encoding(
torch.cat(
[
prot_res_rep_in[indexer["protatm_res_idx_res"]],
protatm_atom37_onehot,
],
dim=-1,
)
)
# Gathered AA features from individual graphs
prot_atm_rep_in = features["prot_atom_attr_projected"]
prot_atm_rep_int = self.atm_embed(
torch.cat(
[
prot_atm_rep_in,
self.time_encoding(timestep_prot)[indexer["protatm_res_idx_res"]],
],
dim=-1,
)
)
prot_atm_coords_padded = features["input_protein_coords"]
prot_atm_coords_flat = prot_atm_coords_padded[protatm_padding_mask]
# Embed the rigid body node representations
backbone_frames = get_frame_matrix(
prot_atm_coords_padded[:, 0],
prot_atm_coords_padded[:, 1],
prot_atm_coords_padded[:, 2],
)
prot_res_rep = self.init_scalar_vec_rep(
self.input_node_projector(prot_res_rep_in),
x_v=self.input_node_vec_projector(prot_res_rep_in),
frame=backbone_frames,
)
prot_atm_rep = self.init_scalar_vec_rep(prot_atm_rep_int)
# gather AA features from individual graphs
node_reps = {
"prot_res": prot_res_rep,
"prot_atm": prot_atm_rep,
}
# Embed pair representations
edge_reps = {
"residue_to_residue": features["res_res_pair_attr_decin"],
"prot_atm_to_prot_atm_graph": features["prot_atom_pair_attr_projected"],
"prot_atm_to_prot_atm_knn": features["knn_feat_protatm_protatm"],
"prot_atm_to_residue": protatm_res_pair_encoding,
"residue_to_prot_atm": protatm_res_pair_encoding,
"sampled_residue_to_sampled_residue": features["res_res_grid_attr_flat_decin"],
}
if not batch["misc"]["protein_only"]:
max(metadata["num_i_per_sample"])
timestep_lig = features["timestep_encoding_lig"]
lig_atm_rep_in = features["lig_atom_attr_projected"]
lig_frame_rep_in = features["lig_trp_attr_decin"]
# Ligand atom embedding. Two timescales
lig_atm_rep_int = self.atm_embed(
torch.cat(
[lig_atm_rep_in, self.time_encoding(timestep_lig)],
dim=-1,
)
)
lig_atm_rep = self.init_scalar_vec_rep(lig_atm_rep_int)
# Prepare ligand atom - sidechain atom indexers
# Initialize coordinate features
lig_atm_coords = features["input_ligand_coords"].clone()
lig_frame_atm_idx = (
indexer["gather_idx_ijk_i"][indexer["gather_idx_I_ijk"]],
indexer["gather_idx_ijk_j"][indexer["gather_idx_I_ijk"]],
indexer["gather_idx_ijk_k"][indexer["gather_idx_I_ijk"]],
)
ligand_trp_frames = get_frame_matrix(
lig_atm_coords[lig_frame_atm_idx[0]],
lig_atm_coords[lig_frame_atm_idx[1]],
lig_atm_coords[lig_frame_atm_idx[2]],
)
lig_frame_rep = self.init_scalar_vec_rep(
self.input_node_projector(lig_frame_rep_in),
x_v=self.input_node_vec_projector(lig_frame_rep_in),
frame=ligand_trp_frames,
)
node_reps.update(
{
"lig_atm": lig_atm_rep,
"lig_trp": lig_frame_rep,
}
)
edge_reps.update(
{
"lig_atm_to_lig_atm_graph": features["lig_atom_pair_attr_projected"],
"sampled_lig_triplet_to_lig_atm": features["lig_af_pair_attr_projected"],
"lig_atm_to_sampled_lig_triplet": features["lig_af_pair_attr_projected"],
"sampled_residue_to_sampled_lig_triplet": features[
"res_trp_grid_attr_flat_decin"
],
"sampled_lig_triplet_to_sampled_residue": features[
"res_trp_grid_attr_flat_decin"
],
"residue_to_sampled_lig_triplet": features["res_trp_pair_attr_flat_decin"],
"sampled_lig_triplet_to_residue": features["res_trp_pair_attr_flat_decin"],
"sampled_lig_triplet_to_sampled_lig_triplet": features[
"trp_trp_grid_attr_flat_decin"
],
"prot_atm_to_lig_atm_knn": features["knn_feat_protatm_ligatm"],
"lig_atm_to_prot_atm_knn": features["knn_feat_ligatm_protatm"],
"lig_atm_to_lig_atm_knn": features["knn_feat_ligatm_ligatm"],
}
)
# Message passing
protatm_res_idx_res = indexer["protatm_res_idx_res"]
if batch["misc"]["protein_only"]:
graph_relations = self.graph_relations_no_ligand
else:
graph_relations = self.graph_relations
graph_batcher = make_multi_relation_graph_batcher(graph_relations, indexer, metadata)
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
merged_edge_reps = graph_batcher.collate_edge_attr(
graph_batcher.zero_pad_edge_attr(edge_reps, self.input_pair_dim, device)
)
merged_edge_reps = self.input_pair_projector(merged_edge_reps)
assert merged_edge_idx[0].shape[0] == merged_edge_reps.shape[0]
assert merged_edge_idx[1].shape[0] == merged_edge_reps.shape[0]
dummy_prot_atm_frames = RigidTransform(prot_atm_coords_flat, R=None)
if not batch["misc"]["protein_only"]:
dummy_lig_atm_frames = RigidTransform(lig_atm_coords, R=None)
merged_node_t = graph_batcher.collate_node_attr(
{
"prot_res": backbone_frames.t,
"prot_atm": dummy_prot_atm_frames.t,
"lig_atm": dummy_lig_atm_frames.t,
"lig_trp": ligand_trp_frames.t,
}
)
merged_node_R = graph_batcher.collate_node_attr(
{
"prot_res": backbone_frames.R,
"prot_atm": dummy_prot_atm_frames.R,
"lig_atm": dummy_lig_atm_frames.R,
"lig_trp": ligand_trp_frames.R,
}
)
else:
merged_node_t = graph_batcher.collate_node_attr(
{
"prot_res": backbone_frames.t,
"prot_atm": dummy_prot_atm_frames.t,
}
)
merged_node_R = graph_batcher.collate_node_attr(
{
"prot_res": backbone_frames.R,
"prot_atm": dummy_prot_atm_frames.R,
}
)
merged_node_frames = RigidTransform(merged_node_t, merged_node_R)
merged_edge_reps = merged_edge_reps + (
self.rel_geom_embed(
self.rel_geom_enc(merged_node_frames, merged_edge_idx) + merged_edge_reps
)
)
# No need to reassign embeddings but need to update point coordinates & frames
for block_id in range(self.n_blocks):
dummy_prot_atm_frames = RigidTransform(prot_atm_coords_flat, R=None)
if not batch["misc"]["protein_only"]:
dummy_lig_atm_frames = RigidTransform(lig_atm_coords, R=None)
merged_node_t = graph_batcher.collate_node_attr(
{
"prot_res": backbone_frames.t,
"prot_atm": dummy_prot_atm_frames.t,
"lig_atm": dummy_lig_atm_frames.t,
"lig_trp": ligand_trp_frames.t,
}
)
merged_node_R = graph_batcher.collate_node_attr(
{
"prot_res": backbone_frames.R,
"prot_atm": dummy_prot_atm_frames.R,
"lig_atm": dummy_lig_atm_frames.R,
"lig_trp": ligand_trp_frames.R,
}
)
else:
merged_node_t = graph_batcher.collate_node_attr(
{
"prot_res": backbone_frames.t,
"prot_atm": dummy_prot_atm_frames.t,
}
)
merged_node_R = graph_batcher.collate_node_attr(
{
"prot_res": backbone_frames.R,
"prot_atm": dummy_prot_atm_frames.R,
}
)
# PredictDrift iteration
merged_node_reps, merged_edge_reps = self.ipa_modules[block_id](
merged_node_reps,
merged_edge_idx,
t=merged_node_t,
R=merged_node_R,
x_edge=merged_edge_reps,
)
offloaded_node_reps = graph_batcher.offload_node_attr(merged_node_reps)
if "lig_trp" in offloaded_node_reps.keys():
lig_frame_rep = offloaded_node_reps["lig_trp"]
offloaded_node_reps["lig_trp"] = lig_frame_rep + self.res_adapters[block_id](
lig_frame_rep, ligand_trp_frames.R, extra_feats=lig_frame_rep_in
)
prot_res_rep = offloaded_node_reps["prot_res"]
offloaded_node_reps["prot_res"] = prot_res_rep + self.res_adapters[block_id](
prot_res_rep, backbone_frames.R, extra_feats=prot_res_rep_in
)
merged_node_reps = graph_batcher.collate_node_attr(offloaded_node_reps)
# Displacement vectors in the global coordinate system
if not batch["misc"]["protein_only"]:
drift_trp = (
self.out_drift_res[block_id](offloaded_node_reps["lig_trp"][:, 1:]).squeeze(-1)
* torch.sigmoid(
self.out_scale_res[block_id](offloaded_node_reps["lig_trp"][:, 0])
)
* 10
)
drift_trp_gathered = segment_mean(
drift_trp,
indexer["gather_idx_I_molid"],
metadata["num_molid"],
)[indexer["gather_idx_i_molid"]]
drift_atm = self.out_drift_atm[block_id](
offloaded_node_reps["lig_atm"][:, 1:]
).squeeze(-1) * torch.sigmoid(
self.out_scale_atm[block_id](offloaded_node_reps["lig_atm"][:, 0])
)
if not frozen_lig:
lig_atm_coords = lig_atm_coords + drift_atm + drift_trp_gathered
ligand_trp_frames = get_frame_matrix(
lig_atm_coords[lig_frame_atm_idx[0]],
lig_atm_coords[lig_frame_atm_idx[1]],
lig_atm_coords[lig_frame_atm_idx[2]],
)
drift_bb = (
self.out_drift_res[block_id](offloaded_node_reps["prot_res"][:, 1:]).squeeze(-1)
* torch.sigmoid(
self.out_scale_res[block_id](offloaded_node_reps["prot_res"][:, 0])
)
* 10
)
drift_bb_gathered = drift_bb[protatm_res_idx_res]
drift_prot_atm_int = self.out_drift_atm[block_id](
offloaded_node_reps["prot_atm"][:, 1:]
).squeeze(-1) * torch.sigmoid(
self.out_scale_atm[block_id](offloaded_node_reps["prot_atm"][:, 0])
)
if not frozen_prot:
prot_atm_coords_flat = (
prot_atm_coords_flat + drift_prot_atm_int + drift_bb_gathered
)
prot_atm_coords_padded = torch.zeros_like(features["input_protein_coords"])
prot_atm_coords_padded[protatm_padding_mask] = prot_atm_coords_flat
backbone_frames = get_frame_matrix(
prot_atm_coords_padded[:, 0],
prot_atm_coords_padded[:, 1],
prot_atm_coords_padded[:, 2],
)
ret = {
"final_embedding_prot_atom": offloaded_node_reps["prot_atm"],
"final_embedding_prot_res": offloaded_node_reps["prot_res"],
"final_coords_prot_atom": prot_atm_coords_flat,
"final_coords_prot_atom_padded": prot_atm_coords_padded,
}
if not batch["misc"]["protein_only"]:
ret["final_embedding_lig_atom"] = offloaded_node_reps["lig_atm"]
ret["final_coords_lig_atom"] = lig_atm_coords
else:
ret["final_embedding_lig_atom"] = None
ret["final_coords_lig_atom"] = None
return ret
def resolve_score_head(
protein_model_cfg: DictConfig,
score_cfg: DictConfig,
task_cfg: DictConfig,
state_dict: Optional[STATE_DICT] = None,
) -> torch.nn.Module:
"""Instantiates an EquivariantStructureDenoisingModule model for protein-ligand complex
structure denoising.
:param protein_model_cfg: Protein model configuration.
:param score_cfg: Score configuration.
:param task_cfg: Task configuration.
:param state_dict: Optional (potentially-pretrained) state dictionary.
:return: EquivariantStructureDenoisingModule model.
"""
model = EquivariantStructureDenoisingModule(
score_cfg.fiber_dim,
input_dim=protein_model_cfg.residue_dim,
input_pair_dim=protein_model_cfg.pair_dim,
hidden_dim=score_cfg.hidden_dim,
n_stacks=score_cfg.n_stacks,
n_heads=protein_model_cfg.n_heads,
dropout=task_cfg.dropout,
)
if score_cfg.from_pretrained and state_dict is not None:
try:
model.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("score_head")
}
)
log.info("Successfully loaded pretrained score weights.")
except Exception as e:
log.warning(f"Skipping loading of pretrained score weights due to: {e}.")
return model
def resolve_confidence_head(
protein_model_cfg: DictConfig,
confidence_cfg: DictConfig,
task_cfg: DictConfig,
state_dict: Optional[STATE_DICT] = None,
) -> Tuple[torch.nn.Module, torch.nn.Module]:
"""Instantiates an EquivariantStructureDenoisingModule model for confidence prediction.
:param protein_model_cfg: Protein model configuration.
:param confidence_cfg: Confidence configuration.
:param task_cfg: Task configuration.
:param state_dict: Optional (potentially-pretrained) state dictionary.
:return: EquivariantStructureDenoisingModule model and plDDT gram head weights.
"""
confidence_head = EquivariantStructureDenoisingModule(
confidence_cfg.fiber_dim,
input_dim=protein_model_cfg.residue_dim,
input_pair_dim=protein_model_cfg.pair_dim,
hidden_dim=confidence_cfg.hidden_dim,
n_stacks=confidence_cfg.n_stacks,
n_heads=protein_model_cfg.n_heads,
dropout=task_cfg.dropout,
)
plddt_gram_head = GELUMLP(
protein_model_cfg.pair_dim,
8,
n_hidden_feats=protein_model_cfg.pair_dim,
zero_init=True,
)
if confidence_cfg.from_pretrained and state_dict is not None:
try:
confidence_head.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("confidence_head")
}
)
log.info("Successfully loaded pretrained confidence weights.")
except Exception as e:
log.warning(f"Skipping loading of pretrained confidence weights due to: {e}.")
try:
plddt_gram_head.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("plddt_gram_head")
}
)
log.info("Successfully loaded pretrained pLDDT gram head weights.")
except Exception as e:
log.warning(f"Skipping loading of pretrained pLDDT gram head weights due to: {e}.")
return confidence_head, plddt_gram_head
def resolve_affinity_head(
ligand_model_cfg: DictConfig,
affinity_cfg: DictConfig,
task_cfg: DictConfig,
learnable_pooling: bool = True,
state_dict: Optional[STATE_DICT] = None,
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
"""Instantiates an EquivariantStructureDenoisingModule model for affinity prediction.
:param ligand_model_cfg: Ligand model configuration.
:param affinity_cfg: Affinity configuration.
:param task_cfg: Task configuration.
:param learnable_pooling: Whether to use learnable ligand pooling modules.
:param state_dict: Optional (potentially-pretrained) state dictionary.
:return: EquivariantStructureDenoisingModule model as well as a ligand pooling module and
projection head.
"""
affinity_head = EquivariantStructureDenoisingModule(
affinity_cfg.fiber_dim,
input_dim=ligand_model_cfg.node_channels,
input_pair_dim=ligand_model_cfg.pair_channels,
hidden_dim=affinity_cfg.hidden_dim,
n_stacks=affinity_cfg.n_stacks,
n_heads=ligand_model_cfg.n_heads,
dropout=affinity_cfg.dropout if affinity_cfg.get("dropout") else task_cfg.dropout,
)
if affinity_cfg.ligand_pooling in ["sum", "add", "summation", "addition"]:
ligand_pooling = SumPooling(learnable=learnable_pooling, hidden_dim=affinity_cfg.fiber_dim)
elif affinity_cfg.ligand_pooling in ["mean", "avg", "average"]:
ligand_pooling = AveragePooling(
learnable=learnable_pooling, hidden_dim=affinity_cfg.fiber_dim
)
else:
raise NotImplementedError(
f"Unsupported ligand pooling method: {affinity_cfg.ligand_pooling}"
)
affinity_proj_head = GELUMLP(
affinity_cfg.fiber_dim,
1,
n_hidden_feats=affinity_cfg.fiber_dim,
zero_init=True,
)
if affinity_cfg.from_pretrained and state_dict is not None:
try:
affinity_head.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("affinity_head")
}
)
log.info("Successfully loaded pretrained affinity head weights.")
except Exception as e:
log.warning(f"Skipping loading of pretrained affinity head weights due to: {e}.")
if learnable_pooling:
try:
ligand_pooling.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("ligand_pooling")
}
)
log.info("Successfully loaded pretrained ligand pooling weights.")
except Exception as e:
log.warning(f"Skipping loading of pretrained ligand pooling weights due to: {e}.")
try:
affinity_proj_head.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("affinity_proj_head")
}
)
log.info("Successfully loaded pretrained affinity projection head weights.")
except Exception as e:
log.warning(
f"Skipping loading of pretrained affinity projection head weights due to: {e}."
)
return affinity_head, ligand_pooling, affinity_proj_head

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,166 @@
# Adapted from: https://github.com/zrqiao/NeuralPLexer
from dataclasses import dataclass
import torch
from beartype.typing import Dict, List, Tuple
@dataclass(frozen=True)
class Relation:
edge_type: str
edge_rev_name: str
edge_frd_name: str
src_node_type: str
dst_node_type: str
num_edges: int
class MultiRelationGraphBatcher:
"""Collate sub-graphs of different node/edge types into a single instance.
Returned multi-relation edge indices are stored in LongTensor of shape [2, N_edges].
"""
def __init__(
self,
relation_forms: List[Relation],
graph_metadata: Dict[str, int],
):
"""Initialize the batcher."""
self._relation_forms = relation_forms
self._make_offset_dict(graph_metadata)
def _make_offset_dict(self, graph_metadata):
"""Create offset dictionaries for node and edge types."""
self._node_chunk_sizes = {}
self._edge_chunk_sizes = {}
self._offsets_lower = {}
self._offsets_upper = {}
all_node_types = set()
for relation in self._relation_forms:
assert (
f"num_{relation.src_node_type}" in graph_metadata.keys()
), f"Missing metadata: num_{relation.src_node_type}"
assert (
f"num_{relation.dst_node_type}" in graph_metadata.keys()
), f"Missing metadata: num_{relation.src_node_type}"
all_node_types.add(relation.src_node_type)
all_node_types.add(relation.dst_node_type)
offset = 0
# Fix node type ordering
self.all_node_types = list(all_node_types)
for node_type in self.all_node_types:
self._offsets_lower[node_type] = offset
self._node_chunk_sizes[node_type] = graph_metadata[f"num_{node_type}"]
new_offset = offset + self._node_chunk_sizes[node_type]
self._offsets_upper[node_type] = new_offset
offset = new_offset
def collate_single_relation_graphs(self, indexer, node_attr_dict, edge_attr_dict):
"""Collate sub-graphs of different node/edge types into a single instance."""
return {
"node_attr": self.collate_node_attr(node_attr_dict),
"edge_attr": self.collate_edge_attr(edge_attr_dict),
"edge_index": self.collate_idx_list(indexer),
}
def collate_idx_list(
self,
indexer: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Collate edge indices for all relations."""
ret_eidxs_rev, ret_eidxs_frd = [], []
for relation in self._relation_forms:
assert relation.edge_rev_name in indexer.keys()
assert relation.edge_frd_name in indexer.keys()
assert indexer[relation.edge_rev_name].dim() == 1
assert indexer[relation.edge_frd_name].dim() == 1
assert torch.all(
indexer[relation.edge_rev_name] < self._node_chunk_sizes[relation.src_node_type]
), f"Node index on edge exceeding boundary: {relation.edge_type}, {self._node_chunk_sizes[relation.src_node_type]}, {self._node_chunk_sizes[relation.dst_node_type]}, {max(indexer[relation.edge_rev_name])}, {max(indexer[relation.edge_frd_name])}"
assert torch.all(
indexer[relation.edge_frd_name] < self._node_chunk_sizes[relation.dst_node_type]
), f"Node index on edge exceeding boundary: {relation.edge_type}, {self._node_chunk_sizes[relation.src_node_type]}, {self._node_chunk_sizes[relation.dst_node_type]}, {max(indexer[relation.edge_rev_name])}, {max(indexer[relation.edge_frd_name])}"
ret_eidxs_rev.append(
indexer[relation.edge_rev_name] + self._offsets_lower[relation.src_node_type]
)
ret_eidxs_frd.append(
indexer[relation.edge_frd_name] + self._offsets_lower[relation.dst_node_type]
)
ret_eidxs_rev = torch.cat(ret_eidxs_rev, dim=0)
ret_eidxs_frd = torch.cat(ret_eidxs_frd, dim=0)
return torch.stack([ret_eidxs_rev, ret_eidxs_frd], dim=0)
def collate_node_attr(self, node_attr_dict: Dict[str, torch.Tensor]):
"""Collate node attributes for all node types."""
for node_type in self.all_node_types:
assert (
node_attr_dict[node_type].shape[0] == self._node_chunk_sizes[node_type]
), f"Node count mismatch: {node_type}, {node_attr_dict[node_type].shape[0]}, {self._node_chunk_sizes[node_type]}"
return torch.cat([node_attr_dict[node_type] for node_type in self.all_node_types], dim=0)
def collate_edge_attr(self, edge_attr_dict: Dict[str, torch.Tensor]):
"""Collate edge attributes for all relations."""
# for relation in self._relation_forms:
# print(relation.edge_type, edge_attr_dict[relation.edge_type].shape)
return torch.cat(
[edge_attr_dict[relation.edge_type] for relation in self._relation_forms],
dim=0,
)
def zero_pad_edge_attr(
self,
edge_attr_dict: Dict[str, torch.Tensor],
embedding_dim: int,
device: torch.device,
):
"""Zero pad edge attributes for all relations."""
for relation in self._relation_forms:
if edge_attr_dict[relation.edge_type] is None:
edge_attr_dict[relation.edge_type] = torch.zeros(
(relation.num_edges, embedding_dim),
device=device,
)
return edge_attr_dict
def offload_node_attr(self, cat_node_attr: torch.Tensor):
"""Offload node attributes for all node types."""
node_chunk_sizes = [self._node_chunk_sizes[node_type] for node_type in self.all_node_types]
node_attr_split = torch.split(cat_node_attr, node_chunk_sizes)
return {
self.all_node_types[i]: node_attr_split[i] for i in range(len(self.all_node_types))
}
def offload_edge_attr(self, cat_edge_attr: torch.Tensor):
"""Offload edge attributes for all relations."""
edge_chunk_sizes = [relation.num_edges for relation in self._relation_forms]
edge_attr_split = torch.split(cat_edge_attr, edge_chunk_sizes)
return {
self._relation_forms[i].edge_type: edge_attr_split[i]
for i in range(len(self._relation_forms))
}
def make_multi_relation_graph_batcher(
list_of_relations: List[Tuple[str, str, str, str, str]],
indexer,
metadata,
):
"""Make a multi-relation graph batcher."""
# Use one instantiation of the indexer to compute chunk sizes
relation_forms = [
Relation(
edge_type=rl_tuple[0],
edge_rev_name=rl_tuple[1],
edge_frd_name=rl_tuple[2],
src_node_type=rl_tuple[3],
dst_node_type=rl_tuple[4],
num_edges=indexer[rl_tuple[1]].shape[0],
)
for rl_tuple in list_of_relations
]
return MultiRelationGraphBatcher(
relation_forms,
metadata,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,364 @@
# Adapted from: https://github.com/zrqiao/NeuralPLexer
import rootutils
import torch
from beartype.typing import Any, Dict, Optional, Tuple
from omegaconf import DictConfig
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
from flowdock.models.components.modules import TransformerLayer
from flowdock.utils import RankedLogger
from flowdock.utils.model_utils import GELUMLP, segment_softmax, segment_sum
MODEL_BATCH = Dict[str, Any]
STATE_DICT = Dict[str, Any]
log = RankedLogger(__name__, rank_zero_only=True)
class PathConvStack(torch.nn.Module):
"""Path integral convolution stack for ligand encoding."""
def __init__(
self,
pair_channels: int,
n_heads: int = 8,
max_pi_length: int = 8,
dropout: float = 0.0,
):
"""Initialize PathConvStack model."""
super().__init__()
self.pair_channels = pair_channels
self.max_pi_length = max_pi_length
self.n_heads = n_heads
self.prop_value_layer = torch.nn.Linear(pair_channels, n_heads, bias=False)
self.triangle_pair_kernel_layer = torch.nn.Linear(pair_channels, n_heads, bias=False)
self.prop_update_mlp = GELUMLP(
n_heads * (max_pi_length + 1), pair_channels, dropout=dropout
)
def forward(
self,
prop_attr: torch.Tensor,
stereo_attr: torch.Tensor,
indexer: Dict[str, torch.LongTensor],
metadata: Dict[str, Any],
) -> torch.Tensor:
"""Forward pass for PathConvStack model.
:param prop_attr: Atom-frame pair attributes.
:param stereo_attr: Stereochemistry attributes.
:param indexer: A dictionary of indices.
:param metadata: A dictionary of metadata.
:return: Updated atom-frame pair attributes.
"""
triangle_pair_kernel = self.triangle_pair_kernel_layer(stereo_attr)
# Segment-wise softmax, normalized by outgoing triangles
triangle_pair_alpha = segment_softmax(
triangle_pair_kernel, indexer["gather_idx_ijkl_jkl"], metadata["num_ijk"]
) # .div(self.max_pi_length)
# Uijk,ijkl->ujkl pair representation update
kernel = triangle_pair_alpha[indexer["gather_idx_Uijkl_ijkl"]]
out_prop_attr = [self.prop_value_layer(prop_attr)]
for _ in range(self.max_pi_length):
gathered_prop_attr = out_prop_attr[-1][indexer["gather_idx_Uijkl_Uijk"]]
out_prop_attr.append(
segment_sum(
kernel.mul(gathered_prop_attr),
indexer["gather_idx_Uijkl_ujkl"],
metadata["num_Uijk"],
)
)
new_prop_attr = torch.cat(out_prop_attr, dim=-1)
new_prop_attr = self.prop_update_mlp(new_prop_attr) + prop_attr
return new_prop_attr
class PIFormer(torch.nn.Module):
"""PIFormer model for ligand encoding."""
def __init__(
self,
node_channels: int,
pair_channels: int,
n_atom_encodings: int,
n_bond_encodings: int,
n_atom_pos_encodings: int,
n_stereo_encodings: int,
heads: int,
head_dim: int,
max_path_length: int = 4,
n_transformer_stacks=4,
hidden_dim: Optional[int] = None,
dropout: float = 0.0,
):
"""Initialize PIFormer model."""
super().__init__()
self.node_channels = node_channels
self.pair_channels = pair_channels
self.max_pi_length = max_path_length
self.n_transformer_stacks = n_transformer_stacks
self.n_atom_encodings = n_atom_encodings
self.n_bond_encodings = n_bond_encodings
self.n_atom_pair_encodings = n_bond_encodings + 4
self.n_atom_pos_encodings = n_atom_pos_encodings
self.input_atom_layer = torch.nn.Linear(n_atom_encodings, node_channels)
self.input_pair_layer = torch.nn.Linear(self.n_atom_pair_encodings, pair_channels)
self.input_stereo_layer = torch.nn.Linear(n_stereo_encodings, pair_channels)
self.input_prop_layer = GELUMLP(
self.n_atom_pair_encodings * 3,
pair_channels,
)
self.path_integral_stacks = torch.nn.ModuleList(
[
PathConvStack(
pair_channels,
max_pi_length=max_path_length,
dropout=dropout,
)
for _ in range(n_transformer_stacks)
]
)
self.graph_transformer_stacks = torch.nn.ModuleList(
[
TransformerLayer(
node_channels,
heads,
head_dim=head_dim,
edge_channels=pair_channels,
hidden_dim=hidden_dim,
dropout=dropout,
edge_update=True,
)
for _ in range(n_transformer_stacks)
]
)
def forward(self, batch: MODEL_BATCH, masking_rate: float = 0.0) -> MODEL_BATCH:
"""Forward pass for PIFormer model.
:param batch: A batch dictionary.
:param masking_rate: Masking rate.
:return: A batch dictionary.
"""
features = batch["features"]
indexer = batch["indexer"]
metadata = batch["metadata"]
features["atom_encodings"] = features["atom_encodings"]
atom_attr = features["atom_encodings"]
atom_pair_attr = features["atom_pair_encodings"]
af_pair_attr = features["atom_frame_pair_encodings"]
stereo_enc = features["stereo_chemistry_encodings"]
batch["features"]["lig_atom_token"] = atom_attr.detach().clone()
batch["features"]["lig_pair_token"] = atom_pair_attr.detach().clone()
atom_mask = torch.rand(atom_attr.shape[0], device=atom_attr.device) > masking_rate
stereo_mask = torch.rand(stereo_enc.shape[0], device=stereo_enc.device) > masking_rate
atom_pair_mask = (
torch.rand(atom_pair_attr.shape[0], device=atom_pair_attr.device) > masking_rate
)
af_pair_mask = (
torch.rand(af_pair_attr.shape[0], device=atom_pair_attr.device) > masking_rate
)
atom_attr = atom_attr * atom_mask[:, None]
stereo_enc = stereo_enc * stereo_mask[:, None]
atom_pair_attr = atom_pair_attr * atom_pair_mask[:, None]
af_pair_attr = af_pair_attr * af_pair_mask[:, None]
# Embedding blocks
metadata["num_atom"] = metadata["num_u"]
metadata["num_frame"] = metadata["num_ijk"]
atom_attr = self.input_atom_layer(atom_attr)
atom_pair_attr = self.input_pair_layer(atom_pair_attr)
triangle_attr = atom_attr.new_zeros(metadata["num_frame"], self.node_channels)
# Initialize atom-frame pair attributes. Reusing uv indices
prop_attr = self.input_prop_layer(af_pair_attr)
stereo_attr = self.input_stereo_layer(stereo_enc)
graph_relations = [
("atom_to_atom", "gather_idx_uv_u", "gather_idx_uv_v", "atom", "atom"),
(
"atom_to_frame",
"gather_idx_Uijk_u",
"gather_idx_Uijk_ijk",
"atom",
"frame",
),
(
"frame_to_atom",
"gather_idx_Uijk_ijk",
"gather_idx_Uijk_u",
"frame",
"atom",
),
(
"frame_to_frame",
"gather_idx_ijkl_ijk",
"gather_idx_ijkl_jkl",
"frame",
"frame",
),
]
graph_batcher = make_multi_relation_graph_batcher(graph_relations, indexer, metadata)
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
node_reps = {"atom": atom_attr, "frame": triangle_attr}
edge_reps = {
"atom_to_atom": atom_pair_attr,
"atom_to_frame": prop_attr,
"frame_to_atom": prop_attr,
"frame_to_frame": stereo_attr,
}
# Graph path integral recursion
for block_id in range(self.n_transformer_stacks):
merged_node_attr = graph_batcher.collate_node_attr(node_reps)
merged_edge_attr = graph_batcher.collate_edge_attr(edge_reps)
_, merged_node_attr, merged_edge_attr = self.graph_transformer_stacks[block_id](
merged_node_attr,
merged_node_attr,
merged_edge_idx,
merged_edge_attr,
)
node_reps = graph_batcher.offload_node_attr(merged_node_attr)
edge_reps = graph_batcher.offload_edge_attr(merged_edge_attr)
prop_attr = edge_reps["atom_to_frame"]
stereo_attr = edge_reps["frame_to_frame"]
prop_attr = prop_attr + self.path_integral_stacks[block_id](
prop_attr,
stereo_attr,
indexer,
metadata,
)
edge_reps["atom_to_frame"] = prop_attr
node_reps["sampled_frame"] = node_reps["frame"][indexer["gather_idx_I_ijk"]]
batch["metadata"]["num_lig_atm"] = metadata["num_u"]
batch["metadata"]["num_lig_trp"] = metadata["num_I"]
batch["features"]["lig_atom_attr"] = node_reps["atom"]
# Downsampled ligand frames
batch["features"]["lig_trp_attr"] = node_reps["sampled_frame"]
batch["features"]["lig_atom_pair_attr"] = edge_reps["atom_to_atom"]
batch["features"]["lig_prop_attr"] = edge_reps["atom_to_frame"]
edge_reps["sampled_atom_to_sampled_frame"] = edge_reps["atom_to_frame"][
indexer["gather_idx_UI_Uijk"]
]
batch["features"]["lig_af_pair_attr"] = edge_reps["sampled_atom_to_sampled_frame"]
return batch
def resolve_ligand_encoder(
ligand_model_cfg: DictConfig,
task_cfg: DictConfig,
state_dict: Optional[STATE_DICT] = None,
) -> torch.nn.Module:
"""Instantiates a PIFormer model for ligand encoding.
:param ligand_model_cfg: Ligand model configuration.
:param task_cfg: Task configuration.
:param state_dict: Optional (potentially-pretrained) state dictionary.
:return: Ligand encoder model.
"""
model = PIFormer(
ligand_model_cfg.node_channels,
ligand_model_cfg.pair_channels,
ligand_model_cfg.n_atom_encodings,
ligand_model_cfg.n_bond_encodings,
ligand_model_cfg.n_atom_pos_encodings,
ligand_model_cfg.n_stereo_encodings,
ligand_model_cfg.n_attention_heads,
ligand_model_cfg.attention_head_dim,
hidden_dim=ligand_model_cfg.hidden_dim,
max_path_length=ligand_model_cfg.max_path_integral_length,
n_transformer_stacks=ligand_model_cfg.n_transformer_stacks,
dropout=task_cfg.dropout,
)
if ligand_model_cfg.from_pretrained and state_dict is not None:
try:
model.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("ligand_encoder")
}
)
log.info(
"Successfully loaded pretrained ligand Molecular Heat Transformer (MHT) weights."
)
except Exception as e:
log.warning(
f"Skipping loading of pretrained ligand Molecular Heat Transformer (MHT) weights due to: {e}."
)
return model
def resolve_relational_reasoning_module(
protein_model_cfg: DictConfig,
ligand_model_cfg: DictConfig,
relational_reasoning_cfg: DictConfig,
state_dict: Optional[STATE_DICT] = None,
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
"""Instantiates relational reasoning module for ligand encoding.
:param protein_model_cfg: Protein model configuration.
:param ligand_model_cfg: Ligand model configuration.
:param relational_reasoning_cfg: Relational reasoning configuration.
:param state_dict: Optional (potentially-pretrained) state dictionary.
:return: Relational reasoning modules for ligand encoding.
"""
molgraph_single_projector = torch.nn.Linear(
ligand_model_cfg.node_channels, protein_model_cfg.residue_dim, bias=False
)
molgraph_pair_projector = torch.nn.Linear(
ligand_model_cfg.pair_channels, protein_model_cfg.pair_dim, bias=False
)
covalent_embed = torch.nn.Embedding(2, protein_model_cfg.pair_dim)
if relational_reasoning_cfg.from_pretrained and state_dict is not None:
try:
molgraph_single_projector.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("molgraph_single_projector")
}
)
log.info("Successfully loaded pretrained ligand graph single projector weights.")
except Exception as e:
log.warning(
f"Skipping loading of pretrained ligand graph single projector weights due to: {e}."
)
try:
molgraph_pair_projector.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("molgraph_pair_projector")
}
)
log.info("Successfully loaded pretrained ligand graph pair projector weights.")
except Exception as e:
log.warning(
f"Skipping loading of pretrained ligand graph pair projector weights due to: {e}."
)
try:
covalent_embed.load_state_dict(
{
".".join(k.split(".")[1:]): v
for k, v in state_dict.items()
if k.startswith("covalent_embed")
}
)
log.info("Successfully loaded pretrained ligand covalent embedding weights.")
except Exception as e:
log.warning(
f"Skipping loading of pretrained ligand covalent embedding weights due to: {e}."
)
return molgraph_single_projector, molgraph_pair_projector, covalent_embed

View File

@@ -0,0 +1,423 @@
import math
import rootutils
import torch
import torch.nn.functional as F
from beartype.typing import Optional, Tuple, Union
from openfold.model.primitives import Attention
from openfold.utils.tensor_utils import permute_final_dims
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.utils.model_utils import GELUMLP, segment_softmax
class MultiHeadAttentionConv(torch.nn.Module):
"""Native Pytorch implementation."""
def __init__(
self,
dim: Union[int, Tuple[int, int]],
head_dim: int,
edge_dim: int = None,
n_heads: int = 1,
dropout: float = 0.0,
edge_lin: bool = True,
**kwargs,
):
"""Multi-Head Attention Convolution layer."""
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_heads = n_heads
self.dropout = dropout
self.edge_dim = edge_dim
self.edge_lin = edge_lin
self._alpha = None
if isinstance(dim, int):
dim = (dim, dim)
self.lin_key = torch.nn.Linear(dim[0], n_heads * head_dim, bias=False)
self.lin_query = torch.nn.Linear(dim[1], n_heads * head_dim, bias=False)
self.lin_value = torch.nn.Linear(dim[0], n_heads * head_dim, bias=False)
if edge_lin is True:
self.lin_edge = torch.nn.Linear(edge_dim, n_heads, bias=False)
else:
self.lin_edge = self.register_parameter("lin_edge", None)
self.reset_parameters()
def reset_parameters(self):
"""Reset the parameters of the layer."""
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.edge_lin:
self.lin_edge.reset_parameters()
def forward(
self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
edge_index: torch.Tensor,
edge_attr: torch.Tensor = None,
return_attention_weights=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Forward pass of the Multi-Head Attention Convolution layer.
:param x (torch.Tensor or Tuple[torch.Tensor, torch.Tensor]): The input features.
:param edge_index (torch.Tensor): The edge index tensor.
:param edge_attr (torch.Tensor, optional): The edge attribute tensor.
:param return_attention_weights (bool, optional): If set to `True`,
will additionally return the tuple
`(edge_index, attention_weights)`, holding the computed
attention weights for each edge. Default is `None`.
:return: The output features or the tuple
`(output_features, (edge_index, attention_weights))`.
"""
H, C = self.n_heads, self.head_dim
if isinstance(x, torch.Tensor):
x = (x, x)
query = self.lin_query(x[1]).view(*x[1].shape[:-1], H, C)
key = self.lin_key(x[0]).view(*x[0].shape[:-1], H, C)
value = self.lin_value(x[0]).view(*x[0].shape[:-1], H, C)
attended_values = self.message(key, query, value, edge_attr, edge_index)
out = self.aggregate(attended_values, edge_index[1], query.shape[0])
alpha = self._alpha
self._alpha = None
out = out.contiguous().view(*out.shape[:-2], H * C)
if isinstance(return_attention_weights, bool):
assert alpha is not None
return out, (edge_index, alpha)
else:
return out
def message(
self,
key: torch.Tensor,
query: torch.Tensor,
value: torch.Tensor,
edge_attr: torch.Tensor,
index: torch.Tensor,
) -> torch.Tensor:
"""Add the relative positional encodings to attention scores.
:param key (torch.Tensor): The key tensor.
:param query (torch.Tensor): The query tensor.
:param value (torch.Tensor): The value tensor.
:param edge_attr (torch.Tensor): The edge attribute tensor.
:param index (torch.Tensor): The edge index tensor.
:return: The output tensor.
"""
edge_bias = 0
if self.lin_edge is not None:
assert edge_attr is not None
edge_bias = self.lin_edge(edge_attr)
_alpha_z = (query[index[1]] * key[index[0]]).sum(dim=-1) / math.sqrt(
self.head_dim
) + edge_bias
self._alpha = _alpha_z
alpha = segment_softmax(_alpha_z, index[1], query.shape[0])
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = value[index[0]]
out *= alpha.unsqueeze(-1)
return out
def aggregate(
self, src: torch.Tensor, dst_idx: torch.Tensor, dst_size: torch.Size
) -> torch.Tensor:
"""Aggregate the source tensor to the destination tensor.
:param src (torch.Tensor): The source tensor.
:param dst_idx (torch.Tensor): The destination index tensor.
:param dst_size (torch.Size): The destination size tensor.
:return: The output tensor.
"""
out = torch.zeros(
dst_size,
*src.shape[1:],
dtype=src.dtype,
device=src.device,
).index_add_(0, dst_idx, src)
return out
class TransformerLayer(torch.nn.Module):
"""A single layer of a transformer model."""
def __init__(
self,
node_dim: int,
n_heads: int,
head_dim: Optional[int] = None,
hidden_dim: Optional[int] = None,
bidirectional: bool = False,
edge_channels: Optional[int] = None,
dropout: float = 0.0,
edge_update: bool = False,
):
"""Initialize the transformer layer."""
super().__init__()
edge_lin = edge_channels is not None
self.edge_update = edge_update
if head_dim is None:
head_dim = node_dim // n_heads
self.conv = MultiHeadAttentionConv(
node_dim,
head_dim,
edge_dim=edge_channels,
n_heads=n_heads,
edge_lin=edge_lin,
dropout=dropout,
)
self.bidirectional = bidirectional
self.projector = torch.nn.Linear(head_dim * n_heads, node_dim, bias=False)
self.norm = torch.nn.LayerNorm(node_dim)
self.mlp = GELUMLP(
node_dim,
node_dim,
n_hidden_feats=hidden_dim,
dropout=dropout,
zero_init=True,
)
if edge_update:
self.mlpe = GELUMLP(
n_heads + edge_channels, edge_channels, dropout=dropout, zero_init=True
)
def forward(
self,
x_s: torch.Tensor,
x_a: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Forward pass through the transformer layer.
:param x_s (torch.Tensor): The source node features.
:param x_a (torch.Tensor): The target node features.
:param edge_index (torch.Tensor): The edge index tensor. :param edge_attr (torch.Tensor,
optional): The edge attribute tensor.
:return: The output source and target node features.
"""
out_a, (edge_index, alpha) = self.conv(
(x_s, x_a),
edge_index,
edge_attr,
return_attention_weights=True,
)
x_a = x_a + self.projector(out_a)
x_a = self.mlp(self.norm(x_a)) + x_a
if self.bidirectional:
out_s = self.conv((x_a, x_s), (edge_index[1], edge_index[0]), edge_attr)
x_s = x_s + self.projector(out_s)
x_s = self.mlp(self.norm(x_s)) + x_s
if self.edge_update:
edge_attr = edge_attr + self.mlpe(torch.cat([alpha, edge_attr], dim=-1))
return x_s, x_a, edge_attr
else:
return x_s, x_a
class PointSetAttention(torch.nn.Module):
"""PointSetAttention module."""
def __init__(
self,
fiber_dim: int,
heads: int = 8,
point_dim: int = 4,
edge_dim: Optional[int] = None,
edge_update: bool = False,
dropout: float = 0.0,
):
"""Initialize the PointSetAttention module."""
super().__init__()
self.fiber_dim = fiber_dim
self.edge_dim = edge_dim
self.heads = heads
self.point_dim = point_dim
self.dropout = dropout
self.edge_update = edge_update
self.distance_scaling = 10 # 1 nm
# num attention contributions
num_attn_logits = 2
self.lin_query = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
self.lin_key = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
self.lin_value = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
if edge_dim is not None:
self.lin_edge = torch.nn.Linear(edge_dim, heads, bias=False)
if edge_update:
self.edge_update_mlp = GELUMLP(heads + edge_dim, edge_dim)
# qkv projection for scalar attention (normal)
self.scalar_attn_logits_scale = (num_attn_logits * point_dim) ** -0.5
# qkv projection for point attention (coordinate and orientation aware)
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.0)) - 1.0)
self.point_weights = torch.nn.Parameter(point_weight_init_value)
self.point_attn_logits_scale = ((num_attn_logits * point_dim) * (9 / 2)) ** -0.5
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.0)) - 1.0)
self.point_weights = torch.nn.Parameter(point_weight_init_value)
# combine out - point dim * 4
self.to_out = torch.nn.Linear(heads * point_dim, fiber_dim, bias=False)
def forward(
self,
x_k: torch.Tensor,
x_q: torch.Tensor,
edge_index: torch.LongTensor,
point_centers_k: torch.Tensor,
point_centers_q: torch.Tensor,
x_edge: torch.Tensor = None,
):
"""Forward pass of the PointSetAttention module."""
H, P = self.heads, self.point_dim
q = self.lin_query(x_q)
k = self.lin_key(x_k)
v = self.lin_value(x_k)
scalar_q = q[..., 0, :].view(-1, H, P)
scalar_k = k[..., 0, :].view(-1, H, P)
scalar_v = v[..., 0, :].view(-1, H, P)
point_q_local = q[..., 1:, :].view(-1, 3, H, P)
point_k_local = k[..., 1:, :].view(-1, 3, H, P)
point_v_local = v[..., 1:, :].view(-1, 3, H, P)
point_q = point_q_local + point_centers_q[..., None, None] / self.distance_scaling
point_k = point_k_local + point_centers_k[..., None, None] / self.distance_scaling
point_v = point_v_local + point_centers_k[..., None, None] / self.distance_scaling
if self.edge_dim is not None:
edge_bias = self.lin_edge(x_edge)
else:
edge_bias = 0
attn_logits, attentions = self.compute_attention(
scalar_k, scalar_q, point_k, point_q, edge_bias, edge_index
)
res_scalar = self.aggregate(
attentions[:, :, None] * scalar_v[edge_index[0]],
edge_index[1],
scalar_q.shape[0],
)
res_points = self.aggregate(
attentions[:, None, :, None] * point_v[edge_index[0]],
edge_index[1],
point_q.shape[0],
)
res_points_local = res_points - point_centers_q[..., None, None] / self.distance_scaling
# [N, H, P], [N, 3, H, P] -> [N, 4, C]
res = torch.cat([res_scalar.unsqueeze(-3), res_points_local], dim=-3).flatten(-2, -1)
out = self.to_out(res) # [N, 4, C]
if self.edge_update:
edge_out = self.edge_update_mlp(torch.cat([attn_logits, x_edge], dim=-1))
return out, edge_out
return out
def compute_attention(self, scalar_k, scalar_q, point_k, point_q, edge_bias, index):
"""Compute the attention scores."""
scalar_q = scalar_q[index[1]]
scalar_k = scalar_k[index[0]]
point_q = point_q[index[1]]
point_k = point_k[index[0]]
scalar_logits = (scalar_q * scalar_k).sum(dim=-1) * self.scalar_attn_logits_scale
point_weights = F.softplus(self.point_weights).unsqueeze(0)
point_logits = (
torch.square(point_q - point_k).sum(dim=(-3, -1)) * self.point_attn_logits_scale
)
logits = scalar_logits - 1 / 2 * point_logits * point_weights + edge_bias
alpha = segment_softmax(logits, index[1], scalar_q.shape[0])
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return logits, alpha
@staticmethod
def aggregate(src, dst_idx, dst_size):
"""Aggregate the source tensor to the destination tensor."""
out = torch.zeros(
dst_size,
*src.shape[1:],
dtype=src.dtype,
device=src.device,
).index_add_(0, dst_idx, src)
return out
class BiDirectionalTriangleAttention(torch.nn.Module):
"""
Adapted from https://github.com/aqlaboratory/openfold
supports rectangular pair representation tensors
"""
def __init__(self, c_in: int, c_hidden: int, no_heads: int, inf: float = 1e9):
"""Initialize the Bi-Directional Triangle Attention layer."""
super().__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.linear = torch.nn.Linear(c_in, self.no_heads, bias=False)
self.mha_1 = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
self.mha_2 = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
self.layer_norm = torch.nn.LayerNorm(self.c_in)
def forward(
self,
x1: torch.Tensor,
x2: torch.Tensor,
x_pair: torch.Tensor,
mask: Optional[torch.Tensor] = None,
use_lma: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of the Bi-Directional Triangle Attention layer."""
if mask is None:
# [*, I, J, K]
mask = x_pair.new_ones(
x_pair.shape[:-1],
)
# [*, I, J, C_in]
x1 = self.layer_norm(x1)
# [*, I, K, C_in]
x2 = self.layer_norm(x2)
# [*, I, 1, J, K]
mask_bias = (self.inf * (mask - 1))[..., :, None, :, :]
# [*, I, H, J, K]
triangle_bias = permute_final_dims(self.linear(x_pair), [0, 3, 1, 2])
biases_J2I = [mask_bias, triangle_bias]
x1_out = self.mha_1(q_x=x1, kv_x=x2, biases=biases_J2I, use_lma=use_lma)
x1 = x1 + x1_out
# transpose the triangle bias for I->J attention.
mask_bias_T_ = mask_bias.transpose(-2, -1).contiguous()
triangle_bias_T_ = triangle_bias.transpose(-2, -1).contiguous()
biases_I2J = [mask_bias_T_, triangle_bias_T_]
x2_out = self.mha_2(q_x=x2, kv_x=x1, biases=biases_I2J, use_lma=use_lma)
x2 = x2 + x2_out
return x1, x2

View File

@@ -0,0 +1,443 @@
import numpy as np
import rootutils
import torch
from beartype.typing import Any, Dict, List, Optional, Tuple, Union
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.models.components.transforms import LatentCoordinateConverter
from flowdock.utils import RankedLogger
from flowdock.utils.model_utils import segment_mean
MODEL_BATCH = Dict[str, Any]
log = RankedLogger(__name__, rank_zero_only=True)
class DiffusionSDE:
"""Diffusion SDE class.
Adapted from: https://github.com/HannesStark/FlowSite
"""
def __init__(self, sigma: torch.Tensor, tau_factor: float = 5.0):
"""Initialize the Diffusion SDE class."""
self.lamb = 1 / sigma**2
self.tau_factor = tau_factor
def var(self, t: torch.Tensor) -> torch.Tensor:
"""Calculate the variance of the diffusion SDE."""
return (1 - torch.exp(-self.lamb * t)) / self.lamb
def max_t(self) -> float:
"""Calculate the maximum time of the diffusion SDE."""
return self.tau_factor / self.lamb
def mu_factor(self, t: torch.Tensor) -> torch.Tensor:
"""Calculate the mu factor of the diffusion SDE."""
return torch.exp(-self.lamb * t / 2)
class HarmonicSDE:
"""Harmonic SDE class.
Adapted from: https://github.com/HannesStark/FlowSite
"""
def __init__(self, J: Optional[torch.Tensor] = None, diagonalize: bool = True):
"""Initialize the Harmonic SDE class."""
self.l_index = 1
self.use_cuda = False
if not diagonalize:
return
if J is not None:
self.D, self.P = np.linalg.eigh(J)
self.N = self.D.size
@staticmethod
def diagonalize(
N,
ptr: torch.Tensor,
edges: Optional[List[Tuple[int, int]]] = None,
antiedges: Optional[List[Tuple[int, int]]] = None,
a=1,
b=0.3,
lamb: Optional[torch.Tensor] = None,
device: Optional[Union[str, torch.device]] = None,
):
"""Diagonalize using the Harmonic SDE."""
device = device or ptr.device
J = torch.zeros((N, N), device=device)
if edges is None:
for i, j in zip(np.arange(N - 1), np.arange(1, N)):
J[i, i] += a
J[j, j] += a
J[i, j] = J[j, i] = -a
else:
for i, j in edges:
J[i, i] += a
J[j, j] += a
J[i, j] = J[j, i] = -a
if antiedges is not None:
for i, j in antiedges:
J[i, i] -= b
J[j, j] -= b
J[i, j] = J[j, i] = b
if edges is not None:
J += torch.diag(lamb)
Ds, Ps = [], []
for start, end in zip(ptr[:-1], ptr[1:]):
D, P = torch.linalg.eigh(J[start:end, start:end])
D_ = D
if edges is None:
D_inv = 1 / D
D_inv[0] = 0
D_ = D_inv
Ds.append(D_)
Ps.append(P)
return torch.cat(Ds), torch.block_diag(*Ps)
def eigens(self, t):
"""Calculate the eigenvalues of `sigma_t` using the Harmonic SDE."""
np_ = torch if self.use_cuda else np
D = 1 / self.D * (1 - np_.exp(-t * self.D))
t = torch.tensor(t, device="cuda").float() if self.use_cuda else t
return np_.where(D != 0, D, t)
def conditional(self, mask, x2):
"""Calculate the conditional distribution using the Harmonic SDE."""
J_11 = self.J[~mask][:, ~mask]
J_12 = self.J[~mask][:, mask]
h = -J_12 @ x2
mu = np.linalg.inv(J_11) @ h
D, P = np.linalg.eigh(J_11)
z = np.random.randn(*mu.shape)
return (P / D**0.5) @ z + mu
def A(self, t, invT=False):
"""Calculate the matrix `A` using the Harmonic SDE."""
D = self.eigens(t)
A = self.P * (D**0.5)
if not invT:
return A
AinvT = self.P / (D**0.5)
return A, AinvT
def Sigma_inv(self, t):
"""Calculate the inverse of the covariance matrix `Sigma` using the Harmonic SDE."""
D = 1 / self.eigens(t)
return (self.P * D) @ self.P.T
def Sigma(self, t):
"""Calculate the covariance matrix `Sigma` using the Harmonic SDE."""
D = self.eigens(t)
return (self.P * D) @ self.P.T
@property
def J(self):
"""Return the matrix `J`."""
return (self.P * self.D) @ self.P.T
def rmsd(self, t):
"""Calculate the root mean square deviation using the Harmonic SDE."""
l_index = self.l_index
D = 1 / self.D * (1 - np.exp(-t * self.D))
return np.sqrt(3 * D[l_index:].mean())
def sample(self, t, x=None, score=False, k=None, center=True, adj=False):
"""Sample from the Harmonic SDE."""
l_index = self.l_index
np_ = torch if self.use_cuda else np
if x is None:
if self.use_cuda:
x = torch.zeros((self.N, 3), device="cuda").float()
else:
x = np.zeros((self.N, 3))
if t == 0:
return x
z = (
np.random.randn(self.N, 3)
if not self.use_cuda
else torch.randn(self.N, 3, device="cuda").float()
)
D = self.eigens(t)
xx = self.P.T @ x
if center:
z[0] = 0
xx[0] = 0
if k:
z[k + l_index :] = 0
xx[k + l_index :] = 0
out = np_.exp(-t * self.D / 2)[:, None] * xx + np_.sqrt(D)[:, None] * z
if score:
score = -(1 / np_.sqrt(D))[:, None] * z
if adj:
score = score + self.D[:, None] * out
return self.P @ out, self.P @ score
return self.P @ out
def score_norm(self, t, k=None, adj=False):
"""Calculate the score norm using the Harmonic SDE."""
if k == 0:
return 0
l_index = self.l_index
np_ = torch if self.use_cuda else np
k = k or self.N - 1
D = 1 / self.eigens(t)
if adj:
D = D * np_.exp(-self.D * t)
return (D[l_index : k + l_index].sum() / self.N) ** 0.5
def inject(self, t, modes):
"""Inject noise along the given modes using the Harmonic SDE."""
z = (
np.random.randn(self.N, 3)
if not self.use_cuda
else torch.randn(self.N, 3, device="cuda").float()
)
z[~modes] = 0
A = self.A(t, invT=False)
return A @ z
def score(self, x0, xt, t):
"""Calculate the score of the diffusion kernel using the Harmonic SDE."""
Sigma_inv = self.Sigma_inv(t)
mu_t = (self.P * np.exp(-t * self.D / 2)) @ (self.P.T @ x0)
return Sigma_inv @ (mu_t - xt)
def project(self, X, k, center=False):
"""Project onto the first `k` nonzero modes using the Harmonic SDE."""
l_index = self.l_index
D = self.P.T @ X
D[k + l_index :] = 0
if center:
D[0] = 0
return self.P @ D
def unproject(self, X, mask, k, return_Pinv=False):
"""Find the vector along the first k nonzero modes whose mask is closest to `X`"""
l_index = self.l_index
PP = self.P[mask, : k + l_index]
Pinv = np.linalg.pinv(PP)
out = self.P[:, : k + l_index] @ Pinv @ X
if return_Pinv:
return out, Pinv
return out
def energy(self, X):
"""Calculate the energy using the Harmonic SDE."""
l_index = self.l_index
return (self.D[:, None] * (self.P.T @ X) ** 2).sum(-1)[l_index:] / 2
@property
def free_energy(self):
"""Calculate the free energy using the Harmonic SDE."""
l_index = self.l_index
return 3 * np.log(self.D[l_index:]).sum() / 2
def KL_H(self, t):
"""Calculate the Kullback-Leibler divergence using the Harmonic SDE."""
l_index = self.l_index
D = self.D[l_index:]
return -3 * 0.5 * (np.log(1 - np.exp(-D * t)) + np.exp(-D * t)).sum(0)
def sample_gaussian_prior(
x0: torch.Tensor,
latent_converter: LatentCoordinateConverter,
sigma: float,
x0_sigma: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample noise from a Gaussian prior distribution.
:param x0: ground-truth tensor
:param latent_converter: The latent coordinate converter
:param sigma: standard deviation of the Gaussian noise
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
:return: tuple of ground-truth and predicted tensors with additive Gaussian prior noise
"""
prior = torch.randn_like(x0)
x_int_0 = x0 + prior * x0_sigma # add small Gaussian noise to the ground-truth tensor
(
x1_ca_lat,
x1_cother_lat,
x1_lig_lat,
) = torch.split(
prior * sigma,
[
latent_converter._n_res_per_sample,
latent_converter._n_cother_per_sample,
latent_converter._n_ligha_per_sample,
],
dim=1,
)
x_int_1 = torch.cat(
[
x1_ca_lat,
x1_cother_lat,
x1_lig_lat,
],
dim=1,
)
return x_int_0, x_int_1
def sample_protein_harmonic_prior(
protein_ca_x0: torch.Tensor,
protein_cother_x0: torch.Tensor,
batch: MODEL_BATCH,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample protein noise from a harmonic prior distribution.
Adapted from: https://github.com/bjing2016/alphaflow
Note that this function represents non-Ca atoms as Gaussian noise
centered around each harmonically-noised Ca atom.
:param protein_ca_x0: ground-truth protein Ca-atom tensor
:param protein_cother_x0: ground-truth protein other-atom tensor
:param batch: A batch dictionary
:return: tuple of harmonic protein Ca atom noise and Gaussian protein other atom noise
"""
indexer = batch["indexer"]
metadata = batch["metadata"]
protein_bid = indexer["gather_idx_a_structid"]
protein_num_nodes = protein_ca_x0.size(0) * protein_ca_x0.size(1)
ptr = torch.cumsum(torch.bincount(protein_bid), dim=0)
ptr = torch.cat((torch.tensor([0], device=protein_bid.device), ptr))
try:
D_inv, P = HarmonicSDE.diagonalize(
protein_num_nodes,
ptr,
a=3 / (3.8**2),
)
except Exception as e:
log.error(
f"Failed to call HarmonicSDE.diagonalize() for protein(s) {metadata['sample_ID_per_sample']} due to: {e}"
)
raise e
noise = torch.randn((protein_num_nodes, 3), device=protein_ca_x0.device)
harmonic_ca_noise = P @ (torch.sqrt(D_inv)[:, None] * noise)
gaussian_cother_noise = (
torch.randn_like(protein_cother_x0.flatten(0, 1))
+ harmonic_ca_noise[indexer["gather_idx_a_cotherid"]]
)
return (
harmonic_ca_noise.view(protein_ca_x0.size()).contiguous(),
gaussian_cother_noise.view(protein_cother_x0.size()).contiguous(),
)
def sample_ligand_harmonic_prior(
lig_x0: torch.Tensor, protein_ca_x0: torch.Tensor, batch: MODEL_BATCH, sigma: float = 1.0
) -> torch.Tensor:
"""
Sample ligand noise from a harmonic prior distribution.
Adapted from: https://github.com/HannesStark/FlowSite
:param lig_x0: ground-truth ligand tensor
:param protein_x0: ground-truth protein Ca-atom tensor
:param batch: A batch dictionary
:param sigma: standard deviation of the harmonic noise
:return: tensor of harmonic noise
"""
indexer = batch["indexer"]
metadata = batch["metadata"]
lig_num_nodes = lig_x0.size(0) * lig_x0.size(1)
num_molid_per_sample = max(metadata["num_molid_per_sample"])
# NOTE: here, we distinguish each ligand chain in a complex for harmonic chain sampling
lig_bid = indexer["gather_idx_i_molid"]
protein_sigma = (
segment_mean(
torch.square(protein_ca_x0).flatten(0, 1),
indexer["gather_idx_a_structid"],
metadata["num_structid"],
).mean(dim=-1)
** 0.5
).repeat_interleave(num_molid_per_sample)
sde = DiffusionSDE(protein_sigma * sigma)
edges = torch.stack(
(
indexer["gather_idx_ij_i"],
indexer["gather_idx_ij_j"],
)
)
edges = edges[:, edges[0] < edges[1]] # de-duplicate edges
ptr = torch.cumsum(torch.bincount(lig_bid), dim=0)
ptr = torch.cat((torch.tensor([0], device=lig_bid.device), ptr))
try:
D, P = HarmonicSDE.diagonalize(
lig_num_nodes,
ptr,
edges=edges.T,
lamb=sde.lamb[lig_bid],
)
except Exception as e:
log.error(
f"Failed to call HarmonicSDE.diagonalize() for ligand(s) {metadata['sample_ID_per_sample']} due to: {e}"
)
raise e
noise = torch.randn((lig_num_nodes, 3), device=lig_x0.device)
prior = P @ (noise / torch.sqrt(D)[:, None])
return prior.view(lig_x0.size()).contiguous()
def sample_complex_harmonic_prior(
x0: torch.Tensor,
latent_converter: LatentCoordinateConverter,
batch: MODEL_BATCH,
x0_sigma: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample protein-ligand complex noise from a harmonic prior distribution.
From: https://github.com/bjing2016/alphaflow
:param x0: ground-truth tensor
:param latent_converter: The latent coordinate converter
:param batch: A batch dictionary
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
:return: tuple of ground-truth and predicted tensors with additive Gaussian and harmonic prior noise, respectively
"""
ca_lat, cother_lat, lig_lat = x0.split(
[
latent_converter._n_res_per_sample,
latent_converter._n_cother_per_sample,
latent_converter._n_ligha_per_sample,
],
dim=1,
)
harmonic_ca_lat, gaussian_cother_lat = sample_protein_harmonic_prior(
ca_lat,
cother_lat,
batch,
)
harmonic_lig_lat = sample_ligand_harmonic_prior(lig_lat, harmonic_ca_lat, batch)
x1 = torch.cat(
[
# NOTE: the following normalization steps assume that `self.latent_model == "default"`
harmonic_ca_lat / latent_converter.ca_scale,
gaussian_cother_lat / latent_converter.other_scale,
harmonic_lig_lat / latent_converter.other_scale,
],
dim=1,
)
gaussian_prior = torch.randn_like(x0)
return x0 + gaussian_prior * x0_sigma, x1
def sample_esmfold_prior(
x0: torch.Tensor, x1: torch.Tensor, sigma: float, x0_sigma: float = 1e-4
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample noise from an ESMFold prior distribution.
:param x0: ground-truth tensor
:param x1: predicted tensor
:param sigma: standard deviation of the ESMFold prior's additive Gaussian noise
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
:return: tuple of ground-truth and predicted tensors with additive Gaussian prior noise
"""
prior_noise = torch.randn_like(x0)
return x0 + prior_noise * x0_sigma, x1 + prior_noise * sigma

View File

@@ -0,0 +1,241 @@
# Adapted from: https://github.com/zrqiao/NeuralPLexer
import rootutils
import torch
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.utils.model_utils import segment_mean
class LatentCoordinateConverter:
"""Transform the batched feature dict to latent coordinate arrays."""
def __init__(self, config, prot_atom37_namemap, lig_namemap):
"""Initialize the converter."""
super().__init__()
self.config = config
self.prot_namemap = prot_atom37_namemap
self.lig_namemap = lig_namemap
self.cached_noise = None
self._last_pred_ca_trace = None
@staticmethod
def nested_get(dic, keys):
"""Get the value in the nested dictionary."""
for key in keys:
dic = dic[key]
return dic
@staticmethod
def nested_set(dic, keys, value):
"""Set the value in the nested dictionary."""
for key in keys[:-1]:
dic = dic.setdefault(key, {})
dic[keys[-1]] = value
def to_latent(self, batch):
"""Convert the batched feature dict to latent coordinates."""
return None
def assign_to_batch(self, batch, x_int):
"""Assign the latent coordinates to the batched feature dict."""
return None
class DefaultPLCoordinateConverter(LatentCoordinateConverter):
"""Minimal conversion, using internal coords for sidechains and global coords for others."""
def __init__(self, config, prot_atom37_namemap, lig_namemap):
"""Initialize the converter."""
super().__init__(config, prot_atom37_namemap, lig_namemap)
# Scale parameters in Angstrom
self.ca_scale = config.global_max_sigma
self.other_scale = config.internal_max_sigma
def to_latent(self, batch: dict):
"""Convert the batched feature dict to latent coordinates."""
indexer = batch["indexer"]
metadata = batch["metadata"]
self._batch_size = metadata["num_structid"]
atom37_mask = batch["features"]["res_atom_mask"].bool()
self._cother_mask = atom37_mask.clone()
self._cother_mask[:, 1] = False
atom37_coords = self.nested_get(batch, self.prot_namemap[0])
try:
apo_available = True
apo_atom37_coords = self.nested_get(
batch, self.prot_namemap[0][:-1] + ("apo_" + self.prot_namemap[0][-1],)
)
except KeyError:
apo_available = False
apo_atom37_coords = torch.zeros_like(atom37_coords)
ca_atom_centroid_coords = segment_mean(
# NOTE: in contrast to NeuralPLexer, we center all coordinates at the origin using the Ca atom centroids
atom37_coords[:, 1],
indexer["gather_idx_a_structid"],
self._batch_size,
)
if apo_available:
apo_ca_atom_centroid_coords = segment_mean(
apo_atom37_coords[:, 1],
indexer["gather_idx_a_structid"],
self._batch_size,
)
else:
apo_ca_atom_centroid_coords = torch.zeros_like(ca_atom_centroid_coords)
ca_coords_glob = (
(atom37_coords[:, 1] - ca_atom_centroid_coords[indexer["gather_idx_a_structid"]])
.contiguous()
.view(self._batch_size, -1, 3)
)
if apo_available:
apo_ca_coords_glob = (
(
apo_atom37_coords[:, 1]
- apo_ca_atom_centroid_coords[indexer["gather_idx_a_structid"]]
)
.contiguous()
.view(self._batch_size, -1, 3)
)
else:
apo_ca_coords_glob = torch.zeros_like(ca_coords_glob)
cother_coords_int = (
(atom37_coords - ca_atom_centroid_coords[indexer["gather_idx_a_structid"], None])[
self._cother_mask
]
.contiguous()
.view(self._batch_size, -1, 3)
)
if apo_available:
apo_cother_coords_int = (
(
apo_atom37_coords
- apo_ca_atom_centroid_coords[indexer["gather_idx_a_structid"], None]
)[self._cother_mask]
.contiguous()
.view(self._batch_size, -1, 3)
)
else:
apo_cother_coords_int = torch.zeros_like(cother_coords_int)
self._n_res_per_sample = ca_coords_glob.shape[1]
self._n_cother_per_sample = cother_coords_int.shape[1]
if batch["misc"]["protein_only"]:
self._n_ligha_per_sample = 0
x_int = torch.cat(
[
ca_coords_glob / self.ca_scale,
apo_ca_coords_glob / self.ca_scale,
cother_coords_int / self.other_scale,
apo_cother_coords_int / self.other_scale,
],
dim=1,
)
return x_int
lig_ha_coords = self.nested_get(batch, self.lig_namemap[0])
lig_ha_coords_int = (
lig_ha_coords - ca_atom_centroid_coords[indexer["gather_idx_i_structid"]]
)
lig_ha_coords_int = lig_ha_coords_int.contiguous().view(self._batch_size, -1, 3)
ca_atom_centroid_coords = ca_atom_centroid_coords.contiguous().view(
self._batch_size, -1, 3
)
apo_ca_atom_centroid_coords = apo_ca_atom_centroid_coords.contiguous().view(
self._batch_size, -1, 3
)
x_int = torch.cat(
[
ca_coords_glob / self.ca_scale,
apo_ca_coords_glob / self.ca_scale,
cother_coords_int / self.other_scale,
apo_cother_coords_int / self.other_scale,
ca_atom_centroid_coords / self.ca_scale,
apo_ca_atom_centroid_coords / self.ca_scale,
lig_ha_coords_int / self.other_scale,
],
dim=1,
)
# NOTE: since we use the Ca atom centroids for centralization, we have only one molid per sample
self._n_molid_per_sample = ca_atom_centroid_coords.shape[1]
self._n_ligha_per_sample = lig_ha_coords_int.shape[1]
return x_int
def assign_to_batch(self, batch: dict, x_lat: torch.Tensor):
"""Assign the latent coordinates to the batched feature dict."""
indexer = batch["indexer"]
new_atom37_coords = x_lat.new_zeros(self._batch_size * self._n_res_per_sample, 37, 3)
apo_new_atom37_coords = x_lat.new_zeros(self._batch_size * self._n_res_per_sample, 37, 3)
if batch["misc"]["protein_only"]:
ca_lat, apo_ca_lat, cother_lat, apo_cother_lat = torch.split(
x_lat,
[
self._n_res_per_sample,
self._n_res_per_sample,
self._n_cother_per_sample,
self._n_cother_per_sample,
],
dim=1,
)
else:
(
ca_lat,
apo_ca_lat,
cother_lat,
apo_cother_lat,
ca_cent_lat,
_,
lig_lat,
) = torch.split(
x_lat,
[
self._n_res_per_sample,
self._n_res_per_sample,
self._n_cother_per_sample,
self._n_cother_per_sample,
self._n_molid_per_sample,
self._n_molid_per_sample,
self._n_ligha_per_sample,
],
dim=1,
)
new_ca_glob = (ca_lat * self.ca_scale).contiguous().flatten(0, 1)
apo_new_ca_glob = (apo_ca_lat * self.ca_scale).contiguous().flatten(0, 1)
new_atom37_coords[self._cother_mask] = (
(cother_lat * self.other_scale).contiguous().flatten(0, 1)
)
apo_new_atom37_coords[self._cother_mask] = (
(apo_cother_lat * self.other_scale).contiguous().flatten(0, 1)
)
new_atom37_coords = new_atom37_coords
apo_new_atom37_coords = apo_new_atom37_coords
new_atom37_coords[~self._cother_mask] = 0
apo_new_atom37_coords[~self._cother_mask] = 0
new_atom37_coords[:, 1] = new_ca_glob
apo_new_atom37_coords[:, 1] = apo_new_ca_glob
self.nested_set(batch, self.prot_namemap[1], new_atom37_coords)
self.nested_set(
batch,
self.prot_namemap[1][:-1] + ("apo_" + self.prot_namemap[1][-1],),
apo_new_atom37_coords,
)
if batch["misc"]["protein_only"]:
self.nested_set(batch, self.lig_namemap[1], None)
self.empty_cache()
return batch
new_ligha_coords_int = (lig_lat * self.other_scale).contiguous().flatten(0, 1)
new_ligha_coords_cent = (ca_cent_lat * self.ca_scale).contiguous().flatten(0, 1)
new_ligha_coords = (
new_ligha_coords_int + new_ligha_coords_cent[indexer["gather_idx_i_structid"]]
)
self.nested_set(batch, self.lig_namemap[1], new_ligha_coords)
self.empty_cache()
return batch
def empty_cache(self):
"""Empty the cached variables."""
self._batch_size = None
self._cother_mask = None
self._n_res_per_sample = None
self._n_cother_per_sample = None
self._n_ligha_per_sample = None
self._n_molid_per_sample = None

View File

@@ -0,0 +1,943 @@
import os
import esm
import numpy as np
import rootutils
import torch
from beartype.typing import Any, Dict, Literal, Optional, Union
from lightning import LightningModule
from omegaconf import DictConfig
from torchmetrics.functional.regression import (
mean_absolute_error,
mean_squared_error,
pearson_corrcoef,
spearman_corrcoef,
)
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.models.components.losses import (
eval_auxiliary_estimation_losses,
eval_structure_prediction_losses,
)
from flowdock.utils import RankedLogger
from flowdock.utils.data_utils import pdb_filepath_to_protein, prepare_batch
from flowdock.utils.model_utils import extract_esm_embeddings
from flowdock.utils.sampling_utils import multi_pose_sampling
from flowdock.utils.visualization_utils import (
construct_prot_lig_pairs,
write_prot_lig_pairs_to_pdb_file,
)
MODEL_BATCH = Dict[str, Any]
MODEL_STAGE = Literal["train", "val", "test", "predict"]
LOSS_MODES_LIST = [
"structure_prediction",
"auxiliary_estimation",
"auxiliary_estimation_without_structure_prediction",
]
LOSS_MODES = Literal[
"structure_prediction",
"auxiliary_estimation",
"auxiliary_estimation_without_structure_prediction",
]
AUX_ESTIMATION_STAGES = ["train", "val", "test"]
log = RankedLogger(__name__, rank_zero_only=True)
class FlowDockFMLitModule(LightningModule):
"""A `LightningModule` for geometric flow matching (FM) with FlowDock.
A `LightningModule` implements 8 key methods:
```python
def __init__(self):
# Define initialization code here.
def setup(self, stage):
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
# This hook is called on every process when using DDP.
def training_step(self, batch, batch_idx):
# The complete training step.
def validation_step(self, batch, batch_idx):
# The complete validation step.
def test_step(self, batch, batch_idx):
# The complete test step.
def predict_step(self, batch, batch_idx):
# The complete predict step.
def configure_optimizers(self):
# Define and configure optimizers and LR schedulers.
```
Docs:
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
"""
def __init__(
self,
net: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
compile: bool,
cfg: DictConfig,
**kwargs: Dict[str, Any],
):
"""Initialize a `FlowDockFMLitModule`.
:param net: The model to train.
:param optimizer: The optimizer to use for training.
:param scheduler: The learning rate scheduler to use for training.
:param compile: Whether to compile the model before training.
:param cfg: The model configuration.
:param kwargs: Additional keyword arguments.
"""
super().__init__()
# the model along with its hyperparameters
self.net = net(cfg)
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False, ignore=["net"])
# for validating input arguments
if self.hparams.cfg.task.loss_mode not in LOSS_MODES_LIST:
raise ValueError(
f"Invalid loss mode: {self.hparams.cfg.task.loss_mode}. Must be one of {LOSS_MODES}."
)
# for inspecting the model's outputs during validation and testing
(
self.training_step_outputs,
self.validation_step_outputs,
self.test_step_outputs,
self.predict_step_outputs,
) = (
[],
[],
[],
[],
)
def forward(
self,
batch: MODEL_BATCH,
iter_id: Union[int, str] = 0,
observed_block_contacts: Optional[torch.Tensor] = None,
contact_prediction: bool = True,
infer_geometry_prior: bool = False,
score: bool = False,
affinity: bool = True,
use_template: bool = False,
**kwargs: Dict[str, Any],
) -> MODEL_BATCH:
"""Perform a forward pass through the model.
:param batch: A batch dictionary.
:param iter_id: The current iteration ID.
:param observed_block_contacts: Observed block contacts.
:param contact_prediction: Whether to predict contacts.
:param infer_geometry_prior: Whether to predict using a geometry prior.
:param score: Whether to predict a denoised complex structure.
:param affinity: Whether to predict ligand binding affinity.
:param use_template: Whether to use a template protein structure.
:param kwargs: Additional keyword arguments.
:return: Batch dictionary with outputs.
"""
return self.net(
batch,
iter_id=iter_id,
observed_block_contacts=observed_block_contacts,
contact_prediction=contact_prediction,
infer_geometry_prior=infer_geometry_prior,
score=score,
affinity=affinity,
use_template=use_template,
training=self.training,
**kwargs,
)
def model_step(
self,
batch: MODEL_BATCH,
batch_idx: int,
stage: MODEL_STAGE,
loss_mode: Optional[LOSS_MODES] = None,
) -> MODEL_BATCH:
"""Perform a single model step on a batch of data.
:param batch: A batch dictionary.
:param batch_idx: The index of the current batch.
:param stage: The current model stage (i.e., `train`, `val`, `test`, or `predict`).
:param loss_mode: The loss mode to use for training.
:return: Batch dictionary with losses.
"""
prepare_batch(batch)
predicting_aux_outputs = (
self.hparams.cfg.confidence.enabled or self.hparams.cfg.affinity.enabled
)
is_aux_loss_stage = stage in AUX_ESTIMATION_STAGES
is_aux_batch = batch_idx % self.hparams.cfg.task.aux_batch_freq == 0
struct_pred_loss_mode_requested = (
loss_mode is not None and loss_mode == "structure_prediction"
)
should_eval_aux_loss = (
predicting_aux_outputs
and is_aux_loss_stage
and is_aux_batch
and not struct_pred_loss_mode_requested
and (
not self.hparams.cfg.task.freeze_confidence
or (
not self.hparams.cfg.task.freeze_affinity
and batch["features"]["affinity"].any().item()
)
)
)
eval_aux_loss_mode_requested = (
predicting_aux_outputs
and loss_mode is not None
and "auxiliary_estimation" in loss_mode
)
if should_eval_aux_loss or eval_aux_loss_mode_requested:
return eval_auxiliary_estimation_losses(
self, batch, stage, loss_mode, training=self.training
)
loss_fn = eval_structure_prediction_losses
return loss_fn(self, batch, batch_idx, self.device, stage, t_1=1.0)
def on_train_start(self):
"""Lightning hook that is called when training begins."""
pass
def training_step(self, batch: MODEL_BATCH, batch_idx: int) -> torch.Tensor:
"""Perform a single training step on a batch of data from the training set.
:param batch: A batch dictionary.
:param batch_idx: The index of the current batch.
:return: A tensor of losses between model predictions and targets.
"""
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
name == self.hparams.cfg.task.overfitting_example_name
for name in batch["metadata"]["sample_ID_per_sample"]
):
return None
try:
batch = self.model_step(batch, batch_idx, "train")
except Exception as e:
log.error(
f"Failed to perform training step for batch index {batch_idx} due to: {e}. Skipping example."
)
return None
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
training_outputs = {
"affinity_logits": batch["outputs"]["affinity_logits"],
"affinity": batch["features"]["affinity"],
}
self.training_step_outputs.append(training_outputs)
# return loss or backpropagation will fail
return batch["outputs"]["loss"]
def on_train_epoch_end(self):
"""Lightning hook that is called when a training epoch ends."""
if self.hparams.cfg.affinity.enabled and any(
"affinity_logits" in output for output in self.training_step_outputs
):
affinity_logits = torch.cat(
[
output["affinity_logits"]
for output in self.training_step_outputs
if "affinity_logits" in output
]
)
affinity = torch.cat(
[
output["affinity"]
for output in self.training_step_outputs
if "affinity_logits" in output
]
)
affinity_logits = affinity_logits[~affinity.isnan()]
affinity = affinity[~affinity.isnan()]
if affinity.numel() > 1:
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
aff_mae = mean_absolute_error(affinity_logits, affinity)
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
self.log(
"train_affinity/RMSE",
aff_rmse.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=False,
)
self.log(
"train_affinity/MAE",
aff_mae.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=False,
)
self.log(
"train_affinity/Pearson",
aff_pearson.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=False,
)
self.log(
"train_affinity/Spearman",
aff_spearman.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=False,
)
self.training_step_outputs.clear() # free memory
def on_validation_start(self):
"""Lightning hook that is called when validation begins."""
# create a directory to store model outputs from each validation epoch
os.makedirs(
os.path.join(self.trainer.default_root_dir, "validation_epoch_outputs"), exist_ok=True
)
def validation_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
"""Perform a single validation step on a batch of data from the validation set.
:param batch: A batch dictionary.
:param batch_idx: The index of the current batch.
:param dataloader_idx: The index of the current dataloader.
"""
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
name == self.hparams.cfg.task.overfitting_example_name
for name in batch["metadata"]["sample_ID_per_sample"]
):
return None
try:
prepare_batch(batch)
sampling_stats = self.net.sample_pl_complex_structures(
batch,
sampler="VDODE",
sampler_eta=1.0,
num_steps=10,
start_time=1.0,
exact_prior=False,
return_all_states=True,
eval_input_protein=True,
)
all_frames = sampling_stats["all_frames"]
del sampling_stats["all_frames"]
for metric_name in sampling_stats.keys():
log_stat = sampling_stats[metric_name].mean().detach()
batch_size = sampling_stats[metric_name].shape[0]
self.log(
f"val_sampling/{metric_name}",
log_stat,
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
sampling_stats = self.net.sample_pl_complex_structures(
batch,
sampler="VDODE",
sampler_eta=1.0,
num_steps=10,
start_time=1.0,
exact_prior=False,
use_template=False,
)
for metric_name in sampling_stats.keys():
log_stat = sampling_stats[metric_name].mean().detach()
batch_size = sampling_stats[metric_name].shape[0]
self.log(
f"val_sampling_notemplate/{metric_name}",
log_stat,
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
sampling_stats = self.net.sample_pl_complex_structures(
batch,
sampler="VDODE",
sampler_eta=1.0,
num_steps=10,
start_time=1.0,
return_summary_stats=True,
exact_prior=True,
)
for metric_name in sampling_stats.keys():
log_stat = sampling_stats[metric_name].mean().detach()
batch_size = sampling_stats[metric_name].shape[0]
self.log(
f"val_sampling_trueprior/{metric_name}",
log_stat,
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
batch = self.model_step(batch, batch_idx, "val")
except Exception as e:
log.error(
f"Failed to perform validation step for batch index {batch_idx} of dataloader {dataloader_idx} due to: {e}. Skipping example."
)
return None
# store model outputs for inspection
validation_outputs = {}
if self.hparams.cfg.task.visualize_generated_samples:
validation_outputs = {
"name": batch["metadata"]["sample_ID_per_sample"],
"batch_size": batch["metadata"]["num_structid"],
"aatype": batch["features"]["res_type"].long().cpu().numpy(),
"res_atom_mask": batch["features"]["res_atom_mask"].cpu().numpy(),
"protein_coordinates_list": [
frame["receptor_padded"].cpu().numpy() for frame in all_frames
],
"ligand_coordinates_list": [
frame["ligands"].cpu().numpy() for frame in all_frames
],
"ligand_mol": batch["metadata"]["mol_per_sample"],
"protein_batch_indexer": batch["indexer"]["gather_idx_a_structid"].cpu().numpy(),
"ligand_batch_indexer": batch["indexer"]["gather_idx_i_structid"].cpu().numpy(),
"gt_protein_coordinates": batch["features"]["res_atom_positions"].cpu().numpy(),
"gt_ligand_coordinates": batch["features"]["sdf_coordinates"].cpu().numpy(),
"dataloader_idx": dataloader_idx,
}
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
validation_outputs.update(
{
"affinity_logits": batch["outputs"]["affinity_logits"],
"affinity": batch["features"]["affinity"],
"dataloader_idx": dataloader_idx,
}
)
if validation_outputs:
self.validation_step_outputs.append(validation_outputs)
def on_validation_epoch_end(self):
"Lightning hook that is called when a validation epoch ends."
if self.hparams.cfg.task.visualize_generated_samples:
for i, outputs in enumerate(self.validation_step_outputs):
for batch_index in range(outputs["batch_size"]):
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
write_prot_lig_pairs_to_pdb_file(
prot_lig_pairs,
os.path.join(
self.trainer.default_root_dir,
"validation_epoch_outputs",
f"{outputs['name'][batch_index]}_validation_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}_dataloader_{outputs['dataloader_idx']}.pdb",
),
)
if self.hparams.cfg.affinity.enabled and any(
"affinity_logits" in output for output in self.validation_step_outputs
):
affinity_logits = torch.cat(
[
output["affinity_logits"]
for output in self.validation_step_outputs
if "affinity_logits" in output
]
)
affinity = torch.cat(
[
output["affinity"]
for output in self.validation_step_outputs
if "affinity_logits" in output
]
)
affinity_logits = affinity_logits[~affinity.isnan()]
affinity = affinity[~affinity.isnan()]
if affinity.numel() > 1:
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
aff_mae = mean_absolute_error(affinity_logits, affinity)
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
self.log(
"val_affinity/RMSE",
aff_rmse.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=True,
)
self.log(
"val_affinity/MAE",
aff_mae.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=True,
)
self.log(
"val_affinity/Pearson",
aff_pearson.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=True,
)
self.log(
"val_affinity/Spearman",
aff_spearman.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=True,
)
self.validation_step_outputs.clear() # free memory
def on_test_start(self):
"""Lightning hook that is called when testing begins."""
# create a directory to store model outputs from each test epoch
os.makedirs(
os.path.join(self.trainer.default_root_dir, "test_epoch_outputs"), exist_ok=True
)
def test_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
"""Perform a single test step on a batch of data from the test set.
:param batch: A batch dictionary.
:param batch_idx: The index of the current batch.
:param dataloader_idx: The index of the current dataloader.
"""
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
name == self.hparams.cfg.task.overfitting_example_name
for name in batch["metadata"]["sample_ID_per_sample"]
):
return None
try:
prepare_batch(batch)
if self.hparams.cfg.task.eval_structure_prediction:
sampling_stats = self.net.sample_pl_complex_structures(
batch,
sampler=self.hparams.cfg.task.sampler,
sampler_eta=self.hparams.cfg.task.sampler_eta,
num_steps=self.hparams.cfg.task.num_steps,
start_time=self.hparams.cfg.task.start_time,
exact_prior=False,
return_all_states=True,
eval_input_protein=True,
)
all_frames = sampling_stats["all_frames"]
del sampling_stats["all_frames"]
for metric_name in sampling_stats.keys():
log_stat = sampling_stats[metric_name].mean().detach()
batch_size = sampling_stats[metric_name].shape[0]
self.log(
f"test_sampling/{metric_name}",
log_stat,
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
sampling_stats = self.net.sample_pl_complex_structures(
batch,
sampler=self.hparams.cfg.task.sampler,
sampler_eta=self.hparams.cfg.task.sampler_eta,
num_steps=self.hparams.cfg.task.num_steps,
start_time=self.hparams.cfg.task.start_time,
exact_prior=False,
use_template=False,
)
for metric_name in sampling_stats.keys():
log_stat = sampling_stats[metric_name].mean().detach()
batch_size = sampling_stats[metric_name].shape[0]
self.log(
f"test_sampling_notemplate/{metric_name}",
log_stat,
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
sampling_stats = self.net.sample_pl_complex_structures(
batch,
sampler=self.hparams.cfg.task.sampler,
sampler_eta=self.hparams.cfg.task.sampler_eta,
num_steps=self.hparams.cfg.task.num_steps,
start_time=self.hparams.cfg.task.start_time,
return_summary_stats=True,
exact_prior=True,
)
for metric_name in sampling_stats.keys():
log_stat = sampling_stats[metric_name].mean().detach()
batch_size = sampling_stats[metric_name].shape[0]
self.log(
f"test_sampling_trueprior/{metric_name}",
log_stat,
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
batch = self.model_step(
batch, batch_idx, "test", loss_mode=self.hparams.cfg.task.loss_mode
)
except Exception as e:
log.error(
f"Failed to perform test step for {batch['metadata']['sample_ID_per_sample']} with batch index {batch_idx} of dataloader {dataloader_idx} due to: {e}."
)
raise e
# store model outputs for inspection
test_outputs = {}
if (
self.hparams.cfg.task.visualize_generated_samples
and self.hparams.cfg.task.eval_structure_prediction
):
test_outputs.update(
{
"name": batch["metadata"]["sample_ID_per_sample"],
"batch_size": batch["metadata"]["num_structid"],
"aatype": batch["features"]["res_type"].long().cpu().numpy(),
"res_atom_mask": batch["features"]["res_atom_mask"].cpu().numpy(),
"protein_coordinates_list": [
frame["receptor_padded"].cpu().numpy() for frame in all_frames
],
"ligand_coordinates_list": [
frame["ligands"].cpu().numpy() for frame in all_frames
],
"ligand_mol": batch["metadata"]["mol_per_sample"],
"protein_batch_indexer": batch["indexer"]["gather_idx_a_structid"]
.cpu()
.numpy(),
"ligand_batch_indexer": batch["indexer"]["gather_idx_i_structid"]
.cpu()
.numpy(),
"gt_protein_coordinates": batch["features"]["res_atom_positions"]
.cpu()
.numpy(),
"gt_ligand_coordinates": batch["features"]["sdf_coordinates"].cpu().numpy(),
"dataloader_idx": dataloader_idx,
}
)
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
test_outputs.update(
{
"affinity_logits": batch["outputs"]["affinity_logits"],
"affinity": batch["features"]["affinity"],
"dataloader_idx": dataloader_idx,
}
)
if test_outputs:
self.test_step_outputs.append(test_outputs)
def on_test_epoch_end(self):
"""Lightning hook that is called when a test epoch ends."""
if (
self.hparams.cfg.task.visualize_generated_samples
and self.hparams.cfg.task.eval_structure_prediction
):
for i, outputs in enumerate(self.test_step_outputs):
for batch_index in range(outputs["batch_size"]):
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
write_prot_lig_pairs_to_pdb_file(
prot_lig_pairs,
os.path.join(
self.trainer.default_root_dir,
"test_epoch_outputs",
f"{outputs['name'][batch_index]}_test_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}_dataloader_{outputs['dataloader_idx']}.pdb",
),
)
if self.hparams.cfg.affinity.enabled and any(
"affinity_logits" in output for output in self.test_step_outputs
):
affinity_logits = torch.cat(
[
output["affinity_logits"]
for output in self.test_step_outputs
if "affinity_logits" in output
]
)
affinity = torch.cat(
[
output["affinity"]
for output in self.test_step_outputs
if "affinity_logits" in output
]
)
affinity_logits = affinity_logits[~affinity.isnan()]
affinity = affinity[~affinity.isnan()]
if affinity.numel() > 1:
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
aff_mae = mean_absolute_error(affinity_logits, affinity)
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
self.log(
"test_affinity/RMSE",
aff_rmse.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=True,
)
self.log(
"test_affinity/MAE",
aff_mae.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=True,
)
self.log(
"test_affinity/Pearson",
aff_pearson.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=True,
)
self.log(
"test_affinity/Spearman",
aff_spearman.detach(),
on_epoch=True,
batch_size=len(affinity),
sync_dist=True,
)
self.test_step_outputs.clear() # free memory
def on_predict_start(self):
"""Lightning hook that is called when testing begins."""
# create a directory to store model outputs from each predict epoch
os.makedirs(
os.path.join(self.trainer.default_root_dir, "predict_epoch_outputs"), exist_ok=True
)
log.info("Loading pretrained ESM model...")
esm_model, self.esm_alphabet = esm.pretrained.load_model_and_alphabet_hub(
self.hparams.cfg.model.cfg.protein_encoder.esm_version
)
self.esm_model = esm_model.eval().float()
self.esm_batch_converter = self.esm_alphabet.get_batch_converter()
self.esm_model.cpu()
skip_loading_esmfold_weights = (
# skip loading ESMFold weights if the template protein structure for a single complex input is provided
self.hparams.cfg.task.csv_path is None
and self.hparams.cfg.task.input_template is not None
and os.path.exists(self.hparams.cfg.task.input_template)
)
if not skip_loading_esmfold_weights:
log.info("Loading pretrained ESMFold model...")
esmfold_model = esm.pretrained.esmfold_v1()
self.esmfold_model = esmfold_model.eval().float()
self.esmfold_model.set_chunk_size(self.hparams.cfg.esmfold_chunk_size)
self.esmfold_model.cpu()
def predict_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
"""Perform a single predict step on a batch of data from the predict set.
:param batch: A batch dictionary.
:param batch_idx: The index of the current batch.
:param dataloader_idx: The index of the current dataloader.
"""
rec_path = batch["rec_path"][0]
ligand_paths = list(
path[0] for path in batch["lig_paths"]
) # unpack a list of (batched) single-element string tuples
sample_id = batch["sample_id"][0] if "sample_id" in batch else "sample"
input_template = batch["input_template"][0] if "input_template" in batch else None
out_path = (
os.path.join(self.hparams.cfg.out_path, sample_id)
if "sample_id" in batch
else self.hparams.cfg.out_path
)
# generate ESM embeddings for the protein
protein = pdb_filepath_to_protein(rec_path)
sequences = [
"".join(np.array(list(chain_seq))[chain_mask])
for (_, chain_seq, chain_mask) in protein.letter_sequences
]
esm_embeddings = extract_esm_embeddings(
self.esm_model,
self.esm_alphabet,
self.esm_batch_converter,
sequences,
device="cpu",
esm_repr_layer=self.hparams.cfg.model.cfg.protein_encoder.esm_repr_layer,
)
sequences_to_embeddings = {
f"{seq}:{i}": esm_embeddings[i].cpu().numpy() for i, seq in enumerate(sequences)
}
# generate initial ESMFold-predicted structure for the protein if a template is not provided
apo_rec_path = None
if input_template and os.path.exists(input_template):
apo_protein = pdb_filepath_to_protein(input_template)
apo_chain_seq_masked = "".join(
"".join(np.array(list(chain_seq))[chain_mask])
for (_, chain_seq, chain_mask) in apo_protein.letter_sequences
)
chain_seq_masked = "".join(
"".join(np.array(list(chain_seq))[chain_mask])
for (_, chain_seq, chain_mask) in protein.letter_sequences
)
if apo_chain_seq_masked != chain_seq_masked:
log.error(
f"Provided template protein structure {input_template} does not match the input protein sequence within {rec_path}. Skipping example {sample_id} at batch index {batch_idx} of dataloader {dataloader_idx}."
)
return None
log.info(f"Starting from provided template protein structure: {input_template}")
apo_rec_path = input_template
if apo_rec_path is None and self.hparams.cfg.prior_type == "esmfold":
esmfold_sequence = ":".join(sequences)
apo_rec_path = rec_path.replace(".pdb", "_apo.pdb")
with torch.no_grad():
esmfold_pdb_output = self.esmfold_model.infer_pdb(esmfold_sequence)
with open(apo_rec_path, "w") as f:
f.write(esmfold_pdb_output)
_, _, _, _, _, all_frames, batch_all, b_factors, plddt_rankings = multi_pose_sampling(
rec_path,
ligand_paths,
self.hparams.cfg,
self,
out_path,
separate_pdb=self.hparams.cfg.separate_pdb,
apo_receptor_path=apo_rec_path,
sample_id=sample_id,
protein=protein,
sequences_to_embeddings=sequences_to_embeddings,
return_all_states=self.hparams.cfg.task.visualize_generated_samples,
auxiliary_estimation_only=self.hparams.cfg.task.auxiliary_estimation_only,
)
# store model outputs for inspection
if self.hparams.cfg.task.visualize_generated_samples:
predict_outputs = {
"name": batch_all["metadata"]["sample_ID_per_sample"],
"batch_size": batch_all["metadata"]["num_structid"],
"aatype": batch_all["features"]["res_type"].long().cpu().numpy(),
"res_atom_mask": batch_all["features"]["res_atom_mask"].cpu().numpy(),
"protein_coordinates_list": [
frame["receptor_padded"].cpu().numpy() for frame in all_frames
],
"ligand_coordinates_list": [
frame["ligands"].cpu().numpy() for frame in all_frames
],
"ligand_mol": batch_all["metadata"]["mol_per_sample"],
"protein_batch_indexer": batch_all["indexer"]["gather_idx_a_structid"]
.cpu()
.numpy(),
"ligand_batch_indexer": batch_all["indexer"]["gather_idx_i_structid"]
.cpu()
.numpy(),
"b_factors": b_factors,
"plddt_rankings": plddt_rankings,
}
self.predict_step_outputs.append(predict_outputs)
def on_predict_epoch_end(self):
"""Lightning hook that is called when a predict epoch ends."""
if self.hparams.cfg.task.visualize_generated_samples:
for i, outputs in enumerate(self.predict_step_outputs):
for batch_index in range(outputs["batch_size"]):
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
ranking = (
outputs["plddt_rankings"][batch_index]
if "plddt_rankings" in outputs
else None
)
write_prot_lig_pairs_to_pdb_file(
prot_lig_pairs,
os.path.join(
self.hparams.cfg.out_path,
outputs["name"][batch_index],
"predict_epoch_outputs",
f"{outputs['name'][batch_index]}{f'_rank{ranking + 1}' if ranking is not None else ''}_predict_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}.pdb",
),
)
self.predict_step_outputs.clear() # free memory
def on_after_backward(self):
"""Skip updates in case of unstable gradients.
Reference: https://github.com/Lightning-AI/lightning/issues/4956
"""
valid_gradients = True
for _, param in self.named_parameters():
if param.grad is not None:
valid_gradients = not (
torch.isnan(param.grad).any() or torch.isinf(param.grad).any()
)
if not valid_gradients:
break
if not valid_gradients:
log.warning(
"Detected `inf` or `nan` values in gradients. Not updating model parameters."
)
self.zero_grad()
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_closure,
):
"""Override the optimizer step to dynamically update the learning rate.
:param epoch: The current epoch.
:param batch_idx: The index of the current batch.
:param optimizer: The optimizer to use for training.
:param optimizer_closure: The optimizer closure.
"""
# update params
optimizer = optimizer.optimizer
optimizer.step(closure=optimizer_closure)
# warm up learning rate
if self.trainer.global_step < 1000:
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 1000.0)
for pg in optimizer.param_groups:
# NOTE: `self.hparams.optimizer.keywords["lr"]` refers to the optimizer's initial learning rate
pg["lr"] = lr_scale * self.hparams.optimizer.keywords["lr"]
def setup(self, stage: str):
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
test, or predict.
This is a good hook when you need to build models dynamically or adjust something about
them. This hook is called on every process when using DDP.
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
"""
if self.hparams.compile and stage == "fit":
self.net = torch.compile(self.net)
def configure_optimizers(self) -> Dict[str, Any]:
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples:
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
"""
try:
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
except TypeError:
# NOTE: strategies such as DeepSpeed require `params` to instead be specified as `model_params`
optimizer = self.hparams.optimizer(model_params=self.trainer.model.parameters())
if self.hparams.scheduler is not None:
scheduler = self.hparams.scheduler(optimizer=optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val/loss",
"interval": "epoch",
"frequency": 1,
},
}
return {"optimizer": optimizer}
if __name__ == "__main__":
_ = FlowDockFMLitModule(None, None, None, None)

284
flowdock/sample.py Normal file
View File

@@ -0,0 +1,284 @@
import os
import hydra
import lightning as L
import lovely_tensors as lt
import pandas as pd
import rootutils
import torch
from beartype.typing import Any, Dict, List, Tuple
from lightning import LightningModule, Trainer
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies.strategy import Strategy
from omegaconf import DictConfig, open_dict
from torch.utils.data import DataLoader
lt.monkey_patch()
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from flowdock import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
from flowdock.utils import (
RankedLogger,
extras,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)
from flowdock.utils.data_utils import (
create_full_pdb_with_zero_coordinates,
create_temp_ligand_frag_files,
)
log = RankedLogger(__name__, rank_zero_only=True)
AVAILABLE_SAMPLING_TASKS = ["batched_structure_sampling"]
class SamplingDataset(torch.utils.data.Dataset):
"""Dataset for sampling."""
def __init__(self, cfg: DictConfig):
"""Initializes the SamplingDataset."""
if cfg.sampling_task == "batched_structure_sampling":
if cfg.csv_path is not None:
# handle variable CSV inputs
df_rows = []
self.df = pd.read_csv(cfg.csv_path)
for _, row in self.df.iterrows():
sample_id = row.id
input_receptor = row.input_receptor
input_ligand = row.input_ligand
input_template = row.input_template
assert input_receptor is not None, "Receptor path is required for sampling."
if input_ligand is not None:
if input_ligand.endswith(".sdf"):
ligand_paths = create_temp_ligand_frag_files(input_ligand)
else:
ligand_paths = list(input_ligand.split("|"))
else:
ligand_paths = None # handle `null` ligand input
if not input_receptor.endswith(".pdb"):
log.warning(
"Assuming the provided receptor input is a protein sequence. Creating a dummy PDB file."
)
create_full_pdb_with_zero_coordinates(
input_receptor, os.path.join(cfg.out_path, f"input_{sample_id}.pdb")
)
input_receptor = os.path.join(cfg.out_path, f"input_{sample_id}.pdb")
df_row = {
"sample_id": sample_id,
"rec_path": input_receptor,
"lig_paths": ligand_paths,
}
if input_template is not None:
df_row["input_template"] = input_template
df_rows.append(df_row)
self.df = pd.DataFrame(df_rows)
else:
sample_id = cfg.sample_id
input_receptor = cfg.input_receptor
input_ligand = cfg.input_ligand
if input_ligand is not None:
if input_ligand.endswith(".sdf"):
ligand_paths = create_temp_ligand_frag_files(input_ligand)
else:
ligand_paths = list(input_ligand.split("|"))
else:
ligand_paths = None # handle `null` ligand input
if not input_receptor.endswith(".pdb"):
log.warning(
"Assuming the provided receptor input is a protein sequence. Creating a dummy PDB file."
)
create_full_pdb_with_zero_coordinates(
input_receptor, os.path.join(cfg.out_path, "input.pdb")
)
input_receptor = os.path.join(cfg.out_path, "input.pdb")
self.df = pd.DataFrame(
[
{
"sample_id": sample_id,
"rec_path": input_receptor,
"lig_paths": ligand_paths,
}
]
)
if cfg.input_template is not None:
self.df["input_template"] = cfg.input_template
else:
raise NotImplementedError(f"Sampling task {cfg.sampling_task} is not implemented.")
def __len__(self):
"""Returns the length of the dataset."""
return len(self.df)
def __getitem__(self, idx: int) -> Tuple[str, str]:
"""Returns the input receptor and input ligand."""
return self.df.iloc[idx].to_dict()
@task_wrapper
def sample(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Samples using given checkpoint on a datamodule predictset.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: DictConfig configuration composed by Hydra.
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
"""
assert cfg.ckpt_path, "Please provide a checkpoint path with which to sample!"
assert os.path.exists(cfg.ckpt_path), f"Checkpoint path {cfg.ckpt_path} does not exist!"
assert (
cfg.sampling_task in AVAILABLE_SAMPLING_TASKS
), f"Sampling task {cfg.sampling_task} is not one of the following available tasks: {AVAILABLE_SAMPLING_TASKS}."
assert (cfg.input_receptor is not None and cfg.input_ligand is not None) or (
cfg.csv_path is not None and os.path.exists(cfg.csv_path)
), "Please provide either an input receptor and ligand or a CSV file with receptor and ligand sequences/filepaths."
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
)
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
# Establish model input arguments
with open_dict(cfg):
# NOTE: Structure trajectories will not be visualized when performing auxiliary estimation only
cfg.model.cfg.prior_type = cfg.prior_type
cfg.model.cfg.task.detect_covalent = cfg.detect_covalent
cfg.model.cfg.task.use_template = cfg.use_template
cfg.model.cfg.task.csv_path = cfg.csv_path
cfg.model.cfg.task.input_receptor = cfg.input_receptor
cfg.model.cfg.task.input_ligand = cfg.input_ligand
cfg.model.cfg.task.input_template = cfg.input_template
cfg.model.cfg.task.visualize_generated_samples = (
cfg.visualize_sample_trajectories and not cfg.auxiliary_estimation_only
)
cfg.model.cfg.task.auxiliary_estimation_only = cfg.auxiliary_estimation_only
if cfg.latent_model is not None:
with open_dict(cfg):
cfg.model.cfg.latent_model = cfg.latent_model
with open_dict(cfg):
if cfg.start_time == "auto":
cfg.start_time = 1.0
else:
cfg.start_time = float(cfg.start_time)
log.info("Converting sampling inputs into a <SamplingDataset>")
dataloaders: List[DataLoader] = [
DataLoader(
SamplingDataset(cfg),
batch_size=1,
shuffle=False,
num_workers=0,
pin_memory=False,
)
]
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
model.hparams.cfg.update(cfg) # update model config with the sampling config
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
plugins = None
if "_target_" in cfg.environment:
log.info(f"Instantiating environment <{cfg.environment._target_}>")
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
strategy = getattr(cfg.trainer, "strategy", None)
if "_target_" in cfg.strategy:
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
if (
"mixed_precision" in strategy.__dict__
and getattr(strategy, "mixed_precision", None) is not None
):
strategy.mixed_precision.param_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
else None
)
strategy.mixed_precision.reduce_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
else None
)
strategy.mixed_precision.buffer_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
else None
)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = (
hydra.utils.instantiate(
cfg.trainer,
logger=logger,
plugins=plugins,
strategy=strategy,
)
if strategy is not None
else hydra.utils.instantiate(
cfg.trainer,
logger=logger,
plugins=plugins,
)
)
object_dict = {
"cfg": cfg,
"model": model,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
metric_dict = trainer.callback_metrics
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="../configs", config_name="sample.yaml")
def main(cfg: DictConfig) -> None:
"""Main entry point for sampling.
:param cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)
sample(cfg)
if __name__ == "__main__":
register_custom_omegaconf_resolvers()
main()

196
flowdock/train.py Normal file
View File

@@ -0,0 +1,196 @@
import os
import hydra
import lightning as L
import lovely_tensors as lt
import rootutils
import torch
from beartype.typing import Any, Dict, List, Optional, Tuple
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies.strategy import Strategy
from omegaconf import DictConfig
lt.monkey_patch()
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from flowdock import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
from flowdock.utils import (
RankedLogger,
extras,
get_metric_value,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)
log = RankedLogger(__name__, rank_zero_only=True)
@task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: A DictConfig configuration composed by Hydra.
:return: A tuple with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
)
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, stage="fit")
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
plugins = None
if "_target_" in cfg.environment:
log.info(f"Instantiating environment <{cfg.environment._target_}>")
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
strategy = getattr(cfg.trainer, "strategy", None)
if "_target_" in cfg.strategy:
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
if (
"mixed_precision" in strategy.__dict__
and getattr(strategy, "mixed_precision", None) is not None
):
strategy.mixed_precision.param_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
else None
)
strategy.mixed_precision.reduce_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
else None
)
strategy.mixed_precision.buffer_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
else None
)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = (
hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
plugins=plugins,
strategy=strategy,
)
if strategy is not None
else hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
plugins=plugins,
)
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
if cfg.get("train"):
log.info("Starting training!")
ckpt_path = None
if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")):
ckpt_path = cfg.get("ckpt_path")
elif cfg.get("ckpt_path"):
log.warning(
"`ckpt_path` was given, but the path does not exist. Training with new model weights."
)
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
"""Main entry point for training.
:param cfg: DictConfig configuration composed by Hydra.
:return: Optional[float] with optimized metric value.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)
# train the model
metric_dict, _ = train(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
register_custom_omegaconf_resolvers()
main()

View File

@@ -0,0 +1,5 @@
from flowdock.utils.instantiators import instantiate_callbacks, instantiate_loggers
from flowdock.utils.logging_utils import log_hyperparameters
from flowdock.utils.pylogger import RankedLogger
from flowdock.utils.rich_utils import enforce_tags, print_config_tree
from flowdock.utils.utils import extras, get_metric_value, task_wrapper

1846
flowdock/utils/data_utils.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,77 @@
import torch
from beartype.typing import Optional
class RigidTransform:
"""Rigid Transform class."""
def __init__(self, t: torch.Tensor, R: Optional[torch.Tensor] = None):
"""Initialize Rigid Transform."""
self.t = t
if R is None:
R = t.new_zeros(*t.shape, 3)
self.R = R
def __getitem__(self, key):
"""Get item from Rigid Transform."""
return RigidTransform(self.t[key], self.R[key])
def unsqueeze(self, dim):
"""Unsqueeze Rigid Transform."""
return RigidTransform(self.t.unsqueeze(dim), self.R.unsqueeze(dim))
def squeeze(self, dim):
"""Squeeze Rigid Transform."""
return RigidTransform(self.t.squeeze(dim), self.R.squeeze(dim))
def concatenate(self, other, dim=0):
"""Concatenate Rigid Transform."""
return RigidTransform(
torch.cat([self.t, other.t], dim=dim),
torch.cat([self.R, other.R], dim=dim),
)
def get_frame_matrix(
ri: torch.Tensor, rj: torch.Tensor, rk: torch.Tensor, eps: float = 1e-4, strict: bool = False
):
"""Get frame matrix from three points using the regularized Gram-Schmidt algorithm.
Note that this implementation allows for shearing.
"""
v1 = ri - rj
v2 = rk - rj
if strict:
# v1 = v1 + torch.randn_like(rj).mul(eps)
# v2 = v2 + torch.randn_like(rj).mul(eps)
e1 = v1 / v1.norm(dim=-1, keepdim=True)
# Project and pad
u2 = v2 - e1.mul(e1.mul(v2).sum(-1, keepdim=True))
e2 = u2 / u2.norm(dim=-1, keepdim=True)
else:
e1 = v1 / v1.square().sum(dim=-1, keepdim=True).add(eps).sqrt()
# Project and pad
u2 = v2 - e1.mul(e1.mul(v2).sum(-1, keepdim=True))
e2 = u2 / u2.square().sum(dim=-1, keepdim=True).add(eps).sqrt()
e3 = torch.cross(e1, e2, dim=-1)
# Rows - lab frame, columns - internal frame
rot_j = torch.stack([e1, e2, e3], dim=-1)
return RigidTransform(rj, torch.nan_to_num(rot_j, 0.0))
def cartesian_to_internal(rs: torch.Tensor, frames: RigidTransform):
"""Right-multiply the pose matrix."""
rs_loc = rs - frames.t
rs_loc = torch.matmul(rs_loc.unsqueeze(-2), frames.R)
return rs_loc.squeeze(-2)
def apply_similarity_transform(
X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor
) -> torch.Tensor:
"""Apply a similarity transform to a set of points X.
From: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/ops/points_alignment.html
"""
X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :]
return X

View File

@@ -0,0 +1,4 @@
import wandb
if __name__ == "__main__":
print(f"Generated WandB run ID: {wandb.util.generate_id()}")

View File

@@ -0,0 +1,55 @@
import argparse
import torch
from beartype import beartype
from beartype.typing import Literal
def clamp_tensor(value: torch.Tensor, min: float = 1e-6, max: float = 1 - 1e-6) -> torch.Tensor:
"""Set the upper and lower bounds of a tensor via clamping.
:param value: The tensor to clamp.
:param min: The minimum value to clamp to. Default is `1e-6`.
:param max: The maximum value to clamp to. Default is `1 - 1e-6`.
:return: The clamped tensor.
"""
return value.clamp(min=min, max=max)
@beartype
def main(
start_time: float, num_steps: int, sampler: Literal["ODE", "VDODE"] = "VDODE", eta: float = 1.0
):
"""Inspect different ODE samplers by printing the left hand side (LHS) and right hand side.
(RHS) of their time ratio schedules. Note that the LHS and RHS are clamped to the range
`[1e-6, 1 - 1e-6]` by default.
:param start_time: The start time of the ODE sampler.
:param num_steps: The number of steps to take.
:param sampler: The ODE sampler to use.
:param eta: The variance diminishing factor.
"""
assert 0 < start_time <= 1.0, "The argument `start_time` must be in the range (0, 1]."
schedule = torch.linspace(start_time, 0, num_steps + 1)
for t, s in zip(schedule[:-1], schedule[1:]):
if sampler == "ODE":
# Baseline ODE
print(
f"t: {t:.3f}; s: {s:.3f}; LHS -> (1 - t) * x0_hat: {clamp_tensor((1 - t)):.3f}; RHS -> t * xt: {clamp_tensor(t):.3f}"
)
elif sampler == "VDODE":
# Variance Diminishing ODE (VD-ODE)
print(
f"t: {t:.3f}; s: {s:.3f}; LHS -> (1 - ((s / t) * eta)) * x0_hat: {clamp_tensor(1 - ((s / t) * eta)):.3f}; RHS -> ((s / t) * eta) * xt: {clamp_tensor((s / t) * eta):.3f}"
)
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--start_time", type=float, default=1.0)
argparser.add_argument("--num_steps", type=int, default=40)
argparser.add_argument("--sampler", type=str, choices=["ODE", "VDODE"], default="VDODE")
argparser.add_argument("--eta", type=float, default=1.0)
args = argparser.parse_args()
main(args.start_time, args.num_steps, sampler=args.sampler, eta=args.eta)

View File

@@ -0,0 +1,58 @@
import hydra
import rootutils
from beartype.typing import List
from lightning import Callback
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.utils import pylogger
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config.
:param callbacks_cfg: A DictConfig object containing callback configurations.
:return: A list of instantiated callbacks.
"""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config.
:param logger_cfg: A DictConfig object containing logger configurations.
:return: A list of instantiated loggers.
"""
logger: List[Logger] = []
if not logger_cfg:
log.warning("No logger configs found! Skipping...")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
return logger

View File

@@ -0,0 +1,59 @@
import rootutils
from beartype.typing import Any, Dict
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import OmegaConf
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.utils import pylogger
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
@rank_zero_only
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
"""Controls which config parts are saved by Lightning loggers.
Additionally saves:
- Number of model parameters
:param object_dict: A dictionary containing the following objects:
- `"cfg"`: A DictConfig object containing the main config.
- `"model"`: The Lightning model.
- `"trainer"`: The Lightning trainer.
"""
hparams = {}
cfg = OmegaConf.to_container(object_dict["cfg"])
model = object_dict["model"]
trainer = object_dict["trainer"]
if not trainer.logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return
hparams["model"] = cfg["model"]
# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)
hparams["data"] = cfg["data"]
hparams["trainer"] = cfg["trainer"]
hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")
hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")
# send hparams to all loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams)

View File

@@ -0,0 +1,88 @@
import subprocess # nosec
import torch
from beartype import beartype
from beartype.typing import Any, Dict, List, Optional, Tuple
MODEL_BATCH = Dict[str, Any]
@beartype
def calculate_usalign_metrics(
pred_pdb_filepath: str,
reference_pdb_filepath: str,
usalign_exec_path: str,
flags: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Calculates US-align structural metrics between predicted and reference macromolecular
structures.
:param pred_pdb_filepath: Filepath to predicted macromolecular structure in PDB format.
:param reference_pdb_filepath: Filepath to reference macromolecular structure in PDB format.
:param usalign_exec_path: Path to US-align executable.
:param flags: Command-line flags to pass to US-align, optional.
:return: Dictionary containing macromolecular US-align structural metrics and metadata.
"""
# run US-align with subprocess and capture output
cmd = [usalign_exec_path, pred_pdb_filepath, reference_pdb_filepath]
if flags is not None:
cmd += flags
output = subprocess.check_output(cmd, text=True, stderr=subprocess.PIPE) # nosec
# parse US-align output to extract structural metrics
metrics = {}
for line in output.splitlines():
line = line.strip()
if line.startswith("Name of Structure_1:"):
metrics["Name of Structure_1"] = line.split(": ", 1)[1]
elif line.startswith("Name of Structure_2:"):
metrics["Name of Structure_2"] = line.split(": ", 1)[1]
elif line.startswith("Length of Structure_1:"):
metrics["Length of Structure_1"] = int(line.split(": ")[1].split()[0])
elif line.startswith("Length of Structure_2:"):
metrics["Length of Structure_2"] = int(line.split(": ")[1].split()[0])
elif line.startswith("Aligned length="):
aligned_length = line.split("=")[1].split(",")[0]
rmsd = line.split("=")[2].split(",")[0]
seq_id = line.split("=")[4]
metrics["Aligned length"] = int(aligned_length.strip())
metrics["RMSD"] = float(rmsd.strip())
metrics["Seq_ID"] = float(seq_id.strip())
elif line.startswith("TM-score="):
if "normalized by length of Structure_1" in line:
metrics["TM-score_1"] = float(line.split("=")[1].split()[0])
elif "normalized by length of Structure_2" in line:
metrics["TM-score_2"] = float(line.split("=")[1].split()[0])
return metrics
def compute_per_atom_lddt(
batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes per-atom local distance difference test (LDDT) between predicted and target
coordinates.
:param batch: Dictionary containing metadata and target coordinates.
:param pred_coords: Predicted atomic coordinates.
:param target_coords: Target atomic coordinates.
:return: Tuple of lDDT and lDDT list.
"""
pred_coords = pred_coords.contiguous().view(batch["metadata"]["num_structid"], -1, 3)
target_coords = target_coords.contiguous().view(batch["metadata"]["num_structid"], -1, 3)
target_dist = (target_coords[:, :, None] - target_coords[:, None, :]).norm(dim=-1)
pred_dist = (pred_coords[:, :, None] - pred_coords[:, None, :]).norm(dim=-1)
conserved_mask = target_dist < 15.0
lddt_list = []
thresholds = [0, 0.5, 1, 2, 4, 6, 8, 12, 1e9]
for threshold_idx in range(8):
distdiff = (pred_dist - target_dist).abs()
bin_fraction = (distdiff > thresholds[threshold_idx]) & (
distdiff < thresholds[threshold_idx + 1]
)
lddt_list.append(
bin_fraction.mul(conserved_mask).long().sum(dim=2) / conserved_mask.long().sum(dim=2)
)
lddt_list = torch.stack(lddt_list, dim=-1)
lddt = torch.cumsum(lddt_list[:, :, :4], dim=-1).mean(dim=-1)
return lddt, lddt_list

View File

@@ -0,0 +1,446 @@
# Adapted from: https://github.com/zrqiao/NeuralPLexer
import rootutils
import torch
import torch.nn.functional as F
from beartype.typing import Any, Dict, List, Optional, Tuple, Union
from torch_scatter import scatter_max, scatter_min
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.utils import RankedLogger
MODEL_BATCH = Dict[str, Any]
STATE_DICT = Dict[str, Any]
log = RankedLogger(__name__, rank_zero_only=True)
class GELUMLP(torch.nn.Module):
"""Simple MLP with post-LayerNorm."""
def __init__(
self,
n_in_feats: int,
n_out_feats: int,
n_hidden_feats: Optional[int] = None,
dropout: float = 0.0,
zero_init: bool = False,
):
"""Initialize the GELUMLP."""
super().__init__()
self.dropout = dropout
if n_hidden_feats is None:
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_in_feats, n_in_feats),
torch.nn.GELU(),
torch.nn.LayerNorm(n_in_feats),
torch.nn.Linear(n_in_feats, n_out_feats),
)
else:
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_in_feats, n_hidden_feats),
torch.nn.GELU(),
torch.nn.Dropout(p=self.dropout),
torch.nn.Linear(n_hidden_feats, n_hidden_feats),
torch.nn.GELU(),
torch.nn.LayerNorm(n_hidden_feats),
torch.nn.Linear(n_hidden_feats, n_out_feats),
)
torch.nn.init.xavier_uniform_(self.layers[0].weight, gain=1)
# zero init for residual branches
if zero_init:
self.layers[-1].weight.data.fill_(0.0)
else:
torch.nn.init.xavier_uniform_(self.layers[-1].weight, gain=1)
def _zero_init(self, module):
"""Zero-initialize weights and biases."""
if isinstance(module, torch.nn.Linear):
module.weight.data.zero_()
if module.bias is not None:
module.bias.data.zero_()
def forward(self, x: torch.Tensor):
"""Forward pass through the GELUMLP."""
x = F.dropout(x, p=self.dropout, training=self.training)
return self.layers(x)
class SumPooling(torch.nn.Module):
"""Sum pooling layer."""
def __init__(self, learnable: bool, hidden_dim: int = 1):
"""Initialize the SumPooling layer."""
super().__init__()
self.pooled_transform = (
torch.nn.Linear(hidden_dim, hidden_dim) if learnable else torch.nn.Identity()
)
def forward(self, x, dst_idx, dst_size):
"""Forward pass through the SumPooling layer."""
return self.pooled_transform(segment_sum(x, dst_idx, dst_size))
class AveragePooling(torch.nn.Module):
"""Average pooling layer."""
def __init__(self, learnable: bool, hidden_dim: int = 1):
"""Initialize the AveragePooling layer."""
super().__init__()
self.pooled_transform = (
torch.nn.Linear(hidden_dim, hidden_dim) if learnable else torch.nn.Identity()
)
def forward(self, x, dst_idx, dst_size):
"""Forward pass through the AveragePooling layer."""
out = torch.zeros(
dst_size,
*x.shape[1:],
dtype=x.dtype,
device=x.device,
).index_add_(0, dst_idx, x)
nmr = torch.zeros(
dst_size,
*x.shape[1:],
dtype=x.dtype,
device=x.device,
).index_add_(0, dst_idx, torch.ones_like(x))
return self.pooled_transform(out / (nmr + 1e-8))
def init_weights(m):
"""Initialize weights with Kaiming uniform."""
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_uniform_(m.weight)
def segment_sum(src, dst_idx, dst_size):
"""Computes the sum of each segment in a tensor."""
out = torch.zeros(
dst_size,
*src.shape[1:],
dtype=src.dtype,
device=src.device,
).index_add_(0, dst_idx, src)
return out
def segment_mean(src, dst_idx, dst_size):
"""Computes the mean value of each segment in a tensor."""
out = torch.zeros(
dst_size,
*src.shape[1:],
dtype=src.dtype,
device=src.device,
).index_add_(0, dst_idx, src)
denom = (
torch.zeros(
dst_size,
*src.shape[1:],
dtype=src.dtype,
device=src.device,
).index_add_(0, dst_idx, torch.ones_like(src))
+ 1e-8
)
return out / denom
def segment_argmin(scores, dst_idx, dst_size, randomize: bool = False) -> torch.Tensor:
"""Samples the index of the minimum value in each segment."""
if randomize:
noise = torch.rand_like(scores)
scores = scores - torch.log(-torch.log(noise))
_, sampled_idx = scatter_min(scores, dst_idx, dim=0, dim_size=dst_size)
return sampled_idx
def segment_logsumexp(src, dst_idx, dst_size, extra_dims=None):
"""Computes the logsumexp of each segment in a tensor."""
src_max, _ = scatter_max(src, dst_idx, dim=0, dim_size=dst_size)
if extra_dims is not None:
src_max = torch.amax(src_max, dim=extra_dims, keepdim=True)
src = src - src_max[dst_idx]
out = torch.zeros(
dst_size,
*src.shape[1:],
dtype=src.dtype,
device=src.device,
).index_add_(0, dst_idx, torch.exp(src))
if extra_dims is not None:
out = torch.sum(out, dim=extra_dims)
return torch.log(out + 1e-8) + src_max.view(*out.shape)
def segment_softmax(src, dst_idx, dst_size, extra_dims=None, floor_value=None):
"""Computes the softmax of each segment in a tensor."""
src_max, _ = scatter_max(src, dst_idx, dim=0, dim_size=dst_size)
if extra_dims is not None:
src_max = torch.amax(src_max, dim=extra_dims, keepdim=True)
src = src - src_max[dst_idx]
exp1 = torch.exp(src)
exp0 = torch.zeros(
dst_size,
*src.shape[1:],
dtype=src.dtype,
device=src.device,
).index_add_(0, dst_idx, exp1)
if extra_dims is not None:
exp0 = torch.sum(exp0, dim=extra_dims, keepdim=True)
exp0 = torch.index_select(input=exp0, dim=0, index=dst_idx)
exp = exp1.div(exp0 + 1e-8)
if floor_value is not None:
exp = exp.clamp(min=floor_value)
exp0 = torch.zeros(
dst_size,
*src.shape[1:],
dtype=src.dtype,
device=src.device,
).index_add_(0, dst_idx, exp)
if extra_dims is not None:
exp0 = torch.sum(exp0, dim=extra_dims, keepdim=True)
exp0 = torch.index_select(input=exp0, dim=0, index=dst_idx)
exp = exp.div(exp0 + 1e-8)
return exp
def batched_sample_onehot(logits, dim=0, max_only=False):
"""Implements the Gumbel-Max trick to sample from a one-hot distribution."""
if max_only:
sampled_idx = torch.argmax(logits, dim=dim, keepdim=True)
else:
noise = torch.rand_like(logits)
sampled_idx = torch.argmax(logits - torch.log(-torch.log(noise)), dim=dim, keepdim=True)
out_onehot = torch.zeros_like(logits, dtype=torch.bool)
out_onehot.scatter_(dim=dim, index=sampled_idx, value=1)
return out_onehot
def topk_edge_mask_from_logits(scores, k, randomize=False):
"""Samples the top-k edges from a set of logits."""
assert len(scores.shape) == 3, "Scores should have shape [B, N, N]"
if randomize:
noise = torch.rand_like(scores)
scores = scores - torch.log(-torch.log(noise))
node_degree = min(k, scores.shape[2])
_, topk_idx = torch.topk(scores, node_degree, dim=-1, largest=True)
edge_mask = scores.new_zeros(scores.shape, dtype=torch.bool)
edge_mask = edge_mask.scatter_(dim=2, index=topk_idx, value=1).bool()
return edge_mask
def sample_inplace_to_torch(sample):
"""Convert NumPy sample to PyTorch tensors."""
if sample is None:
return None
sample["features"] = {k: torch.FloatTensor(v) for k, v in sample["features"].items()}
sample["indexer"] = {k: torch.LongTensor(v) for k, v in sample["indexer"].items()}
if "labels" in sample.keys():
sample["labels"] = {k: torch.FloatTensor(v) for k, v in sample["labels"].items()}
return sample
def inplace_to_device(sample, device):
"""Move sample to device."""
sample["features"] = {k: v.to(device) for k, v in sample["features"].items()}
sample["indexer"] = {k: v.to(device) for k, v in sample["indexer"].items()}
if "labels" in sample.keys():
sample["labels"] = sample["labels"].to(device)
return sample
def inplace_to_torch(sample):
"""Convert NumPy sample to PyTorch tensors."""
if sample is None:
return None
sample["features"] = {k: torch.FloatTensor(v) for k, v in sample["features"].items()}
sample["indexer"] = {k: torch.LongTensor(v) for k, v in sample["indexer"].items()}
if "labels" in sample.keys():
sample["labels"] = {k: torch.FloatTensor(v) for k, v in sample["labels"].items()}
return sample
def distance_to_gaussian_contact_logits(
x: torch.Tensor, contact_scale: float, cutoff: Optional[float] = None
) -> torch.Tensor:
"""Convert distance to Gaussian contact logits.
:param x: Distance tensor.
:param contact_scale: The contact scale.
:param cutoff: The distance cutoff.
:return: Gaussian contact logits.
"""
if cutoff is None:
cutoff = contact_scale * 2
return torch.log(torch.clamp(1 - (x / cutoff), min=1e-9))
def distogram_to_gaussian_contact_logits(
dgram: torch.Tensor, dist_bins: torch.Tensor, contact_scale: float
) -> torch.Tensor:
"""Convert a distance histogram (distogram) matrix to a Gaussian contact map.
:param dgram: A distogram matrix.
:return: A Gaussian contact map.
"""
return torch.logsumexp(
dgram + distance_to_gaussian_contact_logits(dist_bins, contact_scale),
dim=-1,
)
def eval_true_contact_maps(
batch: MODEL_BATCH, contact_scale: float, **kwargs: Dict[str, Any]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Evaluate true contact maps.
:param batch: A batch dictionary.
:param contact_scale: The contact scale.
:param kwargs: Additional keyword arguments.
:return: True contact maps.
"""
indexer = batch["indexer"]
batch_size = batch["metadata"]["num_structid"]
with torch.no_grad():
# Residue centroids
res_cent_coords = (
batch["features"]["res_atom_positions"]
.mul(batch["features"]["res_atom_mask"].bool()[:, :, None])
.sum(dim=1)
.div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9)
)
res_lig_dist = (
res_cent_coords.view(batch_size, -1, 3)[:, :, None]
- batch["features"]["sdf_coordinates"][indexer["gather_idx_U_u"]].view(
batch_size, -1, 3
)[:, None, :]
).norm(dim=-1)
res_lig_contact_logit = distance_to_gaussian_contact_logits(
res_lig_dist, contact_scale, **kwargs
)
return res_lig_dist, res_lig_contact_logit.flatten()
def sample_reslig_contact_matrix(
batch: MODEL_BATCH, res_lig_logits: torch.Tensor, last: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Sample residue-ligand contact matrix.
:param batch: A batch dictionary.
:param res_lig_logits: Residue-ligand contact logits.
:param last: The last contact matrix.
:return: Sampled residue-ligand contact matrix.
"""
metadata = batch["metadata"]
batch_size = metadata["num_structid"]
max(metadata["num_molid_per_sample"])
n_a_per_sample = max(metadata["num_a_per_sample"])
n_I_per_sample = max(metadata["num_I_per_sample"])
res_lig_logits = res_lig_logits.view(batch_size, n_a_per_sample, n_I_per_sample)
# Sampling from unoccupied lattice sites
if last is None:
last = torch.zeros_like(res_lig_logits, dtype=torch.bool)
# Column-graph-wise masking for already sampled ligands
# sampled_ligand_mask = torch.amax(last, dim=1, keepdim=True)
sampled_frame_mask = torch.sum(last, dim=1, keepdim=True).contiguous()
masked_logits = res_lig_logits - sampled_frame_mask * 1e9
sampled_block_onehot = batched_sample_onehot(masked_logits.flatten(1, 2), dim=1).view(
batch_size, n_a_per_sample, n_I_per_sample
)
new_block_contact_mat = last + sampled_block_onehot
# Remove non-contact patches
valid_logit_mask = res_lig_logits > -16.0
new_block_contact_mat = (new_block_contact_mat * valid_logit_mask).bool()
return new_block_contact_mat
def merge_res_lig_logits_to_graph(
batch: MODEL_BATCH,
res_lig_logits: torch.Tensor,
single_protein_batch: bool,
) -> torch.Tensor:
"""Patch merging [B, N_res, N_atm] -> [B, N_res, N_graph].
:param batch: A batch dictionary.
:param res_lig_logits: Residue-ligand contact logits.
:param single_protein_batch: Whether to use single protein batch.
:return: Merged residue-ligand logits.
"""
assert single_protein_batch, "Only single protein batch is supported."
metadata = batch["metadata"]
indexer = batch["indexer"]
batch_size = metadata["num_structid"]
max(metadata["num_molid_per_sample"])
n_mol_per_sample = max(metadata["num_molid_per_sample"])
n_a_per_sample = max(metadata["num_a_per_sample"])
n_I_per_sample = max(metadata["num_I_per_sample"])
res_lig_logits = res_lig_logits.view(batch_size, n_a_per_sample, n_I_per_sample)
graph_wise_logits = segment_logsumexp(
res_lig_logits.permute(2, 0, 1),
indexer["gather_idx_I_molid"][:n_I_per_sample],
n_mol_per_sample,
).permute(1, 2, 0)
return graph_wise_logits
def sample_res_rowmask_from_contacts(
batch: MODEL_BATCH,
res_ligatm_logits: torch.Tensor,
single_protein_batch: bool,
) -> torch.Tensor:
"""Sample residue row mask from contacts.
:param batch: A batch dictionary.
:param res_ligatm_logits: Residue-ligand atom contact logits.
:return: Sampled residue row mask.
"""
metadata = batch["metadata"]
max(metadata["num_molid_per_sample"])
lig_wise_logits = (
merge_res_lig_logits_to_graph(batch, res_ligatm_logits, single_protein_batch)
.permute(0, 2, 1)
.contiguous()
)
sampled_res_onehot_mask = batched_sample_onehot(lig_wise_logits.flatten(0, 1), dim=1)
return sampled_res_onehot_mask
def extract_esm_embeddings(
esm_model: torch.nn.Module,
esm_alphabet: torch.nn.Module,
esm_batch_converter: torch.nn.Module,
sequences: List[str],
device: Union[str, torch.device],
esm_repr_layer: int = 33,
) -> List[torch.Tensor]:
"""Extract embeddings from ESM model.
:param esm_model: ESM model.
:param esm_alphabet: ESM alphabet.
:param esm_batch_converter: ESM batch converter.
:param sequences: A list of sequences.
:param device: Device to use.
:param esm_repr_layer: ESM representation layer index from which to extract embeddings.
:return: A corresponding list of embeddings.
"""
# Disable dropout for deterministic results
esm_model.eval()
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [(str(i), seq) for i, seq in enumerate(sequences)]
_, _, batch_tokens = esm_batch_converter(data)
batch_tokens = batch_tokens.to(device)
batch_lens = (batch_tokens != esm_alphabet.padding_idx).sum(1)
# Extract per-residue representations (on CPU)
with torch.no_grad():
results = esm_model(batch_tokens, repr_layers=[esm_repr_layer], return_contacts=True)
token_representations = results["representations"][esm_repr_layer]
# Generate per-residue representations
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
sequence_representations.append(token_representations[i, 1 : tokens_len - 1])
return sequence_representations

View File

@@ -0,0 +1,51 @@
import logging
from beartype.typing import Mapping, Optional
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
class RankedLogger(logging.LoggerAdapter):
"""A multi-GPU-friendly python command line logger."""
def __init__(
self,
name: str = __name__,
rank_zero_only: bool = False,
extra: Optional[Mapping[str, object]] = None,
) -> None:
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
with their rank prefixed in the log message.
:param name: The name of the logger. Default is ``__name__``.
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
"""
logger = logging.getLogger(name)
super().__init__(logger=logger, extra=extra)
self.rank_zero_only = rank_zero_only
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
of the process it's being logged from. If `'rank'` is provided, then the log will only
occur on that rank/process.
:param level: The level to log at. Look at `logging.__init__.py` for more information.
:param msg: The message to log.
:param rank: The rank to log at.
:param args: Additional args to pass to the underlying logging function.
:param kwargs: Any additional keyword args to pass to the underlying logging function.
"""
if self.isEnabledFor(level):
msg, kwargs = self.process(msg, kwargs)
current_rank = getattr(rank_zero_only, "rank", None)
if current_rank is None:
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
msg = rank_prefixed_message(msg, current_rank)
if self.rank_zero_only:
if current_rank == 0:
self.logger.log(level, msg, *args, **kwargs)
else:
if rank is None:
self.logger.log(level, msg, *args, **kwargs)
elif current_rank == rank:
self.logger.log(level, msg, *args, **kwargs)

View File

@@ -0,0 +1,102 @@
from pathlib import Path
import rich
import rich.syntax
import rich.tree
import rootutils
from beartype.typing import Sequence
from hydra.core.hydra_config import HydraConfig
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.prompt import Prompt
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.utils import pylogger
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
@rank_zero_only
def print_config_tree(
cfg: DictConfig,
print_order: Sequence[str] = (
"data",
"model",
"callbacks",
"logger",
"trainer",
"paths",
"extras",
),
resolve: bool = False,
save_to_file: bool = False,
) -> None:
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
:param cfg: A DictConfig composed by Hydra.
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
"callbacks", "logger", "trainer", "paths", "extras")``.
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
queue = []
# add fields from `print_order` to queue
for field in print_order:
queue.append(field) if field in cfg else log.warning(
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
# print config tree
rich.print(tree)
# save config tree to file
if save_to_file:
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
rich.print(tree, file=file)
@rank_zero_only
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""Prompts user to input tags from command line if no tags are provided in config.
:param cfg: A DictConfig composed by Hydra.
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
"""
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
log.warning("No tags provided in config. Prompting user to input tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
tags = [t.strip() for t in tags.split(",") if t != ""]
with open_dict(cfg):
cfg.tags = tags
log.info(f"Tags: {cfg.tags}")
if save_to_file:
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
rich.print(cfg.tags, file=file)

View File

@@ -0,0 +1,427 @@
import os
import numpy as np
import pandas as pd
import rootutils
import torch
from beartype.typing import Any, Dict, List, Optional, Tuple
from lightning import LightningModule
from omegaconf import DictConfig
from rdkit import Chem
from rdkit.Chem import AllChem
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.data.components.mol_features import (
collate_numpy_samples,
process_mol_file,
)
from flowdock.utils import RankedLogger
from flowdock.utils.data_utils import (
FDProtein,
merge_protein_and_ligands,
pdb_filepath_to_protein,
prepare_batch,
process_protein,
)
from flowdock.utils.model_utils import inplace_to_device, inplace_to_torch, segment_mean
from flowdock.utils.visualization_utils import (
write_conformer_sdf,
write_pdb_models,
write_pdb_single,
)
log = RankedLogger(__name__, rank_zero_only=True)
def featurize_protein_and_ligands(
rec_path: str,
lig_paths: List[str],
n_lig_patches: int,
apo_rec_path: Optional[str] = None,
chain_id: Optional[str] = None,
protein: Optional[FDProtein] = None,
sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None,
enforce_sanitization: bool = False,
discard_sdf_coords: bool = False,
**kwargs: Dict[str, Any],
):
"""Featurize a protein-ligand complex.
:param rec_path: Path to the receptor file.
:param lig_paths: List of paths to the ligand files.
:param n_lig_patches: Number of ligand patches.
:param apo_rec_path: Path to the apo receptor file.
:param chain_id: Chain ID of the receptor.
:param protein: Optional protein object.
:param sequences_to_embeddings: Mapping of sequences to embeddings.
:param enforce_sanitization: Whether to enforce sanitization.
:param discard_sdf_coords: Whether to discard SDF coordinates.
:param kwargs: Additional keyword arguments.
:return: Featurized protein-ligand complex.
"""
assert rec_path is not None
if lig_paths is None:
lig_paths = []
if isinstance(lig_paths, str):
lig_paths = [lig_paths]
out_mol = None
lig_samples = []
for lig_path in lig_paths:
try:
lig_sample, mol_ref = process_mol_file(
lig_path,
sanitize=True,
return_mol=True,
discard_coords=discard_sdf_coords,
)
except Exception as e:
if enforce_sanitization:
raise
log.warning(
f"RDKit sanitization failed for ligand {lig_path} due to: {e}. Loading raw attributes."
)
lig_sample, mol_ref = process_mol_file(
lig_path,
sanitize=False,
return_mol=True,
discard_coords=discard_sdf_coords,
)
lig_samples.append(lig_sample)
if out_mol is None:
out_mol = mol_ref
else:
out_mol = AllChem.CombineMols(out_mol, mol_ref)
protein = protein if protein is not None else pdb_filepath_to_protein(rec_path)
rec_sample = process_protein(
protein,
chain_id=chain_id,
sequences_to_embeddings=None if apo_rec_path is not None else sequences_to_embeddings,
**kwargs,
)
if apo_rec_path is not None:
apo_protein = pdb_filepath_to_protein(apo_rec_path)
apo_rec_sample = process_protein(
apo_protein,
chain_id=chain_id,
sequences_to_embeddings=sequences_to_embeddings,
**kwargs,
)
for key in rec_sample.keys():
for subkey, value in apo_rec_sample[key].items():
rec_sample[key]["apo_" + subkey] = value
merged_sample = merge_protein_and_ligands(
lig_samples,
rec_sample,
n_lig_patches=n_lig_patches,
label=None,
)
return merged_sample, out_mol
def multi_pose_sampling(
receptor_path: str,
ligand_path: str,
cfg: DictConfig,
lit_module: LightningModule,
out_path: str,
save_pdb: bool = True,
separate_pdb: bool = True,
chain_id: Optional[str] = None,
apo_receptor_path: Optional[str] = None,
sample_id: Optional[str] = None,
protein: Optional[FDProtein] = None,
sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None,
confidence: bool = True,
affinity: bool = True,
return_all_states: bool = False,
auxiliary_estimation_only: bool = False,
**kwargs: Dict[str, Any],
) -> Tuple[
Optional[Chem.Mol],
Optional[List[float]],
Optional[List[float]],
Optional[List[float]],
Optional[List[Any]],
Optional[Any],
Optional[np.ndarray],
Optional[np.ndarray],
]:
"""Sample multiple poses of a protein-ligand complex.
:param receptor_path: Path to the receptor file.
:param ligand_path: Path to the ligand file.
:param cfg: Config dictionary.
:param lit_module: LightningModule instance.
:param out_path: Path to save the output files.
:param save_pdb: Whether to save PDB files.
:param separate_pdb: Whether to save separate PDB files for each pose.
:param chain_id: Chain ID of the receptor.
:param apo_receptor_path: Path to the optional apo receptor file.
:param sample_id: Optional sample ID.
:param protein: Optional protein object.
:param sequences_to_embeddings: Mapping of sequences to embeddings.
:param confidence: Whether to estimate confidence scores.
:param affinity: Whether to estimate affinity scores.
:param return_all_states: Whether to return all states.
:param auxiliary_estimation_only: Whether to only estimate auxiliary outputs (e.g., confidence,
affinity) for the input (generated) samples (potentially derived from external sources).
:param kwargs: Additional keyword arguments.
:return: Reference molecule, protein plDDTs, ligand plDDTs, ligand fragment plDDTs, estimated
binding affinities, structure trajectories, input batch, B-factors, and structure rankings.
"""
if return_all_states and auxiliary_estimation_only:
# NOTE: If auxiliary estimation is solely enabled, structure trajectory sampling will be disabled
return_all_states = False
struct_res_all, lig_res_all = [], []
plddt_all, plddt_lig_all, plddt_ligs_all, res_plddt_all = [], [], [], []
affinity_all, ligs_affinity_all = [], []
frames_all = []
chunk_size = cfg.chunk_size
for _ in range(cfg.n_samples // chunk_size):
# Resample anchor node frames
np_sample, mol = featurize_protein_and_ligands(
receptor_path,
ligand_path,
n_lig_patches=lit_module.hparams.cfg.mol_encoder.n_patches,
apo_rec_path=apo_receptor_path,
chain_id=chain_id,
protein=protein,
sequences_to_embeddings=sequences_to_embeddings,
discard_sdf_coords=cfg.discard_sdf_coords and not auxiliary_estimation_only,
**kwargs,
)
np_sample_batched = collate_numpy_samples([np_sample for _ in range(chunk_size)])
sample = inplace_to_device(inplace_to_torch(np_sample_batched), device=lit_module.device)
prepare_batch(sample)
if auxiliary_estimation_only:
# Predict auxiliary quantities using the provided input protein and ligand structures
if "num_molid" in sample["metadata"].keys() and sample["metadata"]["num_molid"] > 0:
sample["misc"]["protein_only"] = False
else:
sample["misc"]["protein_only"] = True
output_struct = {
"receptor": sample["features"]["res_atom_positions"].flatten(0, 1),
"receptor_padded": sample["features"]["res_atom_positions"],
"ligands": sample["features"]["sdf_coordinates"],
}
else:
output_struct = lit_module.net.sample_pl_complex_structures(
sample,
sampler=cfg.sampler,
sampler_eta=cfg.sampler_eta,
num_steps=cfg.num_steps,
return_all_states=return_all_states,
start_time=cfg.start_time,
exact_prior=cfg.exact_prior,
)
frames_all.append(output_struct.get("all_frames", None))
if mol is not None:
ref_mol = AllChem.Mol(mol)
out_x1 = np.split(output_struct["ligands"].cpu().numpy(), cfg.chunk_size)
out_x2 = np.split(output_struct["receptor_padded"].cpu().numpy(), cfg.chunk_size)
if confidence and affinity:
assert (
lit_module.net.confidence_cfg.enabled
), "Confidence estimation must be enabled in the model configuration."
assert (
lit_module.net.affinity_cfg.enabled
), "Affinity estimation must be enabled in the model configuration."
plddt, plddt_lig, plddt_ligs = lit_module.net.run_auxiliary_estimation(
sample,
output_struct,
return_avg_stats=True,
training=False,
)
aff = sample["outputs"]["affinity_logits"]
elif confidence:
assert (
lit_module.net.confidence_cfg.enabled
), "Confidence estimation must be enabled in the model configuration."
plddt, plddt_lig, plddt_ligs = lit_module.net.run_auxiliary_estimation(
sample,
output_struct,
return_avg_stats=True,
training=False,
)
elif affinity:
assert (
lit_module.net.affinity_cfg.enabled
), "Affinity estimation must be enabled in the model configuration."
lit_module.net.run_auxiliary_estimation(
sample, output_struct, return_avg_stats=True, training=False
)
plddt, plddt_lig = None, None
aff = sample["outputs"]["affinity_logits"].cpu()
mol_idx_i_structid = segment_mean(
sample["indexer"]["gather_idx_i_structid"],
sample["indexer"]["gather_idx_i_molid"],
sample["metadata"]["num_molid"],
).long()
for struct_idx in range(cfg.chunk_size):
struct_res = {
"features": {
"asym_id": np_sample["features"]["res_chain_id"],
"residue_index": np.arange(len(np_sample["features"]["res_type"])) + 1,
"aatype": np_sample["features"]["res_type"],
},
"structure_module": {
"final_atom_positions": out_x2[struct_idx],
"final_atom_mask": sample["features"]["res_atom_mask"].bool().cpu().numpy(),
},
}
struct_res_all.append(struct_res)
if mol is not None:
lig_res_all.append(out_x1[struct_idx])
if confidence:
plddt_all.append(plddt[struct_idx].item())
res_plddt_all.append(
sample["outputs"]["plddt"][
struct_idx, : sample["metadata"]["num_a_per_sample"][0]
]
.cpu()
.numpy()
)
if plddt_lig is None:
plddt_lig_all.append(None)
else:
plddt_lig_all.append(plddt_lig[struct_idx].item())
if plddt_ligs is None:
plddt_ligs_all.append(None)
else:
plddt_ligs_all.append(plddt_ligs[mol_idx_i_structid == struct_idx].tolist())
if affinity:
# collect the average affinity across all ligands in each complex
ligs_aff = aff[mol_idx_i_structid == struct_idx]
affinity_all.append(ligs_aff.mean().item())
ligs_affinity_all.append(ligs_aff.tolist())
if confidence and cfg.rank_outputs_by_confidence:
plddt_lig_predicted = all(plddt_lig_all)
if cfg.plddt_ranking_type == "protein":
struct_plddts = np.array(plddt_all) # rank outputs using average protein plDDT
elif cfg.plddt_ranking_type == "ligand":
struct_plddts = np.array(
plddt_lig_all if plddt_lig_predicted else plddt_all
) # rank outputs using average ligand plDDT if available
if not plddt_lig_predicted:
log.warning(
"Ligand plDDT not available for all samples, using protein plDDT instead"
)
elif cfg.plddt_ranking_type == "protein_ligand":
struct_plddts = np.array(
plddt_all + plddt_lig_all if plddt_lig_predicted else plddt_all
) # rank outputs using the sum of the average protein and ligand plDDTs if ligand plDDT is available
if not plddt_lig_predicted:
log.warning(
"Ligand plDDT not available for all samples, using protein plDDT instead"
)
struct_plddt_rankings = np.argsort(
-struct_plddts
).argsort() # ensure that higher plDDTs have a higher rank (e.g., `rank1`)
receptor_plddt = np.array(res_plddt_all) if confidence else None
b_factors = (
np.repeat(
receptor_plddt[..., None],
struct_res_all[0]["structure_module"]["final_atom_mask"].shape[-1],
axis=-1,
)
if confidence
else None
)
if save_pdb:
if separate_pdb:
for struct_id, struct_res in enumerate(struct_res_all):
if confidence and cfg.rank_outputs_by_confidence:
write_pdb_single(
struct_res,
out_path=os.path.join(
out_path,
f"prot_rank{struct_plddt_rankings[struct_id] + 1}_plddt{struct_plddts[struct_id]:.7f}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.pdb",
),
b_factors=b_factors[struct_id] if confidence else None,
)
else:
write_pdb_single(
struct_res,
out_path=os.path.join(
out_path,
f"prot_{struct_id}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.pdb",
),
b_factors=b_factors[struct_id] if confidence else None,
)
write_pdb_models(
struct_res_all, out_path=os.path.join(out_path, "prot_all.pdb"), b_factors=b_factors
)
if mol is not None:
write_conformer_sdf(ref_mol, None, out_path=os.path.join(out_path, "lig_ref.sdf"))
lig_res_all = np.array(lig_res_all)
write_conformer_sdf(mol, lig_res_all, out_path=os.path.join(out_path, "lig_all.sdf"))
for struct_id in range(len(lig_res_all)):
if confidence and cfg.rank_outputs_by_confidence:
write_conformer_sdf(
mol,
lig_res_all[struct_id : struct_id + 1],
out_path=os.path.join(
out_path,
f"lig_rank{struct_plddt_rankings[struct_id] + 1}_plddt{struct_plddts[struct_id]:.7f}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.sdf",
),
)
else:
write_conformer_sdf(
mol,
lig_res_all[struct_id : struct_id + 1],
out_path=os.path.join(
out_path,
f"lig_{struct_id}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.sdf",
),
)
if confidence:
aux_estimation_all_df = pd.DataFrame(
{
"sample_id": [sample_id] * len(struct_res_all),
"rank": struct_plddt_rankings + 1 if cfg.rank_outputs_by_confidence else None,
"plddt_ligs": plddt_ligs_all,
"affinity_ligs": ligs_affinity_all,
}
)
aux_estimation_all_df.to_csv(
os.path.join(out_path, f"{sample_id if sample_id is not None else 'sample'}_auxiliary_estimation.csv"), index=False
)
else:
ref_mol = None
if not confidence:
plddt_all, plddt_lig_all, plddt_ligs_all = None, None, None
if not affinity:
affinity_all = None
if return_all_states:
if mol is not None:
np_sample["metadata"]["sample_ID"] = sample_id if sample_id is not None else "sample"
np_sample["metadata"]["mol"] = ref_mol
batch_all = inplace_to_torch(
collate_numpy_samples([np_sample for _ in range(cfg.n_samples)])
)
merge_frames_all = frames_all[0]
for frames in frames_all[1:]:
for frame_index, frame in enumerate(frames):
for key in frame.keys():
merge_frames_all[frame_index][key] = torch.cat(
[merge_frames_all[frame_index][key], frame[key]], dim=0
)
frames_all = merge_frames_all
else:
frames_all = None
batch_all = None
if not (confidence and cfg.rank_outputs_by_confidence):
struct_plddt_rankings = None
return (
ref_mol,
plddt_all,
plddt_lig_all,
plddt_ligs_all,
affinity_all,
frames_all,
batch_all,
b_factors,
struct_plddt_rankings,
)

153
flowdock/utils/utils.py Normal file
View File

@@ -0,0 +1,153 @@
import warnings
from importlib.util import find_spec
import rootutils
from beartype.typing import Any, Callable, Dict, List, Optional, Tuple
from omegaconf import DictConfig
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.utils import pylogger, rich_utils
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
:param cfg: A DictConfig object containing the config tree.
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
rich_utils.enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
```
:param task_func: The task function to be wrapped.
:return: The wrapped task function.
"""
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: If provided, the name of the metric to retrieve.
:return: If a metric name was provided, the value of the metric.
"""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def read_strings_from_txt(path: str) -> List[str]:
"""Reads strings from a text file and returns them as a list.
:param path: Path to the text file.
:return: List of strings.
"""
with open(path) as file:
# NOTE: every line will be one element of the returned list
lines = file.readlines()
return [line.rstrip() for line in lines]
def fasta_to_dict(filename: str) -> Dict[str, str]:
"""Converts a FASTA file to a dictionary where the keys are the sequence IDs and the values are
the sequences.
:param filename: Path to the FASTA file.
:return: Dictionary with sequence IDs as keys and sequences as values.
"""
fasta_dict = {}
with open(filename) as file:
for line in file:
line = line.rstrip() # remove trailing whitespace
if line.startswith(">"): # identifier line
seq_id = line[1:] # remove the '>' character
fasta_dict[seq_id] = ""
else: # sequence line
fasta_dict[seq_id] += line
return fasta_dict

View File

@@ -0,0 +1,365 @@
import os
import numpy as np
import rootutils
from beartype import beartype
from beartype.typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from openfold.np.protein import Protein as OFProtein
from rdkit import Chem
from rdkit.Geometry.rdGeometry import Point3D
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from flowdock.data.components import residue_constants
from flowdock.utils.data_utils import (
PDB_CHAIN_IDS,
PDB_MAX_CHAINS,
FDProtein,
create_full_prot,
get_mol_with_new_conformer_coords,
)
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PROT_LIG_PAIRS = List[Tuple[OFProtein, Tuple[Chem.Mol, ...]]]
@beartype
def _chain_end(
atom_index: Union[int, np.int64],
end_resname: str,
chain_name: str,
residue_index: Union[int, np.int64],
) -> str:
"""Returns a PDB `TER` record for the end of a chain.
Adapted from: https://github.com/jasonkyuyim/se3_diffusion
:param atom_index: The index of the last atom in the chain.
:param end_resname: The residue name of the last residue in the chain.
:param chain_name: The chain name of the last residue in the chain.
:param residue_index: The residue index of the last residue in the chain.
:return: A PDB `TER` record.
"""
chain_end = "TER"
return (
f"{chain_end:<6}{atom_index:>5} {end_resname:>3} "
f"{chain_name:>1}{residue_index:>4}"
)
@beartype
def res_1to3(restypes: List[str], r: Union[int, np.int64]) -> str:
"""Convert a residue type from 1-letter to 3-letter code.
:param restypes: List of residue types.
:param r: Residue type index.
:return: 3-letter code as a string.
"""
return residue_constants.restype_1to3.get(restypes[r], "UNK")
@beartype
def to_pdb(prot: Union[OFProtein, FDProtein], model=1, add_end=True, add_endmdl=True) -> str:
"""Converts a `Protein` instance to a PDB string.
Adapted from: https://github.com/jasonkyuyim/se3_diffusion
:param prot: The protein to convert to PDB.
:param model: The model number to use.
:param add_end: Whether to add an `END` record.
:param add_endmdl: Whether to add an `ENDMDL` record.
:return: PDB string.
"""
restypes = residue_constants.restypes + ["X"]
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(int)
chain_index = prot.chain_index.astype(int)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
# construct a mapping from chain integer indices to chain ID strings
chain_ids = {}
for i in np.unique(chain_index): # NOTE: `np.unique` gives sorted output
if i >= PDB_MAX_CHAINS:
raise ValueError(f"The PDB format supports at most {PDB_MAX_CHAINS} chains.")
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append(f"MODEL {model}")
atom_index = 1
last_chain_index = chain_index[0]
# add all atom sites
for i in range(aatype.shape[0]):
# close the previous chain if in a multichain PDB
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(restypes, aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1],
)
)
last_chain_index = chain_index[i]
atom_index += 1 # NOTE: atom index increases at the `TER` symbol
res_name_3 = res_1to3(restypes, aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[0] # NOTE: `Protein` supports only C, N, O, S, this works
charge = ""
# NOTE: PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
# close the final chain
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(restypes, aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1],
)
)
if add_endmdl:
pdb_lines.append("ENDMDL")
if add_end:
pdb_lines.append("END")
# pad all lines to 80 characters
pdb_lines = [line.ljust(80) for line in pdb_lines]
return "\n".join(pdb_lines) + "\n" # add terminating newline
@beartype
def construct_prot_lig_pairs(outputs: Dict[str, Any], batch_index: int) -> PROT_LIG_PAIRS:
"""Construct protein-ligand pairs from model outputs.
:param outputs: The model outputs.
:param batch_index: The index of the current batch.
:return: A list of protein-ligand object pairs.
"""
protein_batch_indexer = outputs["protein_batch_indexer"]
ligand_batch_indexer = outputs["ligand_batch_indexer"]
protein_all_atom_mask = outputs["res_atom_mask"][protein_batch_indexer == batch_index]
protein_all_atom_coordinates_mask = np.broadcast_to(
np.expand_dims(protein_all_atom_mask, -1), (protein_all_atom_mask.shape[0], 37, 3)
)
protein_aatype = outputs["aatype"][protein_batch_indexer == batch_index]
# assemble predicted structures
prot_lig_pairs = []
for protein_coordinates, ligand_coordinates in zip(
outputs["protein_coordinates_list"], outputs["ligand_coordinates_list"]
):
protein_all_atom_coordinates = (
protein_coordinates[protein_batch_indexer == batch_index]
* protein_all_atom_coordinates_mask
)
protein = create_full_prot(
protein_all_atom_coordinates,
protein_all_atom_mask,
protein_aatype,
b_factors=outputs["b_factors"][batch_index] if "b_factors" in outputs else None,
)
ligand = get_mol_with_new_conformer_coords(
outputs["ligand_mol"][batch_index],
ligand_coordinates[ligand_batch_indexer == batch_index],
)
ligands = tuple(Chem.GetMolFrags(ligand, asMols=True, sanitizeFrags=False))
prot_lig_pairs.append((protein, ligands))
# assemble ground-truth structures
if "gt_protein_coordinates" in outputs and "gt_ligand_coordinates" in outputs:
protein_gt_all_atom_coordinates = (
outputs["gt_protein_coordinates"][protein_batch_indexer == batch_index]
* protein_all_atom_coordinates_mask
)
gt_protein = create_full_prot(
protein_gt_all_atom_coordinates,
protein_all_atom_mask,
protein_aatype,
)
gt_ligand = get_mol_with_new_conformer_coords(
outputs["ligand_mol"][batch_index],
outputs["gt_ligand_coordinates"][ligand_batch_indexer == batch_index],
)
gt_ligands = tuple(Chem.GetMolFrags(gt_ligand, asMols=True, sanitizeFrags=False))
prot_lig_pairs.append((gt_protein, gt_ligands))
return prot_lig_pairs
@beartype
def write_prot_lig_pairs_to_pdb_file(prot_lig_pairs: PROT_LIG_PAIRS, output_filepath: str):
"""Write a list of protein-ligand pairs to a PDB file.
:param prot_lig_pairs: List of protein-ligand object pairs, where each ligand may consist of
multiple ligand chains.
:param output_filepath: Output file path.
"""
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
with open(output_filepath, "w") as f:
model_id = 1
for prot, lig_mols in prot_lig_pairs:
pdb_prot = to_pdb(prot, model=model_id, add_end=False, add_endmdl=False)
f.write(pdb_prot)
for lig_mol in lig_mols:
f.write(
Chem.MolToPDBBlock(lig_mol).replace(
"END\n", "TER\n"
) # enable proper ligand chain separation
)
f.write("END\n")
f.write("ENDMDL\n") # add `ENDMDL` line to separate models
model_id += 1
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = False,
) -> FDProtein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values.
Returns:
A protein instance.
"""
fold_output = result["structure_module"]
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if "asym_id" in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"])
else:
chain_index = np.zeros_like(_maybe_remove_leading_dim(features["aatype"]))
if b_factors is None:
b_factors = np.zeros_like(fold_output["final_atom_mask"])
return FDProtein(
letter_sequences=None,
aatype=_maybe_remove_leading_dim(features["aatype"]),
atom_positions=fold_output["final_atom_positions"],
atom_mask=fold_output["final_atom_mask"],
residue_index=_maybe_remove_leading_dim(features["residue_index"]),
chain_index=chain_index,
b_factors=b_factors,
atomtypes=None,
)
def write_pdb_single(
result: ModelOutput,
out_path: str = os.path.join("test_results", "debug.pdb"),
model: int = 1,
b_factors: Optional[np.ndarray] = None,
):
"""Write a single model to a PDB file.
:param result: Model results batch.
:param out_path: Output path.
:param model: Model ID.
:param b_factors: Optional B-factors.
"""
os.makedirs(os.path.dirname(out_path), exist_ok=True)
protein = from_prediction(result["features"], result, b_factors=b_factors)
out_string = to_pdb(protein, model=model)
with open(out_path, "w") as of:
of.write(out_string)
def write_pdb_models(
results,
out_path: str = os.path.join("test_results", "debug.pdb"),
b_factors: Optional[np.ndarray] = None,
):
"""Write multiple models to a PDB file.
:param results: Model results.
:param out_path: Output path.
:param b_factors: Optional B-factors.
"""
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w") as of:
for mid, result in enumerate(results):
protein = from_prediction(
result["features"],
result,
b_factors=b_factors[mid] if b_factors is not None else None,
)
out_string = to_pdb(protein, model=mid + 1)
of.write(out_string)
of.write("END")
def write_conformer_sdf(
mol: Chem.Mol,
confs: Optional[np.array] = None,
out_path: str = os.path.join("test_results", "debug.sdf"),
):
"""Write a molecule with conformers to an SDF file.
:param mol: RDKit molecule.
:param confs: Conformers.
:param out_path: Output path.
"""
os.makedirs(os.path.dirname(out_path), exist_ok=True)
if confs is None:
w = Chem.SDWriter(out_path)
w.write(mol)
w.close()
return 0
mol.RemoveAllConformers()
for i in range(len(confs)):
conf = Chem.Conformer(mol.GetNumAtoms())
for j in range(mol.GetNumAtoms()):
x, y, z = confs[i, j].tolist()
conf.SetAtomPosition(j, Point3D(x, y, z))
mol.AddConformer(conf, assignId=True)
w = Chem.SDWriter(out_path)
try:
for cid in range(len(confs)):
w.write(mol, confId=cid)
except Exception as e:
w.SetKekulize(False)
for cid in range(len(confs)):
w.write(mol, confId=cid)
w.close()
return 0

Some files were not shown because too many files have changed in this diff Show More