commit a3ffec6a07c252ec2cae0fca88497d98d7dfbdef Author: Olamide Isreal Date: Mon Mar 16 15:23:29 2026 +0100 Initial commit: FlowDock pipeline configured for WES execution diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..3a206c5 --- /dev/null +++ b/.env.example @@ -0,0 +1,6 @@ +# example of file for storing private and user specific environment variables, like keys or system paths +# rename it to ".env" (excluded from version control by default) +# .env is loaded by train.py automatically +# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} + +PLINDER_MOUNT="$(pwd)/data/PLINDER" diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..410bcd8 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,22 @@ +## What does this PR do? + + + +Fixes #\ + +## Before submitting + +- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? +- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? +- [ ] Did you list all the **breaking changes** introduced by this pull request? +- [ ] Did you **test your PR locally** with `pytest` command? +- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? + +## Did you have fun? + +Make sure you had fun coding 🙃 diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 0000000..c66853c --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1,15 @@ +coverage: + status: + # measures overall project coverage + project: + default: + threshold: 100% # how much decrease in coverage is needed to not consider success + + # measures PR or single commit coverage + patch: + default: + threshold: 100% # how much decrease in coverage is needed to not consider success + + + # project: off + # patch: off diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..5a861fd --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,16 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "daily" + ignore: + - dependency-name: "pytorch-lightning" + update-types: ["version-update:semver-patch"] + - dependency-name: "torchmetrics" + update-types: ["version-update:semver-patch"] diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml new file mode 100644 index 0000000..59af159 --- /dev/null +++ b/.github/release-drafter.yml @@ -0,0 +1,44 @@ +name-template: "v$RESOLVED_VERSION" +tag-template: "v$RESOLVED_VERSION" + +categories: + - title: "🚀 Features" + labels: + - "feature" + - "enhancement" + - title: "🐛 Bug Fixes" + labels: + - "fix" + - "bugfix" + - "bug" + - title: "🧹 Maintenance" + labels: + - "maintenance" + - "dependencies" + - "refactoring" + - "cosmetic" + - "chore" + - title: "📝️ Documentation" + labels: + - "documentation" + - "docs" + +change-template: "- $TITLE @$AUTHOR (#$NUMBER)" +change-title-escapes: '\<*_&' # You can add # and @ to disable mentions + +version-resolver: + major: + labels: + - "major" + minor: + labels: + - "minor" + patch: + labels: + - "patch" + default: patch + +template: | + ## Changes + + $CHANGES diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml new file mode 100644 index 0000000..88b7220 --- /dev/null +++ b/.github/workflows/code-quality-main.yaml @@ -0,0 +1,22 @@ +# Same as `code-quality-pr.yaml` but triggered on commit to main branch +# and runs on all files (instead of only the changed ones) + +name: Code Quality Main + +on: + push: + branches: [main] + +jobs: + code-quality: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + + - name: Run pre-commits + uses: pre-commit/action@v2.0.3 diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml new file mode 100644 index 0000000..e58df42 --- /dev/null +++ b/.github/workflows/code-quality-pr.yaml @@ -0,0 +1,36 @@ +# This workflow finds which files were changed, prints them, +# and runs `pre-commit` on those files. + +# Inspired by the sktime library: +# https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml + +name: Code Quality PR + +on: + pull_request: + branches: [main, "release/*", "dev"] + +jobs: + code-quality: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + + - name: Find modified files + id: file_changes + uses: trilom/file-changes-action@v1.2.4 + with: + output: " " + + - name: List modified files + run: echo '${{ steps.file_changes.outputs.files}}' + + - name: Run pre-commits + uses: pre-commit/action@v2.0.3 + with: + extra_args: --files ${{ steps.file_changes.outputs.files}} diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml new file mode 100644 index 0000000..6a45e15 --- /dev/null +++ b/.github/workflows/release-drafter.yml @@ -0,0 +1,27 @@ +name: Release Drafter + +on: + push: + # branches to consider in the event; optional, defaults to all + branches: + - main + +permissions: + contents: read + +jobs: + update_release_draft: + permissions: + # write permission is required to create a github release + contents: write + # write permission is required for autolabeler + # otherwise, read permission is required at least + pull-requests: write + + runs-on: ubuntu-latest + + steps: + # Drafts your next Release notes as Pull Requests are merged into "master" + - uses: release-drafter/release-drafter@v5 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..ca09def --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,139 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + branches: [main, "release/*", "dev"] + +jobs: + run_tests_ubuntu: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + python-version: ["3.8", "3.9", "3.10"] + + timeout-minutes: 20 + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + conda env create -f environment.yml + python -m pip install --upgrade pip + pip install pytest + pip install sh + + - name: List dependencies + run: | + python -m pip list + + - name: Run pytest + run: | + pytest -v + + run_tests_macos: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ["macos-latest"] + python-version: ["3.8", "3.9", "3.10"] + + timeout-minutes: 20 + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + conda env create -f environment.yml + python -m pip install --upgrade pip + pip install pytest + pip install sh + + - name: List dependencies + run: | + python -m pip list + + - name: Run pytest + run: | + pytest -v + + run_tests_windows: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ["windows-latest"] + python-version: ["3.8", "3.9", "3.10"] + + timeout-minutes: 20 + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + conda env create -f environment.yml + python -m pip install --upgrade pip + pip install pytest + + - name: List dependencies + run: | + python -m pip list + + - name: Run pytest + run: | + pytest -v + + # upload code coverage report + code-coverage: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + conda env create -f environment.yml + python -m pip install --upgrade pip + pip install pytest + pip install pytest-cov[toml] + pip install sh + + - name: Run tests and collect coverage + run: pytest --cov flowdock # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3a6e528 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +work/ +.nextflow/ +.nextflow.log* +*.log.* +results/ +__pycache__/ +*.pyc +.vscode/ +.idea/ +*.tmp +*.swp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3de51e1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,150 @@ +default_language_version: + python: python3 + +exclude: "^forks/" + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-docstring-first + - id: check-yaml + - id: debug-statements + - id: detect-private-key + - id: check-executables-have-shebangs + - id: check-toml + - id: check-case-conflict + - id: check-added-large-files + args: ["--maxkb=20000"] + + # python code formatting + - repo: https://github.com/psf/black + rev: 25.1.0 + hooks: + - id: black + args: [--line-length, "99"] + + # python import sorting + - repo: https://github.com/PyCQA/isort + rev: 6.0.1 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + # python upgrading syntax to newer version + - repo: https://github.com/asottile/pyupgrade + rev: v3.19.1 + hooks: + - id: pyupgrade + args: [--py38-plus] + + # python docstring formatting + - repo: https://github.com/myint/docformatter + rev: eb1df347edd128b30cd3368dddc3aa65edcfac38 # Don't autoupdate until https://github.com/PyCQA/docformatter/issues/293 is fixed + hooks: + - id: docformatter + args: + [ + --in-place, + --wrap-summaries=99, + --wrap-descriptions=99, + --style=sphinx, + --black, + ] + + # python docstring coverage checking + - repo: https://github.com/econchick/interrogate + rev: 1.7.0 # or master if you're bold + hooks: + - id: interrogate + args: + [ + --verbose, + --fail-under=80, + --ignore-init-module, + --ignore-init-method, + --ignore-module, + --ignore-nested-functions, + -vv, + ] + + # python check (PEP8), programming errors and code complexity + - repo: https://github.com/PyCQA/flake8 + rev: 7.1.2 + hooks: + - id: flake8 + args: + [ + "--extend-ignore", + "E203,E402,E501,F401,F841,RST2,RST301", + "--exclude", + "logs/*,data/*", + ] + additional_dependencies: [flake8-rst-docstrings==0.3.0] + + # python security linter + - repo: https://github.com/PyCQA/bandit + rev: "1.8.3" + hooks: + - id: bandit + args: ["-s", "B101"] + + # yaml formatting + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 + hooks: + - id: prettier + types: [yaml] + exclude: "environment.yaml" + + # shell scripts linter + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.10.0.1 + hooks: + - id: shellcheck + + # md formatting + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.22 + hooks: + - id: mdformat + args: ["--number"] + additional_dependencies: + - mdformat-gfm + - mdformat-tables + - mdformat_frontmatter + # - mdformat-toc + # - mdformat-black + + # word spelling linter + - repo: https://github.com/codespell-project/codespell + rev: v2.4.1 + hooks: + - id: codespell + args: + - --skip=logs/**,data/**,*.ipynb,flowdock/data/components/constants.py,flowdock/data/components/process_mols.py,flowdock/data/components/residue_constants.py,flowdock/data/components/uff_parameters.csv,flowdock/data/components/chemical/*,flowdock/utils/data_utils.py + # - --ignore-words-list=abc,def + + # jupyter notebook cell output clearing + - repo: https://github.com/kynan/nbstripout + rev: 0.8.1 + hooks: + - id: nbstripout + + # jupyter notebook linting + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.9.1 + hooks: + - id: nbqa-black + args: ["--line-length=99"] + - id: nbqa-isort + args: ["--profile=black"] + - id: nbqa-flake8 + args: + [ + "--extend-ignore=E203,E402,E501,F401,F841", + "--exclude=logs/*,data/*", + ] diff --git a/.project-root b/.project-root new file mode 100644 index 0000000..63eab77 --- /dev/null +++ b/.project-root @@ -0,0 +1,2 @@ +# this file is required for inferring the project root directory +# do not delete diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..166c512 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,49 @@ +FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-runtime + +LABEL authors="BioinfoMachineLearning" + +# Install system requirements +RUN apt-get update && \ + apt-get install -y --reinstall ca-certificates && \ + apt-get install -y --no-install-recommends \ + git \ + wget \ + libxml2 \ + libgl-dev \ + libgl1 \ + gcc \ + g++ \ + procps && \ + rm -rf /var/lib/apt/lists/* + +# Set working directory +RUN mkdir -p /software/flowdock +WORKDIR /software/flowdock + +# Clone FlowDock repository +RUN git clone https://github.com/BioinfoMachineLearning/FlowDock /software/flowdock + +# Create conda environment +RUN conda env create -f /software/flowdock/environments/flowdock_environment.yaml + +# Install local package and ProDy +RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh && \ + conda activate FlowDock && \ + pip install --no-cache-dir -e /software/flowdock && \ + pip install --no-cache-dir --no-dependencies prody==2.4.1" + +# Create checkpoints directory +RUN mkdir -p /software/flowdock/checkpoints + +# Download pretrained weights +RUN wget -q https://zenodo.org/records/15066450/files/flowdock_checkpoints.tar.gz && \ + tar -xzf flowdock_checkpoints.tar.gz && \ + rm flowdock_checkpoints.tar.gz + +# Activate conda environment by default +RUN echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ + echo "conda activate FlowDock" >> ~/.bashrc + +# Default shell +SHELL ["/bin/bash", "-l", "-c"] +CMD ["/bin/bash"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..42402d0 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 BioinfoMachineLearning + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4b3d93a --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ + +help: ## Show help + @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +clean: ## Clean autogenerated files + rm -rf dist + find . -type f -name "*.DS_Store" -ls -delete + find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf + find . | grep -E ".pytest_cache" | xargs rm -rf + find . | grep -E ".ipynb_checkpoints" | xargs rm -rf + rm -f .coverage + +clean-logs: ## Clean logs + rm -rf logs/** + +format: ## Run pre-commit hooks + pre-commit run -a + +sync: ## Merge changes from main branch to your current branch + git pull + git pull origin main + +test: ## Run not slow tests + pytest -k "not slow" + +test-full: ## Run all tests + pytest + +train: ## Train the model + python flowdock/train.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..621da05 --- /dev/null +++ b/README.md @@ -0,0 +1,471 @@ +
+ +# FlowDock + +PyTorch +Lightning +Config: Hydra + + + +[![Paper](http://img.shields.io/badge/paper-arxiv.2412.10966-B31B1B.svg)](https://arxiv.org/abs/2412.10966) +[![Conference](http://img.shields.io/badge/ISMB-2025-4b44ce.svg)](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i198/8199366) +[![Data DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15066450.svg)](https://doi.org/10.5281/zenodo.15066450) + + + +
+ +## Description + +This is the official codebase of the paper + +**FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction** + +\[[arXiv](https://arxiv.org/abs/2412.10966)\] \[[ISMB](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i198/8199366)\] \[[Neurosnap](https://neurosnap.ai/service/FlowDock)\] \[[Tamarind Bio](https://app.tamarind.bio/tools/flowdock)\] + +
+ +![Animation of a flow model-predicted 3D protein-ligand complex structure visualized successively](img/6I67.gif) +![Animation of a flow model-predicted 3D protein-multi-ligand complex structure visualized successively](img/T1152.gif) + +
+ +## Contents + +- [FlowDock](#flowdock) + - [Description](#description) + - [Contents](#contents) + - [Installation](#installation) + - [How to prepare data for `FlowDock`](#how-to-prepare-data-for-flowdock) + - [Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint)](#generating-esm2-embeddings-for-each-protein-optional-cached-input-data-available-on-sharepoint) + - [Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo)](#predicting-apo-protein-structures-using-esmfold-optional-cached-data-available-on-zenodo) + - [How to train `FlowDock`](#how-to-train-flowdock) + - [How to evaluate `FlowDock`](#how-to-evaluate-flowdock) + - [How to create comparative plots of benchmarking results](#how-to-create-comparative-plots-of-benchmarking-results) + - [How to predict new protein-ligand complex structures and their affinities using `FlowDock`](#how-to-predict-new-protein-ligand-complex-structures-and-their-affinities-using-flowdock) + - [For developers](#for-developers) + - [Docker](#docker) + - [Acknowledgements](#acknowledgements) + - [License](#license) + - [Citing this work](#citing-this-work) + +## Installation + +
+ +Install Mamba + +```bash +wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" +bash Miniforge3-$(uname)-$(uname -m).sh # accept all terms and install to the default location +rm Miniforge3-$(uname)-$(uname -m).sh # (optionally) remove installer after using it +source ~/.bashrc # alternatively, one can restart their shell session to achieve the same result +``` + +Install dependencies + +```bash +# clone project +git clone https://github.com/BioinfoMachineLearning/FlowDock +cd FlowDock + +# create conda environment +mamba env create -f environments/flowdock_environment.yaml +conda activate FlowDock # NOTE: one still needs to use `conda` to (de)activate environments +pip3 install -e . # install local project as package +pip3 install prody==2.4.1 --no-dependencies # install ProDy without NumPy dependency +``` + +Download checkpoints + +```bash +# pretrained NeuralPLexer weights +cd checkpoints/ +wget https://zenodo.org/records/10373581/files/neuralplexermodels_downstream_datasets_predictions.zip +unzip neuralplexermodels_downstream_datasets_predictions.zip +rm neuralplexermodels_downstream_datasets_predictions.zip +cd ../ +``` + +```bash +# pretrained FlowDock weights +wget https://zenodo.org/records/15066450/files/flowdock_checkpoints.tar.gz +tar -xzf flowdock_checkpoints.tar.gz +rm flowdock_checkpoints.tar.gz +``` + +Download preprocessed datasets + +```bash +# cached input data for training/validation/testing +wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/ER1hctIBhDVFjM7YepOI6WcBXNBm4_e6EBjFEHAM1A3y5g?download=1" +tar -xzf flowdock_data_cache.tar.gz +rm flowdock_data_cache.tar.gz + +# cached data for PDBBind, Binding MOAD, DockGen, and the PDB-based van der Mers (vdM) dataset +wget https://zenodo.org/records/15066450/files/flowdock_pdbbind_data.tar.gz +tar -xzf flowdock_pdbbind_data.tar.gz +rm flowdock_pdbbind_data.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_moad_data.tar.gz +tar -xzf flowdock_moad_data.tar.gz +rm flowdock_moad_data.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_dockgen_data.tar.gz +tar -xzf flowdock_dockgen_data.tar.gz +rm flowdock_dockgen_data.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_pdbsidechain_data.tar.gz +tar -xzf flowdock_pdbsidechain_data.tar.gz +rm flowdock_pdbsidechain_data.tar.gz +``` + +
+ +## How to prepare data for `FlowDock` + +
+ +**NOTE:** The following steps (besides downloading PDBBind and Binding MOAD's PDB files) are only necessary if one wants to fully process each of the following datasets manually. +Otherwise, preprocessed versions of each dataset can be found on [Zenodo](https://zenodo.org/records/15066450). + +Download data + +```bash +# fetch preprocessed PDBBind and Binding MOAD (as well as the optional DockGen and vdM datasets) +cd data/ + +wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/EXesf4oh6ztOusGqFcDyqP0Bvk-LdJ1DagEl8GNK-HxDtg?download=1" +wget https://zenodo.org/records/10656052/files/BindingMOAD_2020_processed.tar +wget https://zenodo.org/records/10656052/files/DockGen.tar +wget https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02.tar.gz + +mv EXesf4oh6ztOusGqFcDyqP0Bvk-LdJ1DagEl8GNK-HxDtg?download=1 PDBBind.tar.gz + +tar -xzf PDBBind.tar.gz +tar -xf BindingMOAD_2020_processed.tar +tar -xf DockGen.tar +tar -xzf pdb_2021aug02.tar.gz + +rm PDBBind.tar.gz BindingMOAD_2020_processed.tar DockGen.tar pdb_2021aug02.tar.gz + +mkdir pdbbind/ moad/ pdbsidechain/ +mv PDBBind_processed/ pdbbind/ +mv BindingMOAD_2020_processed/ moad/ +mv pdb_2021aug02/ pdbsidechain/ + +cd ../ +``` + +Lastly, to finetune `FlowDock` using the `PLINDER` dataset, one must first prepare this data for training + +```bash +# fetch PLINDER data (NOTE: requires ~1 hour to download and ~750G of storage) +export PLINDER_MOUNT="$(pwd)/data/PLINDER" +mkdir -p "$PLINDER_MOUNT" # create the directory if it doesn't exist + +plinder_download -y +``` + +### Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint) + +To generate the ESM2 embeddings for the protein inputs, +first create all the corresponding FASTA files for each protein sequence + +```bash +python flowdock/data/components/esm_embedding_preparation.py --dataset pdbbind --data_dir data/pdbbind/PDBBind_processed/ --out_file data/pdbbind/pdbbind_sequences.fasta +python flowdock/data/components/esm_embedding_preparation.py --dataset moad --data_dir data/moad/BindingMOAD_2020_processed/pdb_protein/ --out_file data/moad/moad_sequences.fasta +python flowdock/data/components/esm_embedding_preparation.py --dataset dockgen --data_dir data/DockGen/processed_files/ --out_file data/DockGen/dockgen_sequences.fasta +python flowdock/data/components/esm_embedding_preparation.py --dataset pdbsidechain --data_dir data/pdbsidechain/pdb_2021aug02/pdb/ --out_file data/pdbsidechain/pdbsidechain_sequences.fasta +``` + +Then, generate all ESM2 embeddings in batch using the ESM repository's helper script + +```bash +python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/pdbbind/pdbbind_sequences.fasta data/pdbbind/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0 +python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/moad/moad_sequences.fasta data/moad/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0 +python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/DockGen/dockgen_sequences.fasta data/DockGen/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0 +python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/pdbsidechain/pdbsidechain_sequences.fasta data/pdbsidechain/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0 +``` + +### Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo) + +To generate the apo version of each protein structure, +first create ESMFold-ready versions of the combined FASTA files +prepared above by the script `esm_embedding_preparation.py` +for the PDBBind, Binding MOAD, DockGen, and PDBSidechain datasets, respectively + +```bash +python flowdock/data/components/esmfold_sequence_preparation.py dataset=pdbbind +python flowdock/data/components/esmfold_sequence_preparation.py dataset=moad +python flowdock/data/components/esmfold_sequence_preparation.py dataset=dockgen +python flowdock/data/components/esmfold_sequence_preparation.py dataset=pdbsidechain +``` + +Then, predict each apo protein structure using ESMFold's batch +inference script + +```bash +# Note: Having a CUDA-enabled device available when running this script is highly recommended +python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/pdbbind/pdbbind_esmfold_sequences.fasta -o data/pdbbind/pdbbind_esmfold_structures --cuda-device-index 0 --skip-existing +python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/moad/moad_esmfold_sequences.fasta -o data/moad/moad_esmfold_structures --cuda-device-index 0 --skip-existing +python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/DockGen/dockgen_esmfold_sequences.fasta -o data/DockGen/dockgen_esmfold_structures --cuda-device-index 0 --skip-existing +python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/pdbsidechain/pdbsidechain_esmfold_sequences.fasta -o data/pdbsidechain/pdbsidechain_esmfold_structures --cuda-device-index 0 --skip-existing +``` + +Align each apo protein structure to its corresponding +holo protein structure counterpart in PDBBind, Binding MOAD, and PDBSidechain, +taking ligand conformations into account during each alignment + +```bash +python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=pdbbind num_workers=1 +python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=moad num_workers=1 +python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=dockgen num_workers=1 +python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=pdbsidechain num_workers=1 +``` + +Lastly, assess the apo-to-holo alignments in terms of statistics and structural metrics +to enable runtime-dynamic dataset filtering using such information + +```bash +python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=pdbbind usalign_exec_path=$MY_USALIGN_EXEC_PATH +python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=moad usalign_exec_path=$MY_USALIGN_EXEC_PATH +python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=dockgen usalign_exec_path=$MY_USALIGN_EXEC_PATH +python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=pdbsidechain usalign_exec_path=$MY_USALIGN_EXEC_PATH +``` + +
+ +## How to train `FlowDock` + +
+ +Train model with default configuration + +```bash +# train on CPU +python flowdock/train.py trainer=cpu + +# train on GPU +python flowdock/train.py trainer=gpu +``` + +Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/) + +```bash +python flowdock/train.py experiment=experiment_name.yaml +``` + +For example, reproduce `FlowDock`'s default model training run + +```bash +python flowdock/train.py experiment=flowdock_fm +``` + +**Note:** You can override any parameter from command line like this + +```bash +python flowdock/train.py experiment=flowdock_fm trainer.max_epochs=20 data.batch_size=8 +``` + +For example, override parameters to finetune `FlowDock`'s pretrained weights using a new dataset such as [PLINDER](https://www.plinder.sh/) + +```bash +python flowdock/train.py experiment=flowdock_fm data=plinder ckpt_path=checkpoints/esmfold_prior_paper_weights.ckpt +``` + +
+ +## How to evaluate `FlowDock` + +
+ +To reproduce `FlowDock`'s evaluation results for structure prediction, please refer to its documentation in version `0.6.0-FlowDock` of the [PoseBench](https://github.com/BioinfoMachineLearning/PoseBench/tree/0.6.0-FlowDock?tab=readme-ov-file#how-to-run-inference-with-flowdock) GitHub repository. + +To reproduce `FlowDock`'s evaluation results for binding affinity prediction using the PDBBind dataset + +```bash +python flowdock/eval.py data.test_datasets=[pdbbind] ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt trainer=gpu +... # re-run two more times to gather triplicate results +``` + +
+ +## How to create comparative plots of benchmarking results + +
+ +Download baseline method predictions and results + +```bash +# cached predictions and evaluation metrics for reproducing structure prediction paper results +wget https://zenodo.org/records/15066450/files/alphafold3_baseline_method_predictions.tar.gz +tar -xzf alphafold3_baseline_method_predictions.tar.gz +rm alphafold3_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/chai_baseline_method_predictions.tar.gz +tar -xzf chai_baseline_method_predictions.tar.gz +rm chai_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/diffdock_baseline_method_predictions.tar.gz +tar -xzf diffdock_baseline_method_predictions.tar.gz +rm diffdock_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/dynamicbind_baseline_method_predictions.tar.gz +tar -xzf dynamicbind_baseline_method_predictions.tar.gz +rm dynamicbind_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_baseline_method_predictions.tar.gz +tar -xzf flowdock_baseline_method_predictions.tar.gz +rm flowdock_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_aft_baseline_method_predictions.tar.gz +tar -xzf flowdock_aft_baseline_method_predictions.tar.gz +rm flowdock_aft_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_pft_baseline_method_predictions.tar.gz +tar -xzf flowdock_pft_baseline_method_predictions.tar.gz +rm flowdock_pft_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_esmfold_baseline_method_predictions.tar.gz +tar -xzf flowdock_esmfold_baseline_method_predictions.tar.gz +rm flowdock_esmfold_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_chai_baseline_method_predictions.tar.gz +tar -xzf flowdock_chai_baseline_method_predictions.tar.gz +rm flowdock_chai_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/flowdock_hp_baseline_method_predictions.tar.gz +tar -xzf flowdock_hp_baseline_method_predictions.tar.gz +rm flowdock_hp_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/neuralplexer_baseline_method_predictions.tar.gz +tar -xzf neuralplexer_baseline_method_predictions.tar.gz +rm neuralplexer_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/vina_p2rank_baseline_method_predictions.tar.gz +tar -xzf vina_p2rank_baseline_method_predictions.tar.gz +rm vina_p2rank_baseline_method_predictions.tar.gz + +wget https://zenodo.org/records/15066450/files/rfaa_baseline_method_predictions.tar.gz +tar -xzf rfaa_baseline_method_predictions.tar.gz +rm rfaa_baseline_method_predictions.tar.gz +``` + +Reproduce paper result figures + +```bash +jupyter notebook notebooks/casp16_binding_affinity_prediction_results_plotting.ipynb +jupyter notebook notebooks/casp16_flowdock_vs_multicom_ligand_structure_prediction_results_plotting.ipynb +jupyter notebook notebooks/dockgen_structure_prediction_results_plotting.ipynb +jupyter notebook notebooks/posebusters_benchmark_structure_prediction_chemical_similarity_analysis.ipynb +jupyter notebook notebooks/posebusters_benchmark_structure_prediction_results_plotting.ipynb +``` + +
+ +## How to predict new protein-ligand complex structures and their affinities using `FlowDock` + +
+ +For example, generate new protein-ligand complexes for a pair of protein sequence and ligand SMILES strings such as those of the PDBBind 2020 test target `6i67` + +```bash +python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling input_receptor='YNKIVHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNKLEGKVPMHKLFLEMLEAKV' input_ligand='"c1cc2c(cc1O)CCCC2"' input_template=data/pdbbind/pdbbind_holo_aligned_esmfold_structures/6i67_holo_aligned_esmfold_protein.pdb sample_id='6i67' out_path='./6i67_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=true auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu +``` + +Or, for example, generate new protein-ligand complexes for pairs of protein sequences and (multi-)ligand SMILES strings (delimited via `|`) such as those of the CASP15 target `T1152` + +```bash +python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling input_receptor='MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIPN' input_ligand='"CC(=O)NC1C(O)OC(CO)C(OC2OC(CO)C(OC3OC(CO)C(O)C(O)C3NC(C)=O)C(O)C2NC(C)=O)C1O"' input_template=data/test_cases/predicted_structures/T1152.pdb sample_id='T1152' out_path='./T1152_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=true auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu +``` + +If you do not already have a template protein structure available for your target of interest, set `input_template=null` to instead have the sampling script predict the ESMFold structure of your provided `input_protein` sequence before running the sampling pipeline. For more information regarding the input arguments available for sampling, please refer to the config at `configs/sample.yaml`. + +**NOTE:** To optimize prediction runtimes, a `csv_path` can be specified instead of the `input_receptor`, `input_ligand`, and `input_template` CLI arguments to perform *batched* prediction for a collection of protein-ligand sequence pairs, each represented as a CSV row containing column values for `id`, `input_receptor`, `input_ligand`, and `input_template`. Additionally, disabling `visualize_sample_trajectories` may reduce storage requirements when predicting a large batch of inputs. + +For instance, one can perform batched prediction as follows: + +```bash +python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling csv_path='./data/test_cases/prediction_inputs/flowdock_batched_inputs.csv' out_path='./T1152_batch_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=false auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu +``` + +
+ +## For developers + +
+ +Set up `pre-commit` (one time only) for automatic code linting and formatting upon each `git commit` + +```bash +pre-commit install +``` + +Manually reformat all files in the project, as desired + +```bash +pre-commit run -a +``` + +Update dependencies in a `*_environment.yml` file + +```bash +mamba env export > env.yaml # e.g., run this after installing new dependencies locally +diff environments/flowdock_environment.yaml env.yaml # note the differences and copy accepted changes back into e.g., `environments/flowdock_environment.yaml` +rm env.yaml # clean up temporary environment file +``` + +
+ +## Docker + +
+ +Given that this tool has a number of dependencies, it may be easier to run it in a Docker container. + +Pull from [Docker Hub](https://hub.docker.com/repository/docker/cford38/flowdock): `docker pull cford38/flowdock:latest` + +Alternatively, build the Docker image locally: + +```bash +docker build --platform linux/amd64 -t flowdock . +``` + +Then, run the Docker container (and mount your local `checkpoints/` directory) + +```bash +docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --name flowdock -it flowdock /bin/bash + +# docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --name flowdock -it cford38/flowdock:latest /bin/bash +``` + +
+ +## Acknowledgements + +`FlowDock` builds upon the source code and data from the following projects: + +- [DiffDock](https://github.com/gcorso/DiffDock) +- [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) +- [NeuralPLexer](https://github.com/zrqiao/NeuralPLexer) + +We thank all their contributors and maintainers! + +## License + +This project is covered under the **MIT License**. + +## Citing this work + +If you use the code or data associated with this package or otherwise find this work useful, please cite: + +```bibtex +@inproceedings{morehead2025flowdock, + title={FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction}, + author={Alex Morehead and Jianlin Cheng}, + booktitle={Intelligent Systems for Molecular Biology (ISMB)}, + year=2025, +} +``` diff --git a/citation.bib b/citation.bib new file mode 100644 index 0000000..8831101 --- /dev/null +++ b/citation.bib @@ -0,0 +1,6 @@ +@inproceedings{morehead2025flowdock, + title={FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction}, + author={Alex Morehead and Jianlin Cheng}, + booktitle={Intelligent Systems for Molecular Biology (ISMB)}, + year=2025, +} diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000..56bf7f4 --- /dev/null +++ b/configs/__init__.py @@ -0,0 +1 @@ +# this file is needed here to include configs when building project as a package diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml new file mode 100644 index 0000000..4722528 --- /dev/null +++ b/configs/callbacks/default.yaml @@ -0,0 +1,21 @@ +defaults: + - ema + - last_model_checkpoint + - learning_rate_monitor + - model_checkpoint + - model_summary + - rich_progress_bar + - _self_ + +last_model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "last" + monitor: null + verbose: True + auto_insert_metric_name: False + every_n_epochs: 1 + save_on_train_epoch_end: True + enable_version_counter: False + +model_summary: + max_depth: -1 diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml new file mode 100644 index 0000000..c826c8d --- /dev/null +++ b/configs/callbacks/early_stopping.yaml @@ -0,0 +1,15 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html + +early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: ??? # quantity to be monitored, must be specified !!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 3 # number of checks with no improvement after which training will be stopped + verbose: False # verbosity mode + mode: "min" # "max" means higher metric value is better, can be also "min" + strict: True # whether to crash the training if monitor is not found in the validation metrics + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version diff --git a/configs/callbacks/ema.yaml b/configs/callbacks/ema.yaml new file mode 100644 index 0000000..5bff7bb --- /dev/null +++ b/configs/callbacks/ema.yaml @@ -0,0 +1,10 @@ +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py + +# Maintains an exponential moving average (EMA) of model weights. +# Look at the above link for more detailed information regarding the original implementation. +ema: + _target_: flowdock.models.components.callbacks.ema.EMA + decay: 0.999 + validate_original_weights: false + every_n_steps: 4 + cpu_offload: false diff --git a/configs/callbacks/last_model_checkpoint.yaml b/configs/callbacks/last_model_checkpoint.yaml new file mode 100644 index 0000000..0f780d0 --- /dev/null +++ b/configs/callbacks/last_model_checkpoint.yaml @@ -0,0 +1,21 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +last_model_checkpoint: + # NOTE: this is a direct copy of `model_checkpoint`, + # which is necessary to make to work around the + # key-duplication limitations of YAML config files + _target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint + dirpath: null # directory to save the model file + filename: null # checkpoint filename + monitor: null # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: null # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation + enable_version_counter: True # enables versioning for checkpoint names diff --git a/configs/callbacks/learning_rate_monitor.yaml b/configs/callbacks/learning_rate_monitor.yaml new file mode 100644 index 0000000..fa8ecd6 --- /dev/null +++ b/configs/callbacks/learning_rate_monitor.yaml @@ -0,0 +1,7 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html + +learning_rate_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: null # set to `epoch` or `step` to log learning rate of all optimizers at the same interval, or set to `null` to log at individual interval according to the interval key of each scheduler + log_momentum: false # whether to also log the momentum values of the optimizer, if the optimizer has the `momentum` or `betas` attribute + log_weight_decay: false # whether to also log the weight decay values of the optimizer, if the optimizer has the `weight_decay` attribute diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000..8443e2d --- /dev/null +++ b/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,18 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint + dirpath: null # directory to save the model file + filename: "best" # checkpoint filename + monitor: val_sampling/ligand_hit_score_2A_epoch # name of the logged metric which determines when model is improving + verbose: True # verbosity mode + save_last: False # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "max" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: null # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation + enable_version_counter: False # enables versioning for checkpoint names diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000..b75981d --- /dev/null +++ b/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/configs/callbacks/none.yaml b/configs/callbacks/none.yaml new file mode 100644 index 0000000..e69de29 diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000..de6f1cc --- /dev/null +++ b/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,4 @@ +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html + +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml new file mode 100644 index 0000000..1886902 --- /dev/null +++ b/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +# disable callbacks and loggers during debugging +callbacks: null +logger: null + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + + # use this to also set hydra loggers to 'DEBUG' + # verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + +data: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin diff --git a/configs/debug/fdr.yaml b/configs/debug/fdr.yaml new file mode 100644 index 0000000..7f2d34f --- /dev/null +++ b/configs/debug/fdr.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default + +trainer: + fast_dev_run: true diff --git a/configs/debug/limit.yaml b/configs/debug/limit.yaml new file mode 100644 index 0000000..514d77f --- /dev/null +++ b/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/configs/debug/overfit.yaml b/configs/debug/overfit.yaml new file mode 100644 index 0000000..9906586 --- /dev/null +++ b/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: null diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml new file mode 100644 index 0000000..2bd7da8 --- /dev/null +++ b/configs/debug/profiler.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default + +trainer: + max_epochs: 1 + profiler: "simple" + # profiler: "advanced" + # profiler: "pytorch" diff --git a/configs/environment/default.yaml b/configs/environment/default.yaml new file mode 100644 index 0000000..758b22b --- /dev/null +++ b/configs/environment/default.yaml @@ -0,0 +1,2 @@ +defaults: + - _self_ diff --git a/configs/environment/lightning.yaml b/configs/environment/lightning.yaml new file mode 100644 index 0000000..a5e70f4 --- /dev/null +++ b/configs/environment/lightning.yaml @@ -0,0 +1 @@ +_target_: lightning.fabric.plugins.environments.LightningEnvironment diff --git a/configs/environment/slurm.yaml b/configs/environment/slurm.yaml new file mode 100644 index 0000000..55a5522 --- /dev/null +++ b/configs/environment/slurm.yaml @@ -0,0 +1,3 @@ +_target_: lightning.fabric.plugins.environments.SLURMEnvironment +auto_requeue: true +requeue_signal: null diff --git a/configs/eval.yaml b/configs/eval.yaml new file mode 100644 index 0000000..eeff37f --- /dev/null +++ b/configs/eval.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +defaults: + - data: combined # choose datamodule with `test_dataloader()` for evaluation + - model: flowdock_fm + - logger: null + - strategy: default + - trainer: default + - paths: default + - extras: default + - hydra: default + - environment: default + - _self_ + +task_name: "eval" + +tags: ["eval", "combined", "flowdock_fm"] + +# passing checkpoint path is necessary for evaluation +ckpt_path: ??? + +# seed for random number generators in pytorch, numpy and python.random +seed: null + +# model arguments +model: + cfg: + mol_encoder: + from_pretrained: false + protein_encoder: + from_pretrained: false + relational_reasoning: + from_pretrained: false + contact_predictor: + from_pretrained: false + score_head: + from_pretrained: false + confidence: + from_pretrained: false + affinity: + from_pretrained: false + task: + freeze_mol_encoder: true + freeze_protein_encoder: false + freeze_relational_reasoning: false + freeze_contact_predictor: false + freeze_score_head: false + freeze_confidence: true + freeze_affinity: false diff --git a/configs/experiment/flowdock_fm.yaml b/configs/experiment/flowdock_fm.yaml new file mode 100644 index 0000000..b14b3e2 --- /dev/null +++ b/configs/experiment/flowdock_fm.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=flowdock_fm + +defaults: + - override /data: combined + - override /model: flowdock_fm + - override /callbacks: default + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["flowdock_fm", "combined_dataset"] + +seed: 496 + +trainer: + max_epochs: 300 + check_val_every_n_epoch: 5 # NOTE: we increase this since validation steps involve full model sampling and evaluation + reload_dataloaders_every_n_epochs: 1 + +model: + optimizer: + lr: 2e-4 + compile: false + +data: + batch_size: 16 + +logger: + wandb: + tags: ${tags} + group: "FlowDock-FM" diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml new file mode 100644 index 0000000..b9c6b62 --- /dev/null +++ b/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/configs/hparams_search/combined_optuna.yaml b/configs/hparams_search/combined_optuna.yaml new file mode 100644 index 0000000..43960de --- /dev/null +++ b/configs/hparams_search/combined_optuna.yaml @@ -0,0 +1,50 @@ +# @package _global_ + +# example hyperparameter optimization of some experiment with Optuna: +# python train.py -m hparams_search=mnist_optuna experiment=example + +defaults: + - override /hydra/sweeper: optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module! +optimized_metric: "val/loss" + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + mode: "MULTIRUN" # set hydra to multirun by default if this config is attached + + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper + + # storage URL to persist optimization results + # for example, you can use SQLite if you set 'sqlite:///example.db' + storage: null + + # name of the study to persist optimization results + study_name: null + + # number of parallel workers + n_jobs: 1 + + # 'minimize' or 'maximize' the objective + direction: minimize + + # total number of runs that will be executed + n_trials: 20 + + # choose Optuna hyperparameter sampler + # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others + # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html + sampler: + _target_: optuna.samplers.TPESampler + seed: 1234 + n_startup_trials: 10 # number of random sampling runs before optimization starts + + # define hyperparameter search space + params: + model.optimizer.lr: interval(0.0001, 0.1) + data.batch_size: choice(2, 4, 8, 16) + model.net.hidden_dim: choice(64, 128, 256) diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000..a61e9b3 --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${task_name}.log diff --git a/configs/local/.gitkeep b/configs/local/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml new file mode 100644 index 0000000..8f9f6ad --- /dev/null +++ b/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml new file mode 100644 index 0000000..9de4323 --- /dev/null +++ b/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "FlowDock_FM" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml new file mode 100644 index 0000000..fa028e9 --- /dev/null +++ b/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml new file mode 100644 index 0000000..dd58680 --- /dev/null +++ b/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet + - csv + # - mlflow + # - neptune + - tensorboard + - wandb diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100644 index 0000000..f8fb7e6 --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml new file mode 100644 index 0000000..874a3fc --- /dev/null +++ b/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/FlowDock_FM + # name: "" + log_model_checkpoints: True + prefix: "" diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml new file mode 100644 index 0000000..2bd31f6 --- /dev/null +++ b/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + # version: "" diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml new file mode 100644 index 0000000..c66c04e --- /dev/null +++ b/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "FlowDock_FM" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + entity: "bml-lab" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/configs/model/flowdock_fm.yaml b/configs/model/flowdock_fm.yaml new file mode 100644 index 0000000..2812281 --- /dev/null +++ b/configs/model/flowdock_fm.yaml @@ -0,0 +1,148 @@ +_target_: flowdock.models.flowdock_fm_module.FlowDockFMLitModule + +net: + _target_: flowdock.models.components.flowdock.FlowDock + _partial_: true + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 2e-4 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts + _partial_: true + T_0: ${int_divide:${trainer.max_epochs},15} + T_mult: 2 + eta_min: 1e-8 + verbose: true + +# compile model for faster training with pytorch 2.0 +compile: false + +# model arguments +cfg: + mol_encoder: + node_channels: 512 + pair_channels: 64 + n_atom_encodings: 23 + n_bond_encodings: 4 + n_atom_pos_encodings: 6 + n_stereo_encodings: 14 + n_attention_heads: 8 + attention_head_dim: 8 + hidden_dim: 2048 + max_path_integral_length: 6 + n_transformer_stacks: 8 + n_heads: 8 + n_patches: ${data.n_lig_patches} + checkpoint_file: ${oc.env:PROJECT_ROOT}/checkpoints/neuralplexermodels_downstream_datasets_predictions/models/complex_structure_prediction.ckpt + megamolbart: null + from_pretrained: true + + protein_encoder: + use_esm_embedding: true + esm_version: esm2_t33_650M_UR50D + esm_repr_layer: 33 + residue_dim: 512 + plm_embed_dim: 1280 + n_aa_types: 21 + atom_padding_dim: 37 + n_atom_types: 4 # [C, N, O, S] + n_patches: ${data.n_protein_patches} + n_attention_heads: 8 + scalar_dim: 16 + point_dim: 4 + pair_dim: 64 + n_heads: 8 + head_dim: 8 + max_residue_degree: 32 + n_encoder_stacks: 2 + from_pretrained: true + + relational_reasoning: + from_pretrained: true + + contact_predictor: + n_stacks: 4 + dropout: 0.01 + from_pretrained: true + + score_head: + fiber_dim: 64 + hidden_dim: 512 + n_stacks: 4 + max_atom_degree: 8 + from_pretrained: true + + confidence: + enabled: true # whether the confidence prediction head is to be used e.g., during inference + fiber_dim: ${..score_head.fiber_dim} + hidden_dim: ${..score_head.hidden_dim} + n_stacks: ${..score_head.n_stacks} + from_pretrained: true + + affinity: + enabled: true # whether the affinity prediction head is to be used e.g., during inference + fiber_dim: ${..score_head.fiber_dim} + hidden_dim: ${..score_head.hidden_dim} + n_stacks: ${..score_head.n_stacks} + ligand_pooling: sum # NOTE: must be a value in (`sum`, `mean`) + dropout: 0.01 + from_pretrained: false + + latent_model: default + prior_type: esmfold # NOTE: must be a value in (`gaussian`, `harmonic`, `esmfold`) + + task: + pretrained: null + ligands: true + epoch_frac: ${data.epoch_frac} + label_name: null + sequence_crop_size: 1600 + edge_crop_size: ${data.edge_crop_size} # NOTE: for dynamic batching via `max_n_edges` + max_masking_rate: 0.0 + n_modes: 8 + dropout: 0.01 + # pretraining: true + freeze_mol_encoder: true + freeze_protein_encoder: false + freeze_relational_reasoning: false + freeze_contact_predictor: true + freeze_score_head: false + freeze_confidence: true + freeze_affinity: false + use_template: true + use_plddt: false + block_contact_decoding_scheme: "beam" + frozen_ligand_backbone: false + frozen_protein_backbone: false + single_protein_batch: true + contact_loss_weight: 0.2 + global_score_loss_weight: 0.2 + ligand_score_loss_weight: 0.1 + clash_loss_weight: 10.0 + local_distgeom_loss_weight: 10.0 + drmsd_loss_weight: 2.0 + distogram_loss_weight: 0.05 + plddt_loss_weight: 1.0 + affinity_loss_weight: 0.1 + aux_batch_freq: 10 # NOTE: e.g., `10` means that auxiliary estimation losses will be calculated every 10th batch + global_max_sigma: 5.0 + internal_max_sigma: 2.0 + detect_covalent: true + # runtime configs + float32_matmul_precision: highest + # sampling configs + constrained_inpainting: false + visualize_generated_samples: true + # testing configs + loss_mode: auxiliary_estimation # NOTE: must be one of (`structure_prediction`, `auxiliary_estimation`, `auxiliary_estimation_without_structure_prediction`) + num_steps: 20 + sampler: VDODE # NOTE: must be one of (`ODE`, `VDODE`) + sampler_eta: 1.0 # NOTE: this corresponds to the variance diminishing factor for the `VDODE` sampler, which offers a trade-off between exploration (1.0) and exploitation (> 1.0) + start_time: 1.0 + eval_structure_prediction: false # whether to evaluate structure prediction performance (`true`) or instead only binding affinity performance (`false`) when running a test epoch + # overfitting configs + overfitting_example_name: ${data.overfitting_example_name} diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml new file mode 100644 index 0000000..3898b28 --- /dev/null +++ b/configs/paths/default.yaml @@ -0,0 +1,21 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${paths.root_dir}/data/ + +# path to logging directory +log_dir: ${paths.root_dir}/logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} + +# path to the directory containing the model checkpoints +ckpt_dir: ${paths.root_dir}/checkpoints/ diff --git a/configs/sample.yaml b/configs/sample.yaml new file mode 100644 index 0000000..6d02973 --- /dev/null +++ b/configs/sample.yaml @@ -0,0 +1,78 @@ +# @package _global_ + +defaults: + - data: combined # NOTE: this will not be referenced during sampling + - model: flowdock_fm + - logger: null + - strategy: default + - trainer: default + - paths: default + - extras: default + - hydra: default + - environment: default + - _self_ + +task_name: "sample" + +tags: ["sample", "combined", "flowdock_fm"] + +# passing checkpoint path is necessary for sampling +ckpt_path: ??? + +# seed for random number generators in pytorch, numpy and python.random +seed: null + +# sampling arguments +sampling_task: batched_structure_sampling # NOTE: must be one of (`batched_structure_sampling`) +sample_id: null # optional identifier for the sampling run +input_receptor: null # NOTE: must be either a protein sequence string (with chains separated by `|`) or a path to a PDB file (from which protein chain sequences will be parsed) +input_ligand: null # NOTE: must be either a ligand SMILES string (with chains/fragments separated by `|`) or a path to a ligand SDF file (from which ligand SMILES will be parsed) +input_template: null # path to a protein PDB file to use as a starting protein template for sampling (with an ESMFold prior model) +out_path: ??? # path to which to save the output PDB and SDF files +n_samples: 5 # number of structures to sample +chunk_size: 5 # number of structures to concurrently sample within each batch segment - NOTE: `n_samples` should be evenly divisible by `chunk_size` to produce the expected number of outputs +num_steps: 40 # number of sampling steps to perform +latent_model: null # if provided, the type of latent model to use +sampler: VDODE # sampling algorithm to use - NOTE: must be one of (`ODE`, `VDODE`) +sampler_eta: 1.0 # the variance diminishing factor for the `VDODE` sampler - NOTE: offers a trade-off between exploration (1.0) and exploitation (> 1.0) +start_time: "1.0" # time at which to start sampling +max_chain_encoding_k: -1 # maximum number of chains to encode in the chain encoding +exact_prior: false # whether to use the "ground-truth" binding site for sampling, if available +prior_type: esmfold # the type of prior to use for sampling - NOTE: must be one of (`gaussian`, `harmonic`, `esmfold`) +discard_ligand: false # whether to discard a given input ligand during sampling +discard_sdf_coords: true # whether to discard the input ligand's 3D structure during sampling, if available +detect_covalent: false # whether to detect covalent bonds between the input receptor and ligand +use_template: true # whether to use the input protein template for sampling if one is provided +separate_pdb: true # whether to save separate PDB files for each sampled structure instead of simply a single PDB file +rank_outputs_by_confidence: true # whether to rank the sampled structures by estimated confidence +plddt_ranking_type: ligand # the type of plDDT ranking to apply to generated samples - NOTE: must be one of (`protein`, `ligand`, `protein_ligand`) +visualize_sample_trajectories: false # whether to visualize the generated samples' trajectories +auxiliary_estimation_only: false # whether to only estimate auxiliary outputs (e.g., confidence, affinity) for the input (generated) samples (potentially derived from external sources) +csv_path: null # if provided, the CSV file (with columns `id`, `input_receptor`, `input_ligand`, and `input_template`) from which to parse input receptors and ligands for sampling, overriding the `input_receptor` and `input_ligand` arguments in the process and ignoring the `input_template` for now +esmfold_chunk_size: null # chunks axial attention computation to reduce memory usage from O(L^2) to O(L); equivalent to running a for loop over chunks of of each dimension; lower values will result in lower memory usage at the cost of speed; recommended values: 128, 64, 32 + +# model arguments +model: + cfg: + mol_encoder: + from_pretrained: false + protein_encoder: + from_pretrained: false + relational_reasoning: + from_pretrained: false + contact_predictor: + from_pretrained: false + score_head: + from_pretrained: false + confidence: + from_pretrained: false + affinity: + from_pretrained: false + task: + freeze_mol_encoder: true + freeze_protein_encoder: false + freeze_relational_reasoning: false + freeze_contact_predictor: false + freeze_score_head: false + freeze_confidence: true + freeze_affinity: false diff --git a/configs/strategy/ddp.yaml b/configs/strategy/ddp.yaml new file mode 100644 index 0000000..14933f3 --- /dev/null +++ b/configs/strategy/ddp.yaml @@ -0,0 +1,4 @@ +_target_: lightning.pytorch.strategies.DDPStrategy +static_graph: false +gradient_as_bucket_view: false +find_unused_parameters: true diff --git a/configs/strategy/ddp_spawn.yaml b/configs/strategy/ddp_spawn.yaml new file mode 100644 index 0000000..e51e4f2 --- /dev/null +++ b/configs/strategy/ddp_spawn.yaml @@ -0,0 +1,5 @@ +_target_: lightning.pytorch.strategies.DDPStrategy +static_graph: false +gradient_as_bucket_view: false +find_unused_parameters: true +start_method: spawn diff --git a/configs/strategy/deepspeed.yaml b/configs/strategy/deepspeed.yaml new file mode 100644 index 0000000..3c05b4c --- /dev/null +++ b/configs/strategy/deepspeed.yaml @@ -0,0 +1,5 @@ +_target_: lightning.pytorch.strategies.DeepSpeedStrategy +stage: 2 +offload_optimizer: false +allgather_bucket_size: 200_000_000 +reduce_bucket_size: 200_000_000 diff --git a/configs/strategy/default.yaml b/configs/strategy/default.yaml new file mode 100644 index 0000000..758b22b --- /dev/null +++ b/configs/strategy/default.yaml @@ -0,0 +1,2 @@ +defaults: + - _self_ diff --git a/configs/strategy/fsdp.yaml b/configs/strategy/fsdp.yaml new file mode 100644 index 0000000..12e29e8 --- /dev/null +++ b/configs/strategy/fsdp.yaml @@ -0,0 +1,12 @@ +_target_: lightning.pytorch.strategies.FSDPStrategy +sharding_strategy: ${resolve_variable:torch.distributed.fsdp.ShardingStrategy.FULL_SHARD} +cpu_offload: null +activation_checkpointing: null +mixed_precision: + _target_: torch.distributed.fsdp.MixedPrecision + param_dtype: null + reduce_dtype: null + buffer_dtype: null + keep_low_precision_grads: false + cast_forward_inputs: false + cast_root_forward_inputs: true diff --git a/configs/strategy/optimized_ddp.yaml b/configs/strategy/optimized_ddp.yaml new file mode 100644 index 0000000..b8b4b31 --- /dev/null +++ b/configs/strategy/optimized_ddp.yaml @@ -0,0 +1,4 @@ +_target_: lightning.pytorch.strategies.DDPStrategy +static_graph: true +gradient_as_bucket_view: true +find_unused_parameters: false diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000..d2205a4 --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: combined + - model: flowdock_fm + - callbacks: default + - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - strategy: default + - trainer: default + - paths: default + - extras: default + - hydra: default + - environment: default + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path +task_name: "train" + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +tags: ["train", "combined", "flowdock_fm"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: False + +# simply provide checkpoint path to resume training +ckpt_path: null + +# seed for random number generators in pytorch, numpy and python.random +seed: null diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml new file mode 100644 index 0000000..b7d6767 --- /dev/null +++ b/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml new file mode 100644 index 0000000..ab8f890 --- /dev/null +++ b/configs/trainer/ddp.yaml @@ -0,0 +1,9 @@ +defaults: + - default + +strategy: ddp + +accelerator: gpu +devices: 4 +num_nodes: 1 +sync_batchnorm: True diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000..8404419 --- /dev/null +++ b/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/configs/trainer/ddp_spawn.yaml b/configs/trainer/ddp_spawn.yaml new file mode 100644 index 0000000..021150a --- /dev/null +++ b/configs/trainer/ddp_spawn.yaml @@ -0,0 +1,9 @@ +defaults: + - default + +strategy: ddp_spawn + +accelerator: gpu +devices: 4 +num_nodes: 1 +sync_batchnorm: True diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100644 index 0000000..ca4a10b --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,29 @@ +_target_: lightning.pytorch.trainer.Trainer + +default_root_dir: ${paths.output_dir} + +min_epochs: 1 # prevents early stopping +max_epochs: 10 + +accelerator: cpu +devices: 1 + +# mixed precision for extra speed-up +# precision: 16 + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False + +# determine the frequency of how often to reload the dataloaders +reload_dataloaders_every_n_epochs: 1 + +# if `gradient_clip_val` is not `null`, gradients will be norm-clipped during training +gradient_clip_algorithm: norm +gradient_clip_val: 1.0 + +# if `num_sanity_val_steps` is > 0, then specifically that many validation steps will be run during the first call to `trainer.fit` +num_sanity_val_steps: 0 diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml new file mode 100644 index 0000000..b238951 --- /dev/null +++ b/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: gpu +devices: 1 diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml new file mode 100644 index 0000000..1ecf6d5 --- /dev/null +++ b/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/environments/flowdock_environment.yaml b/environments/flowdock_environment.yaml new file mode 100644 index 0000000..ab55337 --- /dev/null +++ b/environments/flowdock_environment.yaml @@ -0,0 +1,540 @@ +name: FlowDock +channels: + - pytorch + - pyg + - senyan.dev + - nvidia + - conda-forge +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=3_kmp_llvm + - aiohappyeyeballs=2.6.1=pyhd8ed1ab_0 + - aiohttp=3.11.13=py310h89163eb_0 + - aiosignal=1.3.2=pyhd8ed1ab_0 + - alsa-lib=1.2.13=hb9d3cd8_0 + - ambertools=24.8=cuda_None_nompi_py310h834fefc_101 + - annotated-types=0.7.0=pyhd8ed1ab_1 + - anyio=4.8.0=pyhd8ed1ab_0 + - aom=3.9.1=hac33072_0 + - argon2-cffi=23.1.0=pyhd8ed1ab_1 + - argon2-cffi-bindings=21.2.0=py310ha75aee5_5 + - arpack=3.9.1=nompi_hf03ea27_102 + - arrow=1.3.0=pyhd8ed1ab_1 + - asttokens=3.0.0=pyhd8ed1ab_1 + - async-lru=2.0.4=pyhd8ed1ab_1 + - async-timeout=5.0.1=pyhd8ed1ab_1 + - attr=2.5.1=h166bdaf_1 + - attrs=25.3.0=pyh71513ae_0 + - babel=2.17.0=pyhd8ed1ab_0 + - beautifulsoup4=4.13.3=pyha770c72_0 + - blas=2.116=mkl + - blas-devel=3.9.0=16_linux64_mkl + - bleach=6.2.0=pyh29332c3_4 + - bleach-with-css=6.2.0=h82add2a_4 + - blosc=1.21.6=he440d0b_1 + - brotli=1.1.0=hb9d3cd8_2 + - brotli-bin=1.1.0=hb9d3cd8_2 + - brotli-python=1.1.0=py310hf71b8c6_2 + - bson=0.5.10=pyhd8ed1ab_0 + - bzip2=1.0.8=h4bc722e_7 + - c-ares=1.34.4=hb9d3cd8_0 + - c-blosc2=2.15.2=h3122c55_1 + - ca-certificates=2025.1.31=hbcca054_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - cachetools=5.5.2=pyhd8ed1ab_0 + - cairo=1.18.4=h3394656_0 + - certifi=2025.1.31=pyhd8ed1ab_0 + - cffi=1.17.1=py310h8deb56e_0 + - chardet=5.2.0=pyhd8ed1ab_3 + - charset-normalizer=3.4.1=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_1 + - comm=0.2.2=pyhd8ed1ab_1 + - contourpy=1.3.1=py310h3788b33_0 + - cpython=3.10.16=py310hd8ed1ab_1 + - cuda-cudart=11.8.89=0 + - cuda-cupti=11.8.87=0 + - cuda-libraries=11.8.0=0 + - cuda-nvrtc=11.8.89=0 + - cuda-nvtx=11.8.86=0 + - cuda-runtime=11.8.0=0 + - cuda-version=11.8=h70ddcb2_3 + - cudatoolkit=11.8.0=h4ba93d1_13 + - cudatoolkit-dev=11.8.0=h1fa729e_6 + - cycler=0.12.1=pyhd8ed1ab_1 + - cyrus-sasl=2.1.27=h54b06d7_7 + - dav1d=1.2.1=hd590300_0 + - dbus=1.13.6=h5008d03_3 + - debugpy=1.8.13=py310hf71b8c6_0 + - decorator=5.2.1=pyhd8ed1ab_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - deprecated=1.2.18=pyhd8ed1ab_0 + - exceptiongroup=1.2.2=pyhd8ed1ab_1 + - expat=2.6.4=h5888daf_0 + - ffmpeg=7.1.1=gpl_h24e5c1d_701 + - fftw=3.3.10=nompi_hf1063bd_110 + - filelock=3.18.0=pyhd8ed1ab_0 + - flexcache=0.3=pyhd8ed1ab_1 + - flexparser=0.4=pyhd8ed1ab_1 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_3 + - fontconfig=2.15.0=h7e30c49_1 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.56.0=py310h89163eb_0 + - fqdn=1.5.1=pyhd8ed1ab_1 + - freetype=2.13.3=h48d6fc4_0 + - freetype-py=2.3.0=pyhd8ed1ab_0 + - fribidi=1.0.10=h36c2ea0_0 + - frozenlist=1.5.0=py310h89163eb_1 + - fsspec=2025.3.0=pyhd8ed1ab_0 + - gdk-pixbuf=2.42.12=hb9ae30d_0 + - gettext=0.23.1=h5888daf_0 + - gettext-tools=0.23.1=h5888daf_0 + - gmp=6.3.0=hac33072_2 + - gmpy2=2.1.5=py310he8512ff_3 + - graphite2=1.3.13=h59595ed_1003 + - greenlet=3.1.1=py310hf71b8c6_1 + - h11=0.14.0=pyhd8ed1ab_1 + - h2=4.2.0=pyhd8ed1ab_0 + - harfbuzz=10.4.0=h76408a6_0 + - hdf4=4.2.15=h2a13503_7 + - hdf5=1.14.4=nompi_h2d575fe_105 + - hpack=4.1.0=pyhd8ed1ab_0 + - httpcore=1.0.7=pyh29332c3_1 + - httpx=0.28.1=pyhd8ed1ab_0 + - hyperframe=6.1.0=pyhd8ed1ab_0 + - icu=75.1=he02047a_0 + - idna=3.10=pyhd8ed1ab_1 + - importlib-metadata=8.6.1=pyha770c72_0 + - importlib_resources=6.5.2=pyhd8ed1ab_0 + - ipykernel=6.29.5=pyh3099207_0 + - ipython=8.34.0=pyh907856f_0 + - isoduration=20.11.0=pyhd8ed1ab_1 + - jack=1.9.22=h7c63dc7_2 + - jedi=0.19.2=pyhd8ed1ab_1 + - jinja2=3.1.6=pyhd8ed1ab_0 + - joblib=1.4.2=pyhd8ed1ab_1 + - json5=0.10.0=pyhd8ed1ab_1 + - jsonpointer=3.0.0=py310hff52083_1 + - jsonschema=4.23.0=pyhd8ed1ab_1 + - jsonschema-specifications=2024.10.1=pyhd8ed1ab_1 + - jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1 + - jupyter-lsp=2.2.5=pyhd8ed1ab_1 + - jupyter_client=8.6.3=pyhd8ed1ab_1 + - jupyter_core=5.7.2=pyh31011fe_1 + - jupyter_events=0.12.0=pyh29332c3_0 + - jupyter_server=2.15.0=pyhd8ed1ab_0 + - jupyter_server_terminals=0.5.3=pyhd8ed1ab_1 + - jupyterlab=4.3.6=pyhd8ed1ab_0 + - jupyterlab_pygments=0.3.0=pyhd8ed1ab_2 + - jupyterlab_server=2.27.3=pyhd8ed1ab_1 + - jupyterlab_widgets=3.0.13=pyhd8ed1ab_1 + - kernel-headers_linux-64=3.10.0=he073ed8_18 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.7=py310h3788b33_0 + - krb5=1.21.3=h659f571_0 + - lame=3.100=h166bdaf_1003 + - lcms2=2.17=h717163a_0 + - ld_impl_linux-64=2.43=h712a8e2_4 + - lerc=4.0.0=h27087fc_0 + - level-zero=1.21.2=h84d6215_0 + - libabseil=20250127.0=cxx17_hbbce691_0 + - libaec=1.1.3=h59595ed_0 + - libasprintf=0.23.1=h8e693c7_0 + - libasprintf-devel=0.23.1=h8e693c7_0 + - libass=0.17.3=hba53ac1_1 + - libblas=3.9.0=16_linux64_mkl + - libboost=1.86.0=h6c02f8c_3 + - libboost-python=1.86.0=py310ha2bacc8_3 + - libbrotlicommon=1.1.0=hb9d3cd8_2 + - libbrotlidec=1.1.0=hb9d3cd8_2 + - libbrotlienc=1.1.0=hb9d3cd8_2 + - libcap=2.75=h39aace5_0 + - libcblas=3.9.0=16_linux64_mkl + - libcublas=11.11.3.6=0 + - libcufft=10.9.0.58=0 + - libcufile=1.9.1.3=0 + - libcurand=10.3.5.147=0 + - libcurl=8.12.1=h332b0f4_0 + - libcusolver=11.4.1.48=0 + - libcusparse=11.7.5.86=0 + - libdb=6.2.32=h9c3ff4c_0 + - libdeflate=1.23=h4ddbbb0_0 + - libdrm=2.4.124=hb9d3cd8_0 + - libedit=3.1.20250104=pl5321h7949ede_0 + - libegl=1.7.0=ha4b6fd6_2 + - libev=4.33=hd590300_2 + - libexpat=2.6.4=h5888daf_0 + - libffi=3.4.6=h2dba641_0 + - libflac=1.4.3=h59595ed_0 + - libgcc=14.2.0=h767d61c_2 + - libgcc-ng=14.2.0=h69a702a_2 + - libgcrypt-lib=1.11.0=hb9d3cd8_2 + - libgettextpo=0.23.1=h5888daf_0 + - libgettextpo-devel=0.23.1=h5888daf_0 + - libgfortran=14.2.0=h69a702a_2 + - libgfortran-ng=14.2.0=h69a702a_2 + - libgfortran5=14.2.0=hf1ad2bd_2 + - libgl=1.7.0=ha4b6fd6_2 + - libglib=2.82.2=h2ff4ddf_1 + - libglvnd=1.7.0=ha4b6fd6_2 + - libglx=1.7.0=ha4b6fd6_2 + - libgomp=14.2.0=h767d61c_2 + - libgpg-error=1.51=hbd13f7d_1 + - libhwloc=2.11.2=default_h0d58e46_1001 + - libiconv=1.18=h4ce23a2_1 + - libjpeg-turbo=3.0.0=hd590300_1 + - liblapack=3.9.0=16_linux64_mkl + - liblapacke=3.9.0=16_linux64_mkl + - liblzma=5.6.4=hb9d3cd8_0 + - libnetcdf=4.9.2=nompi_h5ddbaa4_116 + - libnghttp2=1.64.0=h161d5f1_0 + - libnpp=11.8.0.86=0 + - libnsl=2.0.1=hd590300_0 + - libntlm=1.8=hb9d3cd8_0 + - libnvjpeg=11.9.0.86=0 + - libogg=1.3.5=h4ab18f5_0 + - libopenvino=2025.0.0=hdc3f47d_3 + - libopenvino-auto-batch-plugin=2025.0.0=h4d9b6c2_3 + - libopenvino-auto-plugin=2025.0.0=h4d9b6c2_3 + - libopenvino-hetero-plugin=2025.0.0=h981d57b_3 + - libopenvino-intel-cpu-plugin=2025.0.0=hdc3f47d_3 + - libopenvino-intel-gpu-plugin=2025.0.0=hdc3f47d_3 + - libopenvino-intel-npu-plugin=2025.0.0=hdc3f47d_3 + - libopenvino-ir-frontend=2025.0.0=h981d57b_3 + - libopenvino-onnx-frontend=2025.0.0=h0e684df_3 + - libopenvino-paddle-frontend=2025.0.0=h0e684df_3 + - libopenvino-pytorch-frontend=2025.0.0=h5888daf_3 + - libopenvino-tensorflow-frontend=2025.0.0=h684f15b_3 + - libopenvino-tensorflow-lite-frontend=2025.0.0=h5888daf_3 + - libopus=1.3.1=h7f98852_1 + - libpciaccess=0.18=hd590300_0 + - libpng=1.6.47=h943b412_0 + - libpq=17.4=h27ae623_0 + - libprotobuf=5.29.3=h501fc15_0 + - librdkit=2024.09.6=h84b0b3c_0 + - librsvg=2.58.4=h49af25d_2 + - libsndfile=1.2.2=hc60ed4a_1 + - libsodium=1.0.20=h4ab18f5_0 + - libsqlite=3.49.1=hee588c1_1 + - libssh2=1.11.1=hf672d98_0 + - libstdcxx=14.2.0=h8f9b012_2 + - libstdcxx-ng=14.2.0=h4852527_2 + - libsystemd0=257.4=h4e0b6ca_1 + - libtiff=4.7.0=hd9ff511_3 + - libudev1=257.4=hbe16f8c_1 + - libunwind=1.6.2=h9c3ff4c_0 + - liburing=2.9=h84d6215_0 + - libusb=1.0.27=hb9d3cd8_101 + - libuuid=2.38.1=h0b41bf4_0 + - libva=2.22.0=h4f16b4b_2 + - libvorbis=1.3.7=h9c3ff4c_0 + - libvpx=1.14.1=hac33072_0 + - libwebp-base=1.5.0=h851e524_0 + - libxcb=1.17.0=h8a09558_0 + - libxcrypt=4.4.36=hd590300_1 + - libxkbcommon=1.8.1=hc4a0caf_0 + - libxml2=2.13.6=h8d12d68_0 + - libxslt=1.1.39=h76b75d6_0 + - libzip=1.11.2=h6991a6a_0 + - libzlib=1.3.1=hb9d3cd8_2 + - llvm-openmp=15.0.7=h0cdce71_0 + - lxml=5.3.1=py310h6ee67d5_0 + - lz4-c=1.10.0=h5888daf_1 + - markupsafe=3.0.2=py310h89163eb_1 + - matplotlib-base=3.10.1=py310h68603db_0 + - matplotlib-inline=0.1.7=pyhd8ed1ab_1 + - mda-xdrlib=0.2.0=pyhd8ed1ab_1 + - mdtraj=1.10.3=py310h4cdbd58_0 + - mendeleev=0.20.1=pymin39_ha308f57_3 + - mistune=3.1.2=pyhd8ed1ab_0 + - mkl=2022.1.0=h84fe81f_915 + - mkl-devel=2022.1.0=ha770c72_916 + - mkl-include=2022.1.0=h84fe81f_915 + - mpc=1.3.1=h24ddda3_1 + - mpfr=4.2.1=h90cbb55_3 + - mpg123=1.32.9=hc50e24c_0 + - mpmath=1.3.0=pyhd8ed1ab_1 + - multidict=6.1.0=py310h89163eb_2 + - munkres=1.1.4=pyh9f0ad1d_0 + - nbclient=0.10.2=pyhd8ed1ab_0 + - nbconvert-core=7.16.6=pyh29332c3_0 + - nbformat=5.10.4=pyhd8ed1ab_1 + - ncurses=6.5=h2d0b736_3 + - nest-asyncio=1.6.0=pyhd8ed1ab_1 + - netcdf-fortran=4.6.1=nompi_ha5d1325_108 + - networkx=3.4.2=pyh267e887_2 + - notebook=7.3.3=pyhd8ed1ab_0 + - notebook-shim=0.2.4=pyhd8ed1ab_1 + - numexpr=2.7.3=py310hb5077e9_1 + - ocl-icd=2.3.2=hb9d3cd8_2 + - ocl-icd-system=1.0.0=1 + - opencl-headers=2024.10.24=h5888daf_0 + - openff-amber-ff-ports=0.0.4=pyhca7485f_0 + - openff-forcefields=2024.09.0=pyhff2d567_0 + - openff-interchange=0.4.2=pyhd8ed1ab_2 + - openff-interchange-base=0.4.2=pyhd8ed1ab_2 + - openff-toolkit=0.16.8=pyhd8ed1ab_2 + - openff-toolkit-base=0.16.8=pyhd8ed1ab_2 + - openff-units=0.3.0=pyhd8ed1ab_1 + - openff-utilities=0.1.15=pyhd8ed1ab_0 + - openh264=2.6.0=hc22cd8d_0 + - openjpeg=2.5.3=h5fbd93e_0 + - openldap=2.6.9=he970967_0 + - openmm=8.2.0=py310h30bdd6a_2 + - openmmforcefields=0.14.2=pyhd8ed1ab_0 + - openssl=3.4.1=h7b32b05_0 + - overrides=7.7.0=pyhd8ed1ab_1 + - packaging=24.2=pyhd8ed1ab_2 + - panedr=0.8.0=pyhd8ed1ab_1 + - pango=1.56.2=h861ebed_0 + - parmed=4.3.0=py310h78e4988_1 + - parso=0.8.4=pyhd8ed1ab_1 + - pcre2=10.44=hba22ea6_2 + - pdbfixer=1.11=pyhd8ed1ab_0 + - perl=5.32.1=7_hd590300_perl5 + - pexpect=4.9.0=pyhd8ed1ab_1 + - pickleshare=0.7.5=pyhd8ed1ab_1004 + - pillow=11.1.0=py310h7e6dc6c_0 + - pint=0.24.4=pyhd8ed1ab_1 + - pip=25.0.1=pyh8b19718_0 + - pixman=0.44.2=h29eaf8c_0 + - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2 + - platformdirs=4.3.6=pyhd8ed1ab_1 + - prometheus_client=0.21.1=pyhd8ed1ab_0 + - prompt-toolkit=3.0.50=pyha770c72_0 + - psutil=7.0.0=py310ha75aee5_0 + - pthread-stubs=0.4=hb9d3cd8_1002 + - ptyprocess=0.7.0=pyhd8ed1ab_1 + - pugixml=1.15=h3f63f65_0 + - pulseaudio-client=17.0=hb77b528_0 + - pure_eval=0.2.3=pyhd8ed1ab_1 + - py-cpuinfo=9.0.0=pyhd8ed1ab_1 + - pycairo=1.27.0=py310h25ff670_0 + - pycparser=2.22=pyh29332c3_1 + - pydantic=2.10.6=pyh3cfb1c2_0 + - pydantic-core=2.27.2=py310h505e2c1_0 + - pyedr=0.8.0=pyhd8ed1ab_1 + - pyfiglet=0.8.post1=py_0 + - pyg=2.5.2=py310_torch_2.2.0_cu118 + - pygments=2.19.1=pyhd8ed1ab_0 + - pyparsing=3.2.1=pyhd8ed1ab_0 + - pysocks=1.7.1=pyha55dd90_7 + - pytables=3.10.1=py310h431dcdc_4 + - python=3.10.16=he725a3c_1_cpython + - python-constraint=1.4.0=pyhff2d567_1 + - python-dateutil=2.9.0.post0=pyhff2d567_1 + - python-fastjsonschema=2.21.1=pyhd8ed1ab_0 + - python-tzdata=2025.1=pyhd8ed1ab_0 + - python_abi=3.10=5_cp310 + - pytorch=2.2.1=py3.10_cuda11.8_cudnn8.7.0_0 + - pytorch-cuda=11.8=h7e8668a_6 + - pytorch-mutex=1.0=cuda + - pytorch-scatter=2.1.2=py310_torch_2.2.0_cu118 + - pytz=2025.1=pyhd8ed1ab_0 + - pyyaml=6.0.2=py310h89163eb_2 + - pyzmq=26.3.0=py310h71f11fc_0 + - qhull=2020.2=h434a139_5 + - rdkit=2024.09.6=py310hcd13295_0 + - readline=8.2=h8c095d6_2 + - referencing=0.36.2=pyh29332c3_0 + - reportlab=4.3.1=py310ha75aee5_0 + - requests=2.32.3=pyhd8ed1ab_1 + - rfc3339-validator=0.1.4=pyhd8ed1ab_1 + - rfc3986-validator=0.1.1=pyh9f0ad1d_0 + - rlpycairo=0.2.0=pyhd8ed1ab_0 + - rpds-py=0.23.1=py310hc1293b2_0 + - scikit-learn=1.6.1=py310h27f47ee_0 + - scipy=1.15.2=py310h1d65ade_0 + - sdl2=2.32.50=h9b8e6db_1 + - sdl3=3.2.8=h3083f51_0 + - send2trash=1.8.3=pyh0d859eb_1 + - setuptools=75.8.2=pyhff2d567_0 + - six=1.17.0=pyhd8ed1ab_0 + - smirnoff99frosst=1.1.0=pyh44b312d_0 + - snappy=1.2.1=h8bd8927_1 + - sniffio=1.3.1=pyhd8ed1ab_1 + - sqlalchemy=2.0.39=py310ha75aee5_1 + - stack_data=0.6.3=pyhd8ed1ab_1 + - svt-av1=3.0.1=h5888daf_0 + - sympy=1.13.3=pyh2585a3b_105 + - sysroot_linux-64=2.17=h0157908_18 + - tbb=2021.13.0=hceb3a55_1 + - terminado=0.18.1=pyh0d859eb_0 + - threadpoolctl=3.6.0=pyhecae5ae_0 + - tinycss2=1.4.0=pyhd8ed1ab_0 + - tinydb=4.8.2=pyhd8ed1ab_1 + - tk=8.6.13=noxft_h4845f30_101 + - tomli=2.2.1=pyhd8ed1ab_1 + - torchaudio=2.2.1=py310_cu118 + - torchtriton=2.2.0=py310 + - torchvision=0.17.1=py310_cu118 + - tornado=6.4.2=py310ha75aee5_0 + - tqdm=4.67.1=pyhd8ed1ab_1 + - traitlets=5.14.3=pyhd8ed1ab_1 + - types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0 + - typing-extensions=4.12.2=hd8ed1ab_1 + - typing_extensions=4.12.2=pyha770c72_1 + - typing_utils=0.1.0=pyhd8ed1ab_1 + - tzdata=2025a=h78e105d_0 + - unicodedata2=16.0.0=py310ha75aee5_0 + - uri-template=1.3.0=pyhd8ed1ab_1 + - urllib3=2.3.0=pyhd8ed1ab_0 + - validators=0.34.0=pyhd8ed1ab_1 + - wayland=1.23.1=h3e06ad9_0 + - wayland-protocols=1.41=hd8ed1ab_0 + - wcwidth=0.2.13=pyhd8ed1ab_1 + - webcolors=24.11.1=pyhd8ed1ab_0 + - webencodings=0.5.1=pyhd8ed1ab_3 + - websocket-client=1.8.0=pyhd8ed1ab_1 + - wheel=0.45.1=pyhd8ed1ab_1 + - wrapt=1.17.2=py310ha75aee5_0 + - x264=1!164.3095=h166bdaf_2 + - x265=3.5=h924138e_3 + - xkeyboard-config=2.43=hb9d3cd8_0 + - xmltodict=0.14.2=pyhd8ed1ab_1 + - xorg-libice=1.1.2=hb9d3cd8_0 + - xorg-libsm=1.2.6=he73a12e_0 + - xorg-libx11=1.8.12=h4f16b4b_0 + - xorg-libxau=1.0.12=hb9d3cd8_0 + - xorg-libxcursor=1.2.3=hb9d3cd8_0 + - xorg-libxdmcp=1.1.5=hb9d3cd8_0 + - xorg-libxext=1.3.6=hb9d3cd8_0 + - xorg-libxfixes=6.0.1=hb9d3cd8_0 + - xorg-libxrender=0.9.12=hb9d3cd8_0 + - xorg-libxscrnsaver=1.2.4=hb9d3cd8_0 + - xorg-libxt=1.3.1=hb9d3cd8_0 + - yaml=0.2.5=h7f98852_2 + - yarl=1.18.3=py310h89163eb_1 + - zeromq=4.3.5=h3b0a872_7 + - zipp=3.21.0=pyhd8ed1ab_1 + - zlib=1.3.1=hb9d3cd8_2 + - zlib-ng=2.2.4=h7955e40_0 + - zstandard=0.23.0=py310ha75aee5_1 + - zstd=1.5.7=hb8e6e7a_1 + - pip: + - absl-py==2.1.0 + - alembic==1.15.1 + - amberutils==21.0 + - antlr4-python3-runtime==4.9.3 + - autopage==0.5.2 + - beartype==0.20.0 + - biopandas==0.5.1 + - biopython==1.79 + - biotite==1.1.0 + - biotraj==1.2.2 + - cfgv==3.4.0 + - cftime==1.6.4.post1 + - click==8.1.8 + - cliff==4.9.1 + - cloudpathlib==0.21.0 + - cmaes==0.11.1 + - cmd2==2.5.11 + - colorlog==6.9.0 + - distlib==0.3.9 + - git+https://github.com/NVIDIA/dllogger.git@0540a43971f4a8a16693a9de9de73c1072020769 + - dm-tree==0.1.9 + - docker-pycreds==0.4.0 + - duckdb==1.2.1 + - edgembar==3.0 + - einops==0.8.1 + - eval-type-backport==0.2.2 + - executing==2.2.0 + - fair-esm==2.0.0 + - fairscale==0.4.13 + - fastcore==1.7.29 + - future==1.0.0 + - fvcore==0.1.5.post20221221 + - gcsfs==2025.3.0 + - gemmi==0.7.0 + - gitdb==4.0.12 + - gitpython==3.1.44 + - google-api-core==2.24.2 + - google-auth==2.38.0 + - google-auth-oauthlib==1.2.1 + - google-cloud-core==2.4.3 + - google-cloud-storage==3.1.0 + - google-crc32c==1.6.0 + - google-resumable-media==2.7.2 + - googleapis-common-protos==1.69.1 + - hydra-colorlog==1.2.0 + - hydra-core==1.3.2 + - hydra-optuna-sweeper==1.2.0 + - identify==2.6.9 + - iniconfig==2.0.0 + - iopath==0.1.10 + - ipython-genutils==0.2.0 + - ipywidgets==7.8.5 + - jupyterlab-widgets==1.1.11 + - lightning==2.5.0.post0 + - lightning-utilities==0.14.1 + - looseversion==1.1.2 + - lovely-numpy==0.2.13 + - lovely-tensors==0.1.18 + - mako==1.3.9 + - markdown-it-py==3.0.0 + - mdurl==0.1.2 + - ml-collections==1.0.0 + - mmcif==0.91.0 + - mmpbsa-py==16.0 + - mmtf-python==1.1.3 + - mols2grid==2.0.0 + - msgpack==1.1.0 + - msgpack-numpy==0.4.8 + - narwhals==1.30.0 + - netcdf4==1.7.2 + - nodeenv==1.9.1 + - numpy==1.26.4 + - oauthlib==3.2.2 + - omegaconf==2.3.0 + - git+https://github.com/amorehead/openfold.git@fe1275099639bf7e617e09ef24d6af778647dd64 + - optuna==2.10.1 + - packmol-memgen==2025.1.29 + - pandas==2.2.3 + - pandocfilters==1.5.1 + - pbr==6.1.1 + - pdb4amber==22.0 + - plinder==0.2.24 + - plotly==6.0.0 + - pluggy==1.5.0 + - portalocker==3.1.1 + - posebusters==0.2.13 + - git+https://git@github.com/zrqiao/power_spherical.git@290b1630c5f84e3bb0d61711046edcf6e47200d4 + - pre-commit==4.1.0 + - prettytable==3.15.1 + # - prody==2.4.1 # NOTE: we must `pip` install Prody to skip its NumPy dependency + - propcache==0.3.0 + - proto-plus==1.26.1 + - protobuf==5.29.3 + - pyarrow==19.0.1 + - pyasn1==0.6.1 + - pyasn1-modules==0.4.1 + - pymsmt==22.0 + - pyperclip==1.9.0 + - pytest==8.3.5 + - python-dotenv==1.0.1 + - python-json-logger==3.3.0 + - pytorch-lightning==2.5.0.post0 + - git+https://github.com/facebookresearch/pytorch3d.git@3da7703c5ac10039645966deddffe8db52eab8c5 + - pytraj==2.0.6 + - requests-oauthlib==2.0.0 + - rich==13.9.4 + - rootutils==1.0.7 + - rsa==4.9 + - sander==22.0 + - seaborn==0.13.2 + - sentry-sdk==2.22.0 + - setproctitle==1.3.5 + - smmap==5.0.2 + - soupsieve==2.6 + - stevedore==5.4.1 + - tabulate==0.9.0 + - termcolor==2.5.0 + - torchmetrics==1.6.3 + - virtualenv==20.29.3 + - wandb==0.19.8 + - widgetsnbextension==3.6.10 + - yacs==0.1.8 diff --git a/environments/flowdock_environment_docker.yaml b/environments/flowdock_environment_docker.yaml new file mode 100644 index 0000000..10acdca --- /dev/null +++ b/environments/flowdock_environment_docker.yaml @@ -0,0 +1,58 @@ +name: flowdock +channels: + - pyg + - pytorch + - nvidia + - defaults + - conda-forge +dependencies: + - mendeleev=0.20.1=pymin39_ha308f57_3 + - networkx=3.4.2=pyh267e887_2 + - python=3.10.16=he725a3c_1_cpython + - pytorch=2.2.1=py3.10_cuda11.8_cudnn8.7.0_0 + - pytorch-cuda=11.8=h7e8668a_6 + - pytorch-mutex=1.0=cuda + - pytorch-scatter=2.1.2=py310_torch_2.2.0_cu118 + - rdkit=2024.09.6=py310hcd13295_0 + - scikit-learn=1.6.1=py310h27f47ee_0 + - scipy=1.15.2=py310h1d65ade_0 + - torchaudio=2.2.1=py310_cu118 + - torchtriton=2.2.0=py310 + - torchvision=0.17.1=py310_cu118 + - tqdm=4.67.1=pyhd8ed1ab_1 + - pip: + - beartype==0.20.0 + - biopandas==0.5.1 + - biopython==1.79 + - biotite==1.1.0 + - git+https://github.com/NVIDIA/dllogger.git@0540a43971f4a8a16693a9de9de73c1072020769 + - dm-tree==0.1.9 + - einops==0.8.1 + - fair-esm==2.0.0 + - fairscale==0.4.13 + - gemmi==0.7.0 + - hydra-colorlog==1.2.0 + - hydra-core==1.3.2 + - hydra-optuna-sweeper==1.2.0 + - lightning==2.5.0.post0 + - lightning-utilities==0.14.1 + - lovely-numpy==0.2.13 + - lovely-tensors==0.1.18 + - ml-collections==1.0.0 + - msgpack==1.1.0 + - msgpack-numpy==0.4.8 + - numpy==1.26.4 + - omegaconf==2.3.0 + - git+https://github.com/amorehead/openfold.git@fe1275099639bf7e617e09ef24d6af778647dd64 + - pandas==2.2.3 + - plinder==0.2.24 + - plotly==6.0.0 + - posebusters==0.2.13 + # - prody==2.4.1 # NOTE: we must `pip` install Prody to skip its NumPy dependency + - pytorch-lightning==2.5.0.post0 + - git+https://github.com/facebookresearch/pytorch3d.git@3da7703c5ac10039645966deddffe8db52eab8c5 + - rich==13.9.4 + - rootutils==1.0.7 + - seaborn==0.13.2 + - torchmetrics==1.6.3 + - wandb==0.19.8 diff --git a/flowdock/__init__.py b/flowdock/__init__.py new file mode 100644 index 0000000..26bca07 --- /dev/null +++ b/flowdock/__init__.py @@ -0,0 +1,120 @@ +import importlib +import os + +from beartype.typing import Any +from omegaconf import OmegaConf + +METHOD_TITLE_MAPPING = { + "diffdock": "DiffDock", + "flowdock": "FlowDock", + "neuralplexer": "NeuralPLexer", +} + +STANDARDIZED_DIR_METHODS = ["diffdock"] + + +def resolve_omegaconf_variable(variable_path: str) -> Any: + """Resolve an OmegaConf variable path to its value.""" + # split the string into parts using the dot separator + parts = variable_path.rsplit(".", 1) + + # get the module name from the first part of the path + module_name = parts[0] + + # dynamically import the module using the module name + try: + module = importlib.import_module(module_name) + # use the imported module to get the requested attribute value + attribute = getattr(module, parts[1]) + except Exception: + module = importlib.import_module(".".join(module_name.split(".")[:-1])) + inner_module = ".".join(module_name.split(".")[-1:]) + # use the imported module to get the requested attribute value + attribute = getattr(getattr(module, inner_module), parts[1]) + + return attribute + + +def resolve_dataset_path_dirname(dataset: str) -> str: + """Resolve the dataset path directory name based on the dataset's name. + + :param dataset: Name of the dataset. + :return: Directory name for the dataset path. + """ + return "DockGen" if dataset == "dockgen" else dataset + + +def resolve_method_input_csv_path(method: str, dataset: str) -> str: + """Resolve the input CSV path for a given method. + + :param method: The method name. + :param dataset: The dataset name. + :return: The input CSV path for the given method. + """ + if method in STANDARDIZED_DIR_METHODS or method in ["flowdock", "neuralplexer"]: + return os.path.join( + "forks", + METHOD_TITLE_MAPPING.get(method, method), + "inference", + f"{method}_{dataset}_inputs.csv", + ) + else: + raise ValueError(f"Invalid method: {method}") + + +def resolve_method_title(method: str) -> str: + """Resolve the method title for a given method. + + :param method: The method name. + :return: The method title for the given method. + """ + return METHOD_TITLE_MAPPING.get(method, method) + + +def resolve_method_output_dir( + method: str, + dataset: str, + repeat_index: int, +) -> str: + """Resolve the output directory for a given method. + + :param method: The method name. + :param dataset: The dataset name. + :param repeat_index: The repeat index for the method. + :return: The output directory for the given method. + """ + if method in STANDARDIZED_DIR_METHODS or method in ["flowdock", "neuralplexer"]: + return os.path.join( + "forks", + METHOD_TITLE_MAPPING.get(method, method), + "inference", + f"{method}_{dataset}_output{'s' if method in ['flowdock', 'neuralplexer'] else ''}_{repeat_index}", + ) + else: + raise ValueError(f"Invalid method: {method}") + + +def register_custom_omegaconf_resolvers(): + """Register custom OmegaConf resolvers.""" + OmegaConf.register_new_resolver( + "resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path) + ) + OmegaConf.register_new_resolver( + "resolve_dataset_path_dirname", lambda dataset: resolve_dataset_path_dirname(dataset) + ) + OmegaConf.register_new_resolver( + "resolve_method_input_csv_path", + lambda method, dataset: resolve_method_input_csv_path(method, dataset), + ) + OmegaConf.register_new_resolver( + "resolve_method_title", lambda method: resolve_method_title(method) + ) + OmegaConf.register_new_resolver( + "resolve_method_output_dir", + lambda method, dataset, repeat_index: resolve_method_output_dir( + method, dataset, repeat_index + ), + ) + OmegaConf.register_new_resolver( + "int_divide", lambda dividend, divisor: int(dividend) // int(divisor) + ) diff --git a/flowdock/eval.py b/flowdock/eval.py new file mode 100644 index 0000000..2d26cb9 --- /dev/null +++ b/flowdock/eval.py @@ -0,0 +1,165 @@ +import os + +import hydra +import lightning as L +import lovely_tensors as lt +import rootutils +import torch +from beartype.typing import Any, Dict, List, Tuple +from lightning import LightningDataModule, LightningModule, Trainer +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies.strategy import Strategy +from omegaconf import DictConfig, open_dict + +lt.monkey_patch() + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from flowdock import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable +from flowdock.utils import ( + RankedLogger, + extras, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +@task_wrapper +def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Evaluates given checkpoint on a datamodule testset. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. + """ + assert cfg.ckpt_path, "Please provide a checkpoint path to evaluate!" + assert os.path.exists(cfg.ckpt_path), f"Checkpoint path {cfg.ckpt_path} does not exist!" + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info( + f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}." + ) + torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, stage="test") + + # Establish model input arguments + with open_dict(cfg): + if cfg.model.cfg.task.start_time == "auto": + cfg.model.cfg.task.start_time = 1.0 + else: + cfg.model.cfg.task.start_time = float(cfg.model.cfg.task.start_time) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + plugins = None + if "_target_" in cfg.environment: + log.info(f"Instantiating environment <{cfg.environment._target_}>") + plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment) + + strategy = getattr(cfg.trainer, "strategy", None) + if "_target_" in cfg.strategy: + log.info(f"Instantiating strategy <{cfg.strategy._target_}>") + strategy: Strategy = hydra.utils.instantiate(cfg.strategy) + if ( + "mixed_precision" in strategy.__dict__ + and getattr(strategy, "mixed_precision", None) is not None + ): + strategy.mixed_precision.param_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) + if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None + else None + ) + strategy.mixed_precision.reduce_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) + if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None + else None + ) + strategy.mixed_precision.buffer_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) + if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None + else None + ) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = ( + hydra.utils.instantiate( + cfg.trainer, + logger=logger, + plugins=plugins, + strategy=strategy, + ) + if strategy is not None + else hydra.utils.instantiate( + cfg.trainer, + logger=logger, + plugins=plugins, + ) + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + log.info("Starting testing!") + trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) + + metric_dict = trainer.callback_metrics + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") +def main(cfg: DictConfig) -> None: + """Main entry point for evaluation. + + :param cfg: DictConfig configuration composed by Hydra. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + evaluate(cfg) + + +if __name__ == "__main__": + register_custom_omegaconf_resolvers() + main() diff --git a/flowdock/models/__init__.py b/flowdock/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flowdock/models/components/__init__.py b/flowdock/models/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flowdock/models/components/callbacks/ema.py b/flowdock/models/components/callbacks/ema.py new file mode 100644 index 0000000..b91eb6e --- /dev/null +++ b/flowdock/models/components/callbacks/ema.py @@ -0,0 +1,452 @@ +# Adapted from https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import copy +import os +import threading +from pathlib import Path +from typing import Any, Dict, Iterable, Optional, Union + +import lightning.pytorch as pl +import torch +from lightning.pytorch import Callback +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.rank_zero import rank_zero_info + + +class EMA(Callback): + """Implements Exponential Moving Averaging (EMA). + + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + validate_original_weights: Validate the original weights, as apposed to the EMA weights. + every_n_steps: Apply EMA every N steps. + cpu_offload: Offload weights to CPU. + """ + + def __init__( + self, + decay: float, + validate_original_weights: bool = False, + every_n_steps: int = 1, + cpu_offload: bool = False, + ): + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self.decay = decay + self.validate_original_weights = validate_original_weights + self.every_n_steps = every_n_steps + self.cpu_offload = cpu_offload + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Add the EMA optimizer to the trainer.""" + device = pl_module.device if not self.cpu_offload else torch.device("cpu") + trainer.optimizers = [ + EMAOptimizer( + optim, + device=device, + decay=self.decay, + every_n_steps=self.every_n_steps, + current_step=trainer.global_step, + ) + for optim in trainer.optimizers + if not isinstance(optim, EMAOptimizer) + ] + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Swap the model weights with the EMA weights.""" + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Swap the model weights back to the original weights.""" + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Swap the model weights with the EMA weights.""" + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Swap the model weights back to the original weights.""" + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: + """Check if the EMA weights should be validated.""" + return not self.validate_original_weights and self._ema_initialized(trainer) + + def _ema_initialized(self, trainer: "pl.Trainer") -> bool: + """Check if the EMA weights have been initialized.""" + return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers) + + def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False): + """Swaps the model weights with the EMA weights.""" + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights(saving_ema_model) + + @contextlib.contextmanager + def save_ema_model(self, trainer: "pl.Trainer"): + """Saves an EMA copy of the model + EMA optimizer states for resume.""" + self.swap_model_weights(trainer, saving_ema_model=True) + try: + yield + finally: + self.swap_model_weights(trainer, saving_ema_model=False) + + @contextlib.contextmanager + def save_original_optimizer_state(self, trainer: "pl.Trainer"): + """Save the original optimizer state.""" + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.save_original_optimizer_state = True + try: + yield + finally: + for optimizer in trainer.optimizers: + optimizer.save_original_optimizer_state = False + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> None: + """Load the EMA state from the checkpoint if it exists.""" + checkpoint_callback = trainer.checkpoint_callback + + # use the connector as NeMo calls the connector directly in the exp_manager when restoring. + connector = trainer._checkpoint_connector + # Replace connector._ckpt_path with below to avoid calling into lightning's protected API + ckpt_path = trainer.ckpt_path + + if ( + ckpt_path + and checkpoint_callback is not None + and "EMA" in type(checkpoint_callback).__name__ + ): + ext = checkpoint_callback.FILE_EXTENSION + if ckpt_path.endswith(f"-EMA{ext}"): + rank_zero_info( + "loading EMA based weights. " + "The callback will treat the loaded EMA weights as the main weights" + " and create a new EMA copy when training." + ) + return + ema_path = ckpt_path.replace(ext, f"-EMA{ext}") + if os.path.exists(ema_path): + ema_state_dict = torch.load(ema_path, map_location=torch.device("cpu")) + + checkpoint["optimizer_states"] = ema_state_dict["optimizer_states"] + del ema_state_dict + rank_zero_info("EMA state has been restored.") + else: + raise MisconfigurationException( + "Unable to find the associated EMA weights when re-loading, " + f"training will start with new EMA weights. Expected them to be at: {ema_path}", + ) + + +@torch.no_grad() +def ema_update(ema_model_tuple, current_model_tuple, decay): + """Update the EMA model with the current model.""" + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, + current_model_tuple, + alpha=(1.0 - decay), + ) + + +def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None): + """Run EMA update on CPU.""" + if pre_sync_stream is not None: + pre_sync_stream.synchronize() + + ema_update(ema_model_tuple, current_model_tuple, decay) + + +class EMAOptimizer(torch.optim.Optimizer): + r"""EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average + of parameters registered in the optimizer. + + EMA parameters are automatically updated after every step of the optimizer + with the following formula: + + ema_weight = decay * ema_weight + (1 - decay) * training_weight + + To access EMA parameters, use ``swap_ema_weights()`` context manager to + perform a temporary in-place swap of regular parameters with EMA + parameters. + + Notes: + - EMAOptimizer is not compatible with APEX AMP O2. + + Args: + optimizer (torch.optim.Optimizer): optimizer to wrap + device (torch.device): device for EMA parameters + decay (float): decay factor + + Returns: + returns an instance of torch.optim.Optimizer that computes EMA of + parameters + + Example: + model = Model().to(device) + opt = torch.optim.Adam(model.parameters()) + + opt = EMAOptimizer(opt, device, 0.9999) + + for epoch in range(epochs): + training_loop(model, opt) + + regular_eval_accuracy = evaluate(model) + + with opt.swap_ema_weights(): + ema_eval_accuracy = evaluate(model) + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + device: torch.device, + decay: float = 0.9999, + every_n_steps: int = 1, + current_step: int = 0, + ): + self.optimizer = optimizer + self.decay = decay + self.device = device + self.current_step = current_step + self.every_n_steps = every_n_steps + self.save_original_optimizer_state = False + + self.first_iteration = True + self.rebuild_ema_params = True + self.stream = None + self.thread = None + + self.ema_params = () + self.in_saving_ema_model_context = False + + def all_parameters(self) -> Iterable[torch.Tensor]: + """Return an iterator over all parameters in the optimizer.""" + return (param for group in self.param_groups for param in group["params"]) + + def step(self, closure=None, grad_scaler=None, **kwargs): + """Perform a single optimization step.""" + self.join() + + if self.first_iteration: + if any(p.is_cuda for p in self.all_parameters()): + self.stream = torch.cuda.Stream() + + self.first_iteration = False + + if self.rebuild_ema_params: + opt_params = list(self.all_parameters()) + + self.ema_params += tuple( + copy.deepcopy(param.data.detach()).to(self.device) + for param in opt_params[len(self.ema_params) :] + ) + self.rebuild_ema_params = False + + if ( + getattr(self.optimizer, "_step_supports_amp_scaling", False) + and grad_scaler is not None + ): + loss = self.optimizer.step(closure=closure, grad_scaler=grad_scaler) + else: + loss = self.optimizer.step(closure) + + if self._should_update_at_step(): + self.update() + self.current_step += 1 + return loss + + def _should_update_at_step(self) -> bool: + """Check if the EMA parameters should be updated at the current step.""" + return self.current_step % self.every_n_steps == 0 + + @torch.no_grad() + def update(self): + """Update the EMA parameters.""" + if self.stream is not None: + self.stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.stream): + current_model_state = tuple( + param.data.to(self.device, non_blocking=True) for param in self.all_parameters() + ) + + if self.device.type == "cuda": + ema_update(self.ema_params, current_model_state, self.decay) + + if self.device.type == "cpu": + self.thread = threading.Thread( + target=run_ema_update_cpu, + args=( + self.ema_params, + current_model_state, + self.decay, + self.stream, + ), + ) + self.thread.start() + + def swap_tensors(self, tensor1, tensor2): + """Swaps the tensors in-place.""" + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + def switch_main_parameter_weights(self, saving_ema_model: bool = False): + """Switches the main parameter weights with the EMA weights.""" + self.join() + self.in_saving_ema_model_context = saving_ema_model + for param, ema_param in zip(self.all_parameters(), self.ema_params): + self.swap_tensors(param.data, ema_param) + + @contextlib.contextmanager + def swap_ema_weights(self, enabled: bool = True): + r"""A context manager to in-place swap regular parameters with EMA parameters. It swaps back + to the original regular parameters on context manager exit. + + Args: + enabled (bool): whether the swap should be performed + """ + + if enabled: + self.switch_main_parameter_weights() + try: + yield + finally: + if enabled: + self.switch_main_parameter_weights() + + def __getattr__(self, name): + """Forward all other attribute calls to the optimizer.""" + return getattr(self.optimizer, name) + + def join(self): + """Wait for the update to complete.""" + if self.stream is not None: + self.stream.synchronize() + + if self.thread is not None: + self.thread.join() + + def state_dict(self): + """Return the state dict for the optimizer.""" + self.join() + + if self.save_original_optimizer_state: + return self.optimizer.state_dict() + + # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights + ema_params = ( + self.ema_params + if not self.in_saving_ema_model_context + else list(self.all_parameters()) + ) + state_dict = { + "opt": self.optimizer.state_dict(), + "ema": ema_params, + "current_step": self.current_step, + "decay": self.decay, + "every_n_steps": self.every_n_steps, + } + return state_dict + + def load_state_dict(self, state_dict): + """Load the state dict for the optimizer.""" + self.join() + + self.optimizer.load_state_dict(state_dict["opt"]) + self.ema_params = tuple( + param.to(self.device) for param in copy.deepcopy(state_dict["ema"]) + ) + self.current_step = state_dict["current_step"] + self.decay = state_dict["decay"] + self.every_n_steps = state_dict["every_n_steps"] + self.rebuild_ema_params = False + + def add_param_group(self, param_group): + """Add a param group to the optimizer.""" + self.optimizer.add_param_group(param_group) + self.rebuild_ema_params = True + + +class EMAModelCheckpoint(ModelCheckpoint): + """EMA version of ModelCheckpoint that saves EMA checkpoints as well.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _ema_callback(self, trainer: "pl.Trainer") -> Optional[EMA]: + """Returns the EMA callback if it exists.""" + ema_callback = None + for callback in trainer.callbacks: + if isinstance(callback, EMA): + ema_callback = callback + return ema_callback + + def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + """Saves the checkpoint file and the EMA checkpoint file if it exists.""" + ema_callback = self._ema_callback(trainer) + if ema_callback is not None: + with ema_callback.save_original_optimizer_state(trainer): + super()._save_checkpoint(trainer, filepath) + + # save EMA copy of the model as well. + with ema_callback.save_ema_model(trainer): + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + super()._save_checkpoint(trainer, filepath) + else: + super()._save_checkpoint(trainer, filepath) + + def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + """Removes the checkpoint file and the EMA checkpoint file if it exists.""" + super()._remove_checkpoint(trainer, filepath) + ema_callback = self._ema_callback(trainer) + if ema_callback is not None: + # remove EMA copy of the state dict as well. + filepath = self._ema_format_filepath(filepath) + super()._remove_checkpoint(trainer, filepath) + + def _ema_format_filepath(self, filepath: str) -> str: + """Appends '-EMA' to the filepath.""" + return filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}") + + def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool: + """Checks if any of the checkpoints are EMA checkpoints.""" + return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints) + + def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: + """Checks if the filepath is an EMA checkpoint.""" + return str(filepath).endswith(f"-EMA{self.FILE_EXTENSION}") + + @property + def _saved_checkpoint_paths(self) -> Iterable[Path]: + """Returns all the saved checkpoint paths in the directory.""" + return Path(self.dirpath).rglob("*.ckpt") diff --git a/flowdock/models/components/cpm.py b/flowdock/models/components/cpm.py new file mode 100644 index 0000000..b66942d --- /dev/null +++ b/flowdock/models/components/cpm.py @@ -0,0 +1,857 @@ +# Adapted from: https://github.com/zrqiao/NeuralPLexer + +import random + +import rootutils +import torch +from beartype.typing import Any, Dict, Optional, Tuple +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.models.components.embedding import ( + GaussianFourierEncoding1D, + RelativeGeometryEncoding, +) +from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher +from flowdock.models.components.modules import ( + BiDirectionalTriangleAttention, + TransformerLayer, +) +from flowdock.utils import RankedLogger +from flowdock.utils.frame_utils import cartesian_to_internal, get_frame_matrix +from flowdock.utils.model_utils import GELUMLP + +log = RankedLogger(__name__, rank_zero_only=True) + +STATE_DICT = Dict[str, Any] + + +class ProtFormer(torch.nn.Module): + """Protein relational reasoning with downsampled edges.""" + + def __init__( + self, + dim: int, + pair_dim: int, + n_blocks: int = 4, + n_heads: int = 8, + dropout: float = 0.0, + ): + """Initialize the ProtFormer model.""" + super().__init__() + self.dim = dim + self.pair_dim = pair_dim + self.n_heads = n_heads + self.n_blocks = n_blocks + self.time_encoding = GaussianFourierEncoding1D(16) + self.res_in_mlp = GELUMLP(dim + 32, dim, dropout=dropout) + self.chain_pos_encoding = GaussianFourierEncoding1D(self.pair_dim // 4) + + self.rel_geom_enc = RelativeGeometryEncoding(15, self.pair_dim) + self.template_nenc = GELUMLP(64 + 37 * 3, self.dim, n_hidden_feats=128) + self.template_eenc = RelativeGeometryEncoding(15, self.pair_dim) + self.template_binding_site_enc = torch.nn.Linear(1, 64, bias=False) + self.pp_edge_embed = GELUMLP( + pair_dim + self.pair_dim // 4 * 2 + dim * 2, + self.pair_dim, + n_hidden_feats=dim, + dropout=dropout, + ) + + self.graph_stacks = torch.nn.ModuleList( + [ + TransformerLayer( + dim, + n_heads, + head_dim=pair_dim // n_heads, + edge_channels=pair_dim, + edge_update=True, + dropout=dropout, + ) + for _ in range(self.n_blocks) + ] + ) + self.ABab_mha = TransformerLayer( + pair_dim, + n_heads, + bidirectional=True, + ) + + self.triangle_stacks = torch.nn.ModuleList( + [ + BiDirectionalTriangleAttention(pair_dim, pair_dim // n_heads, n_heads) + for _ in range(self.n_blocks) + ] + ) + self.graph_relations = [ + ( + "residue_to_residue", + "gather_idx_ab_a", + "gather_idx_ab_b", + "prot_res", + "prot_res", + ), + ( + "sampled_residue_to_sampled_residue", + "gather_idx_AB_a", + "gather_idx_AB_b", + "prot_res", + "prot_res", + ), + ] + + def compute_chain_pe( + self, + residue_index, + res_chain_index, + src_rid, + dst_rid, + ): + """Compute chain positional encoding for a pair of residues.""" + chain_disp = residue_index[src_rid] - residue_index[dst_rid] + chain_rope = self.chain_pos_encoding(chain_disp.div(8).unsqueeze(-1)).div( + chain_disp.div(8).abs().add(1).unsqueeze(-1) + ) + # Mask cross-chain entries + chain_mask = res_chain_index[src_rid] == res_chain_index[dst_rid] + chain_rope = chain_rope * chain_mask[..., None] + return chain_rope + + def compute_chain_pair_pe( + self, + residue_index, + res_chain_index, + AB_broadcasted_rid, + ab_rid, + AB_broadcasted_cid, + ab_cid, + ): + """Compute chain positional encoding for a pair of residues.""" + chain_disp_row = residue_index[AB_broadcasted_rid] - residue_index[ab_rid] + chain_disp_col = residue_index[AB_broadcasted_cid] - residue_index[ab_cid] + chain_disp = torch.stack([chain_disp_row, chain_disp_col], dim=-1) + chain_rope = self.chain_pos_encoding(chain_disp.div(8).unsqueeze(-1)).div( + chain_disp.div(8).abs().add(1).unsqueeze(-1) + ) + # Mask cross-chain entries + chain_mask_row = res_chain_index[AB_broadcasted_rid] == res_chain_index[ab_rid] + chain_mask_col = res_chain_index[AB_broadcasted_cid] == res_chain_index[ab_cid] + chain_mask = torch.stack([chain_mask_row, chain_mask_col], dim=-1) + chain_rope = (chain_rope * chain_mask[..., None]).flatten(-2, -1) + return chain_rope + + def eval_protein_template_encodings(self, batch, edge_idx, use_plddt=False): + """Evaluate template encodings for protein residues.""" + with torch.no_grad(): + template_bb_coords = batch["features"]["apo_res_atom_positions"][:, :3] + template_bb_frames = get_frame_matrix( + template_bb_coords[:, 0, :], + template_bb_coords[:, 1, :], + template_bb_coords[:, 2, :], + ) + # Add template local representations & lddt + template_local_coords = cartesian_to_internal( + batch["features"]["apo_res_atom_positions"], + template_bb_frames.unsqueeze(1), + ) + template_local_coords[~batch["features"]["apo_res_atom_mask"].bool()] = 0 + # if use_plddt: + # template_plddt_enc = F.one_hot( + # torch.bucketize( + # batch["features"]["apo_pLDDT"], + # torch.linspace(0, 1, 65, device=template_bb_coords.device)[:-1], + # right=True, + # ) + # - 1, + # num_classes=64, + # ) + # else: + # template_plddt_enc = torch.zeros( + # template_local_coords.shape[0], 64, device=template_bb_coords.device + # ) + if self.training: + use_sidechain_coords = random.randint(0, 1) # nosec + template_local_coords = template_local_coords * use_sidechain_coords + # use_plddt_input = random.randint(0, 1) + # template_plddt_enc = template_plddt_enc * use_plddt_input + if "binding_site_mask" in batch["features"].keys(): + # Externally-specified binding residue list + binding_site_enc = self.template_binding_site_enc( + batch["features"]["binding_site_mask"][:, None].float() + ) + else: + binding_site_enc = torch.zeros( + template_local_coords.shape[0], 64, device=template_bb_coords.device + ) + template_nfeat = self.template_nenc( + torch.cat([template_local_coords.flatten(-2, -1), binding_site_enc], dim=-1) + ) + template_efeat = self.template_eenc(template_bb_frames, edge_idx) + template_alignment_mask = batch["features"]["apo_res_alignment_mask"].float() + if self.training: + # template_alignment_mask = template_alignment_mask * use_template + nomasking_rate = random.randint(9, 10) / 10 # nosec + template_alignment_mask = template_alignment_mask * ( + torch.rand_like(template_alignment_mask) < nomasking_rate + ) + template_nfeat = template_nfeat * template_alignment_mask.unsqueeze(-1) + template_efeat = ( + template_efeat + * template_alignment_mask[edge_idx[0]].unsqueeze(-1) + * template_alignment_mask[edge_idx[1]].unsqueeze(-1) + ) + return template_nfeat, template_efeat + + def forward(self, batch, **kwargs): + """Forward pass of the ProtFormer model.""" + return self.forward_prot_sample(batch, **kwargs) + + def forward_prot_sample( + self, + batch, + embed_coords=True, + in_attr_suffix="", + out_attr_suffix="", + use_template=False, + use_plddt=False, + **kwargs, + ): + """Forward pass of the ProtFormer model for a single protein sample.""" + features = batch["features"] + indexer = batch["indexer"] + metadata = batch["metadata"] + device = features["res_type"].device + + time_encoding = self.time_encoding(features["timestep_encoding_prot"]) + if not embed_coords: + time_encoding = torch.zeros_like(time_encoding) + + residue_rep = ( + self.res_in_mlp( + torch.cat( + [ + features["res_embedding_in"], + time_encoding, + ], + dim=1, + ) + ) + + features["res_embedding_in"] + ) + batch_size = metadata["num_structid"] + + # Prepare indexers + # Use max to ensure segmentation faults are 100% invoked + # in case there are any bad indices + max(metadata["num_a_per_sample"]) + n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"] + + indexer["gather_idx_pid_b"] = indexer["gather_idx_pid_a"] + # Evaluate gather_idx_AB_a and gather_idx_AB_b + # Assign a to rows and b to columns + # Simple broadcasting for single-structure batches + indexer["gather_idx_AB_a"] = ( + indexer["gather_idx_pid_a"] + .view(batch_size, n_protein_patches)[:, :, None] + .expand(-1, -1, n_protein_patches) + .contiguous() + .flatten() + ) + indexer["gather_idx_AB_b"] = ( + indexer["gather_idx_pid_b"] + .view(batch_size, n_protein_patches)[:, None, :] + .expand(-1, n_protein_patches, -1) + .contiguous() + .flatten() + ) + + # Handle all batch offsets here + graph_batcher = make_multi_relation_graph_batcher(self.graph_relations, indexer, metadata) + merged_edge_idx = graph_batcher.collate_idx_list(indexer) + + input_protein_coords_padded = features["input_protein_coords"] + backbone_frames = get_frame_matrix( + input_protein_coords_padded[:, 0, :], + input_protein_coords_padded[:, 1, :], + input_protein_coords_padded[:, 2, :], + ) + batch["features"]["backbone_frames"] = backbone_frames + # Adding geometrical info to pair representations + + chain_pe = self.compute_chain_pe( + features["residue_index"], + features["res_chain_id"], + merged_edge_idx[0], + merged_edge_idx[1], + ) + geometry_pe = self.rel_geom_enc(backbone_frames, merged_edge_idx) + if not embed_coords: + geometry_pe = torch.zeros_like(geometry_pe) + merged_edge_reps = self.pp_edge_embed( + torch.cat( + [ + geometry_pe, + chain_pe, + residue_rep[merged_edge_idx[0]], + residue_rep[merged_edge_idx[1]], + ], + dim=-1, + ) + ) + if use_template and "apo_res_atom_positions" in features.keys(): + ( + template_res_encodings, + template_geom_encodings, + ) = self.eval_protein_template_encodings(batch, merged_edge_idx, use_plddt=use_plddt) + residue_rep = residue_rep + template_res_encodings + merged_edge_reps = merged_edge_reps + template_geom_encodings + edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps) + + node_reps = {"prot_res": residue_rep} + + gather_idx_res_protpatch = indexer["gather_idx_a_pid"] + # Pointer: AB->AB, ab->AB + gather_idx_ab_AB = ( + indexer["gather_idx_ab_structid"] * n_protein_patches**2 + + (gather_idx_res_protpatch % n_protein_patches)[indexer["gather_idx_ab_a"]] + * n_protein_patches + + (gather_idx_res_protpatch % n_protein_patches)[indexer["gather_idx_ab_b"]] + ) + + # Intertwine graph iterations and triangle iterations + for block_id in range(self.n_blocks): + # Communicate between atomistic and patch resolutions + # Up-sampling for interface edge embeddings + rec_pair_rep = edge_reps["residue_to_residue"] + AB_grid_attr_flat = edge_reps["sampled_residue_to_sampled_residue"] + # Upper-left block: intra-window visual-attention + # Cross-attention between random and grid edges + rec_pair_rep, AB_grid_attr_flat = self.ABab_mha( + rec_pair_rep, + AB_grid_attr_flat, + ( + torch.arange(metadata["num_ab"], device=device), + gather_idx_ab_AB, + ), + ) + AB_grid_attr = AB_grid_attr_flat.view( + batch_size, + n_protein_patches, + n_protein_patches, + self.pair_dim, + ) + + # Inter-patch triangle attentions, refining intermolecular edges + _, AB_grid_attr = self.triangle_stacks[block_id]( + AB_grid_attr, + AB_grid_attr, + AB_grid_attr.unsqueeze(-4), + ) + + # Transfer grid-formatted representations to edges + edge_reps["residue_to_residue"] = rec_pair_rep + edge_reps["sampled_residue_to_sampled_residue"] = AB_grid_attr.flatten(0, 2) + merged_node_reps = graph_batcher.collate_node_attr(node_reps) + merged_edge_reps = graph_batcher.collate_edge_attr(edge_reps) + + # Graph transformer iteration + _, merged_node_reps, merged_edge_reps = self.graph_stacks[block_id]( + merged_node_reps, + merged_node_reps, + merged_edge_idx, + merged_edge_reps, + ) + node_reps = graph_batcher.offload_node_attr(merged_node_reps) + edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps) + + batch["features"][f"rec_res_attr{out_attr_suffix}"] = node_reps["prot_res"] + batch["features"][f"res_res_pair_attr{out_attr_suffix}"] = edge_reps["residue_to_residue"] + batch["features"][f"res_res_grid_attr_flat{out_attr_suffix}"] = edge_reps[ + "sampled_residue_to_sampled_residue" + ] + batch["indexer"]["gather_idx_AB_a"] = indexer["gather_idx_AB_a"] + batch["indexer"]["gather_idx_AB_b"] = indexer["gather_idx_AB_b"] + batch["indexer"]["gather_idx_ab_AB"] = gather_idx_ab_AB + return batch + + +class BindingFormer(ProtFormer): + """Edge inference on protein-ligand graphs.""" + + def __init__( + self, + dim: int, + pair_dim: int, + n_blocks: int = 4, + n_heads: int = 8, + n_ligand_patches: int = 16, + dropout: float = 0.0, + ): + """Initialize the BindingFormer model.""" + super().__init__( + dim, + pair_dim, + n_blocks, + n_heads, + dropout, + ) + self.dim = dim + self.n_heads = n_heads + self.n_blocks = n_blocks + self.n_ligand_patches = n_ligand_patches + self.pl_edge_embed = GELUMLP(dim * 2, self.pair_dim, n_hidden_feats=dim, dropout=dropout) + self.AaJ_mha = TransformerLayer(pair_dim, n_heads, bidirectional=True) + + self.graph_relations = [ + ( + "residue_to_residue", + "gather_idx_ab_a", + "gather_idx_ab_b", + "prot_res", + "prot_res", + ), + ( + "sampled_residue_to_sampled_residue", + "gather_idx_AB_a", + "gather_idx_AB_b", + "prot_res", + "prot_res", + ), + ( + "sampled_residue_to_sampled_lig_triplet", + "gather_idx_AJ_a", + "gather_idx_AJ_J", + "prot_res", + "lig_trp", + ), + ( + "sampled_lig_triplet_to_sampled_residue", + "gather_idx_AJ_J", + "gather_idx_AJ_a", + "lig_trp", + "prot_res", + ), + ( + "residue_to_sampled_lig_triplet", + "gather_idx_aJ_a", + "gather_idx_aJ_J", + "prot_res", + "lig_trp", + ), + ( + "sampled_lig_triplet_to_residue", + "gather_idx_aJ_J", + "gather_idx_aJ_a", + "lig_trp", + "prot_res", + ), + ( + "sampled_lig_triplet_to_sampled_lig_triplet", + "gather_idx_IJ_I", + "gather_idx_IJ_J", + "lig_trp", + "lig_trp", + ), + ] + + def forward( + self, + batch, + observed_block_contacts=None, + in_attr_suffix="", + out_attr_suffix="", + ): + """Forward pass of the BindingFormer model.""" + features = batch["features"] + indexer = batch["indexer"] + metadata = batch["metadata"] + device = features["res_type"].device + # Synchronize with a language model + residue_rep = features[f"rec_res_attr{in_attr_suffix}"] + rec_pair_rep = features[f"res_res_pair_attr{in_attr_suffix}"] + # Inherit the last-layer pair representations from protein encoder + AB_grid_attr_flat = features[f"res_res_grid_attr_flat{in_attr_suffix}"] + + # Prepare indexers + batch_size = metadata["num_structid"] + n_a_per_sample = max(metadata["num_a_per_sample"]) + n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"] + + if not batch["misc"]["protein_only"]: + n_ligand_patches = max(metadata["num_I_per_sample"]) + max(metadata["num_molid_per_sample"]) + lig_frame_rep = features[f"lig_trp_attr{in_attr_suffix}"] + UI_grid_attr = features["lig_af_grid_attr_projected"] + IJ_grid_attr = (UI_grid_attr + UI_grid_attr.transpose(1, 2)) / 2 + + aJ_grid_attr = self.pl_edge_embed( + torch.cat( + [ + residue_rep.view(batch_size, n_a_per_sample, self.dim)[:, :, None].expand( + -1, -1, n_ligand_patches, -1 + ), + lig_frame_rep.view(batch_size, n_ligand_patches, self.dim)[ + :, None, : + ].expand(-1, n_a_per_sample, -1, -1), + ], + dim=-1, + ) + ) + AJ_grid_attr = IJ_grid_attr.new_zeros( + batch_size, n_protein_patches, n_ligand_patches, self.pair_dim + ) + gather_idx_I_I = torch.arange( + batch_size * n_ligand_patches, device=AJ_grid_attr.device + ) + gather_idx_a_a = torch.arange(batch_size * n_a_per_sample, device=AJ_grid_attr.device) + # Note: off-diagonal (AJ) blocks are zero-initialized in the prior stack + indexer["gather_idx_IJ_I"] = ( + gather_idx_I_I.view(batch_size, n_ligand_patches)[:, :, None] + .expand(-1, -1, n_ligand_patches) + .contiguous() + .flatten() + ) + indexer["gather_idx_IJ_J"] = ( + gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :] + .expand(-1, n_ligand_patches, -1) + .contiguous() + .flatten() + ) + indexer["gather_idx_AJ_a"] = ( + indexer["gather_idx_pid_a"] + .view(batch_size, n_protein_patches)[:, :, None] + .expand(-1, -1, n_ligand_patches) + .contiguous() + .flatten() + ) + indexer["gather_idx_AJ_J"] = ( + gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :] + .expand(-1, n_protein_patches, -1) + .contiguous() + .flatten() + ) + indexer["gather_idx_aJ_a"] = ( + gather_idx_a_a.view(batch_size, n_a_per_sample)[:, :, None] + .expand(-1, -1, n_ligand_patches) + .contiguous() + .flatten() + ) + indexer["gather_idx_aJ_J"] = ( + gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :] + .expand(-1, n_a_per_sample, -1) + .contiguous() + .flatten() + ) + batch["indexer"] = indexer + + if observed_block_contacts is not None: + # Generative feedback from block one-hot sampling + # AJ_grid_attr = ( + # AJ_grid_attr + # + observed_block_contacts.transpose(1, 2) + # .contiguous() + # .flatten(0, 1)[indexer["gather_idx_I_molid"]] + # .view(batch_size, n_ligand_patches, n_protein_patches, -1) + # .transpose(1, 2) + # .contiguous() + # ) + AJ_grid_attr = AJ_grid_attr + observed_block_contacts + + graph_batcher = make_multi_relation_graph_batcher( + self.graph_relations, indexer, metadata + ) + merged_edge_idx = graph_batcher.collate_idx_list(indexer) + node_reps = { + "prot_res": residue_rep, + "lig_trp": lig_frame_rep, + } + edge_reps = { + "residue_to_residue": rec_pair_rep, + "sampled_residue_to_sampled_residue": AB_grid_attr_flat, + "sampled_lig_triplet_to_sampled_residue": AJ_grid_attr.flatten(0, 2), + "sampled_residue_to_sampled_lig_triplet": AJ_grid_attr.flatten(0, 2), + "sampled_lig_triplet_to_residue": aJ_grid_attr.flatten(0, 2), + "residue_to_sampled_lig_triplet": aJ_grid_attr.flatten(0, 2), + "sampled_lig_triplet_to_sampled_lig_triplet": IJ_grid_attr.flatten(0, 2), + } + edge_reps = graph_batcher.zero_pad_edge_attr(edge_reps, self.dim, device) + else: + graph_batcher = make_multi_relation_graph_batcher( + self.graph_relations[:2], indexer, metadata + ) + merged_edge_idx = graph_batcher.collate_idx_list(indexer) + + node_reps = { + "prot_res": residue_rep, + } + edge_reps = { + "residue_to_residue": rec_pair_rep, + "sampled_residue_to_sampled_residue": AB_grid_attr_flat, + } + edge_reps = graph_batcher.zero_pad_edge_attr(edge_reps, self.dim, device) + + # Intertwine graph iterations and triangle iterations + gather_idx_res_protpatch = indexer["gather_idx_a_pid"] + for block_id in range(self.n_blocks): + # Communicate between atomistic and patch resolutions + # Up-sampling for interface edge embeddings + rec_pair_rep = edge_reps["residue_to_residue"] + AB_grid_attr_flat = edge_reps["sampled_residue_to_sampled_residue"] + AB_grid_attr = AB_grid_attr_flat.view( + batch_size, + n_protein_patches, + n_protein_patches, + self.pair_dim, + ) + + if not batch["misc"]["protein_only"]: + # Symmetrize off-diagonal blocks + AJ_grid_attr_flat_ = ( + edge_reps["sampled_residue_to_sampled_lig_triplet"] + + edge_reps["sampled_lig_triplet_to_sampled_residue"] + ) / 2 + AJ_grid_attr = AJ_grid_attr_flat_.contiguous().view( + batch_size, n_protein_patches, n_ligand_patches, -1 + ) + aJ_grid_attr_flat_ = ( + edge_reps["residue_to_sampled_lig_triplet"] + + edge_reps["sampled_lig_triplet_to_residue"] + ) / 2 + aJ_grid_attr = aJ_grid_attr_flat_.contiguous().view( + batch_size, n_a_per_sample, n_ligand_patches, -1 + ) + IJ_grid_attr = ( + edge_reps["sampled_lig_triplet_to_sampled_lig_triplet"] + .contiguous() + .view(batch_size, n_ligand_patches, n_ligand_patches, -1) + ) + AJ_grid_attr_temp_, aJ_grid_attr_temp_ = self.AaJ_mha( + AJ_grid_attr.flatten(0, 1), + aJ_grid_attr.flatten(0, 1), + ( + gather_idx_res_protpatch, + torch.arange(gather_idx_res_protpatch.shape[0], device=device), + ), + ) + AJ_grid_attr = AJ_grid_attr_temp_.contiguous().view( + batch_size, n_protein_patches, n_ligand_patches, -1 + ) + aJ_grid_attr = aJ_grid_attr_temp_.contiguous().view( + batch_size, n_a_per_sample, n_ligand_patches, -1 + ) + merged_grid_rep = torch.cat( + [ + torch.cat([AB_grid_attr, AJ_grid_attr], dim=2), + torch.cat([AJ_grid_attr.transpose(1, 2), IJ_grid_attr], dim=2), + ], + dim=1, + ) + else: + merged_grid_rep = AB_grid_attr + + # Inter-patch triangle attentions + _, merged_grid_rep = self.triangle_stacks[block_id]( + merged_grid_rep, + merged_grid_rep, + merged_grid_rep.unsqueeze(-4), + ) + + # Dis-assemble the grid representation + AB_grid_attr = merged_grid_rep[:, :n_protein_patches, :n_protein_patches] + # Transfer grid-formatted representations to edges + edge_reps["residue_to_residue"] = rec_pair_rep + edge_reps["sampled_residue_to_sampled_residue"] = AB_grid_attr.flatten(0, 2) + + if not batch["misc"]["protein_only"]: + AJ_grid_attr = merged_grid_rep[ + :, :n_protein_patches, n_protein_patches: + ].contiguous() + IJ_grid_attr = merged_grid_rep[ + :, n_protein_patches:, n_protein_patches: + ].contiguous() + + edge_reps["sampled_residue_to_sampled_lig_triplet"] = AJ_grid_attr.flatten(0, 2) + edge_reps["sampled_lig_triplet_to_sampled_residue"] = AJ_grid_attr.flatten(0, 2) + edge_reps["residue_to_sampled_lig_triplet"] = aJ_grid_attr.flatten(0, 2) + edge_reps["sampled_lig_triplet_to_residue"] = aJ_grid_attr.flatten(0, 2) + edge_reps["sampled_lig_triplet_to_sampled_lig_triplet"] = IJ_grid_attr.flatten( + 0, 2 + ) + merged_node_reps = graph_batcher.collate_node_attr(node_reps) + merged_edge_reps = graph_batcher.collate_edge_attr(edge_reps) + + # Graph transformer iteration + _, merged_node_reps, merged_edge_reps = self.graph_stacks[block_id]( + merged_node_reps, + merged_node_reps, + merged_edge_idx, + merged_edge_reps, + ) + node_reps = graph_batcher.offload_node_attr(merged_node_reps) + edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps) + + batch["features"][f"rec_res_attr{out_attr_suffix}"] = node_reps["prot_res"] + batch["features"][f"res_res_pair_attr{out_attr_suffix}"] = edge_reps["residue_to_residue"] + batch["features"][f"res_res_grid_attr_flat{out_attr_suffix}"] = edge_reps[ + "sampled_residue_to_sampled_residue" + ] + if not batch["misc"]["protein_only"]: + batch["features"][f"lig_trp_attr{out_attr_suffix}"] = node_reps["lig_trp"] + batch["features"][f"res_trp_grid_attr_flat{out_attr_suffix}"] = edge_reps[ + "sampled_residue_to_sampled_lig_triplet" + ] + batch["features"][f"res_trp_pair_attr_flat{out_attr_suffix}"] = edge_reps[ + "residue_to_sampled_lig_triplet" + ] + batch["features"][f"trp_trp_grid_attr_flat{out_attr_suffix}"] = edge_reps[ + "sampled_lig_triplet_to_sampled_lig_triplet" + ] + batch["metadata"]["n_lig_patches_per_sample"] = n_ligand_patches + return batch + + +def resolve_protein_encoder( + protein_model_cfg: DictConfig, + task_cfg: DictConfig, + state_dict: Optional[STATE_DICT] = None, +) -> Tuple[torch.nn.Module, torch.nn.Module]: + """Instantiates a ProtFormer model for protein encoding. + + :param protein_model_cfg: Protein model configuration. + :param task_cfg: Task configuration. + :param state_dict: Optional (potentially-pretrained) state dictionary. + :return: Protein encoder model and residue input projector. + """ + node_dim = protein_model_cfg.residue_dim + model = ProtFormer( + node_dim, + protein_model_cfg.pair_dim, + n_heads=protein_model_cfg.n_heads, + n_blocks=protein_model_cfg.n_encoder_stacks, + dropout=task_cfg.dropout, + ) + if protein_model_cfg.use_esm_embedding: + # protein sequence language model + res_in_projector = torch.nn.Linear(protein_model_cfg.plm_embed_dim, node_dim, bias=False) + else: + # one-hot amino acid types + res_in_projector = torch.nn.Linear( + protein_model_cfg.n_aa_types, + node_dim, + bias=False, + ) + if protein_model_cfg.from_pretrained and state_dict is not None: + try: + # NOTE: we must avoid enforcing strict key matching + # due to the (new) weights `template_binding_site_enc.weight` + model.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("protein_encoder") + }, + strict=False, + ) + log.info("Successfully loaded pretrained protein encoder weights.") + except Exception as e: + log.warning(f"Skipping loading of pretrained protein encoder weights due to: {e}.") + + try: + res_in_projector.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith( + "plm_adapter" + if protein_model_cfg.use_esm_embedding + else "res_in_projector" + ) + } + ) + log.info("Successfully loaded pretrained protein input projector weights.") + except Exception as e: + log.warning( + f"Skipping loading of pretrained protein input projector weights due to: {e}." + ) + return model, res_in_projector + + +def resolve_pl_contact_stack( + protein_model_cfg: DictConfig, + ligand_model_cfg: DictConfig, + contact_cfg: DictConfig, + task_cfg: DictConfig, + state_dict: Optional[STATE_DICT] = None, +) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.nn.Module]: + """Instantiates a BindingFormer model for protein-ligand contact prediction. + + :param protein_model_cfg: Protein model configuration. + :param ligand_model_cfg: Ligand model configuration. + :param contact_cfg: Contact prediction configuration. + :param task_cfg: Task configuration. + :param state_dict: Optional (potentially-pretrained) state dictionary. + :return: Protein-ligand contact prediction model, contact code embedding, distance bins, and + distogram head. + """ + pl_contact_stack = BindingFormer( + protein_model_cfg.residue_dim, + protein_model_cfg.pair_dim, + n_heads=protein_model_cfg.n_heads, + n_blocks=contact_cfg.n_stacks, + n_ligand_patches=ligand_model_cfg.n_patches, + dropout=contact_cfg.dropout if contact_cfg.get("dropout") else task_cfg.dropout, + ) + contact_code_embed = torch.nn.Embedding(2, protein_model_cfg.pair_dim) + # Distogram heads + dist_bins = torch.nn.Parameter(torch.linspace(2, 22, 32), requires_grad=False) + dgram_head = GELUMLP( + protein_model_cfg.pair_dim, + 32, + n_hidden_feats=protein_model_cfg.pair_dim, + zero_init=True, + ) + if contact_cfg.from_pretrained and state_dict is not None: + try: + # NOTE: we must avoid enforcing strict key matching + # due to the (new) weights `template_binding_site_enc.weight` + pl_contact_stack.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("pl_contact_stack") + }, + strict=False, + ) + log.info("Successfully loaded pretrained protein-ligand contact prediction weights.") + except Exception as e: + log.warning( + f"Skipping loading of pretrained protein-ligand contact prediction weights due to: {e}." + ) + + try: + contact_code_embed.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("contact_code_embed") + } + ) + log.info("Successfully loaded pretrained contact code embedding weights.") + except Exception as e: + log.warning( + f"Skipping loading of pretrained contact code embedding weights due to: {e}." + ) + + try: + dgram_head.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("dgram_head") + } + ) + log.info("Successfully loaded pretrained distogram head weights.") + except Exception as e: + log.warning(f"Skipping loading of pretrained distogram head weights due to: {e}.") + return pl_contact_stack, contact_code_embed, dist_bins, dgram_head diff --git a/flowdock/models/components/embedding.py b/flowdock/models/components/embedding.py new file mode 100644 index 0000000..016ec3c --- /dev/null +++ b/flowdock/models/components/embedding.py @@ -0,0 +1,105 @@ +# Adapted from: https://github.com/zrqiao/NeuralPLexer + +import math + +import rootutils +import torch +from beartype.typing import Tuple + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils.frame_utils import RigidTransform + + +class GaussianFourierEncoding1D(torch.nn.Module): + """Gaussian Fourier Encoding for 1D data.""" + + def __init__( + self, + n_basis: int, + eps: float = 1e-2, + ): + """Initialize Gaussian Fourier Encoding.""" + super().__init__() + self.eps = eps + self.fourier_freqs = torch.nn.Parameter(torch.randn(n_basis) * math.pi) + + def forward( + self, + x: torch.Tensor, + ): + """Forward pass of Gaussian Fourier Encoding.""" + encodings = torch.cat( + [ + torch.sin(self.fourier_freqs.mul(x)), + torch.cos(self.fourier_freqs.mul(x)), + ], + dim=-1, + ) + return encodings + + +class GaussianRBFEncoding1D(torch.nn.Module): + """Gaussian RBF Encoding for 1D data.""" + + def __init__( + self, + n_basis: int, + x_max: float, + sigma: float = 1.0, + ): + """Initialize Gaussian RBF Encoding.""" + super().__init__() + self.sigma = sigma + self.rbf_centers = torch.nn.Parameter( + torch.linspace(0, x_max, n_basis), requires_grad=False + ) + + def forward( + self, + x: torch.Tensor, + ): + """Forward pass of Gaussian RBF Encoding.""" + encodings = torch.exp(-((x.unsqueeze(-1) - self.rbf_centers).div(self.sigma).square())) + return encodings + + +class RelativeGeometryEncoding(torch.nn.Module): + "Compute radial basis functions and iterresidue/pseudoresidue orientations." + + def __init__(self, n_basis: int, out_dim: int, d_max: float = 20.0): + """Initialize RelativeGeometryEncoding.""" + super().__init__() + self.rbf_encoding = GaussianRBFEncoding1D(n_basis, d_max) + self.rel_geom_projector = torch.nn.Linear(n_basis + 15, out_dim, bias=False) + + def forward(self, frames: RigidTransform, merged_edge_idx: Tuple[torch.Tensor, torch.Tensor]): + """Forward pass of RelativeGeometryEncoding.""" + frame_t, frame_R = frames.t, frames.R + pair_dists = torch.norm( + frame_t[merged_edge_idx[0]] - frame_t[merged_edge_idx[1]], + dim=-1, + ) + pair_directions_l = torch.matmul( + (frame_t[merged_edge_idx[1]] - frame_t[merged_edge_idx[0]]).unsqueeze(-2), + frame_R[merged_edge_idx[0]], + ).squeeze(-2) / pair_dists.square().add(1).sqrt().unsqueeze(-1) + pair_directions_r = torch.matmul( + (frame_t[merged_edge_idx[0]] - frame_t[merged_edge_idx[1]]).unsqueeze(-2), + frame_R[merged_edge_idx[1]], + ).squeeze(-2) / pair_dists.square().add(1).sqrt().unsqueeze(-1) + pair_orientations = torch.matmul( + frame_R.transpose(-2, -1).contiguous()[merged_edge_idx[0]], + frame_R[merged_edge_idx[1]], + ) + return self.rel_geom_projector( + torch.cat( + [ + self.rbf_encoding(pair_dists), + pair_directions_l, + pair_directions_r, + pair_orientations.flatten(-2, -1), + ], + dim=-1, + ) + ) diff --git a/flowdock/models/components/esdm.py b/flowdock/models/components/esdm.py new file mode 100644 index 0000000..ea14b00 --- /dev/null +++ b/flowdock/models/components/esdm.py @@ -0,0 +1,884 @@ +# Adapted from: https://github.com/zrqiao/NeuralPLexer + +import rootutils +import torch +from beartype.typing import Any, Dict, Optional, Tuple +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.models.components.embedding import ( + GaussianFourierEncoding1D, + RelativeGeometryEncoding, +) +from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher +from flowdock.models.components.modules import PointSetAttention +from flowdock.utils import RankedLogger +from flowdock.utils.frame_utils import RigidTransform, get_frame_matrix +from flowdock.utils.model_utils import GELUMLP, AveragePooling, SumPooling, segment_mean + +STATE_DICT = Dict[str, Any] + +log = RankedLogger(__name__, rank_zero_only=True) + + +class LocalUpdateUsingReferenceRotations(torch.nn.Module): + """Update local geometric representations using reference rotations.""" + + def __init__( + self, + fiber_dim: int, + extra_feat_dim: int = 0, + eps: float = 1e-4, + dropout: float = 0.0, + hidden_dim: Optional[int] = None, + zero_init: bool = False, + ): + """Initialize the LocalUpdateUsingReferenceRotations module.""" + super().__init__() + self.dim = fiber_dim * 5 + extra_feat_dim + self.fiber_dim = fiber_dim + self.mlp = GELUMLP( + self.dim, + fiber_dim * 4, + dropout=dropout, + zero_init=zero_init, + n_hidden_feats=hidden_dim, + ) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + rotation_mats: torch.Tensor, + extra_feats=None, + ): + """Forward pass of the LocalUpdateUsingReferenceRotations module.""" + # Vector norms are evaluated without applying rigid transform + vecx_local = torch.matmul( + x[:, 1:].transpose(-2, -1), + rotation_mats, + ) + x1_local = torch.cat( + [ + x[:, 0], + vecx_local.flatten(-2, -1), + x[:, 1:].square().sum(dim=-2).add(self.eps).sqrt(), + ], + dim=-1, + ) + if extra_feats is not None: + x1_local = torch.cat([x1_local, extra_feats], dim=-1) + x1_local = self.mlp(x1_local).view(-1, 4, self.fiber_dim) + vecx1_out = torch.matmul( + rotation_mats, + x1_local[:, 1:], + ) + x1_out = torch.cat([x1_local[:, :1], vecx1_out], dim=-2) + return x1_out + + +class LocalUpdateUsingChannelWiseGating(torch.nn.Module): + """Update local geometric representations using channel-wise gating.""" + + def __init__( + self, + fiber_dim: int, + eps: float = 1e-4, + dropout: float = 0.0, + hidden_dim: Optional[int] = None, + zero_init: bool = False, + ): + """Initialize the LocalUpdateUsingChannelWiseGating module.""" + super().__init__() + self.dim = fiber_dim * 2 + self.fiber_dim = fiber_dim + self.mlp = GELUMLP( + self.dim, + self.dim, + dropout=dropout, + n_hidden_feats=hidden_dim, + zero_init=zero_init, + ) + self.gate = torch.nn.Sigmoid() + self.lin_out = torch.nn.Linear(fiber_dim, fiber_dim, bias=False) + if zero_init: + self.lin_out.weight.data.fill_(0.0) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + """Forward pass of the LocalUpdateUsingChannelWiseGating module.""" + x1 = torch.cat( + [ + x[:, 0], + x[:, 1:].square().sum(dim=-2).add(self.eps).sqrt(), + ], + dim=-1, + ) + x1 = self.mlp(x1) + # Gated nonlinear operation on l=1 representations + x1_scalar, x1_gatein = torch.split(x1, self.fiber_dim, dim=-1) + x1_gate = self.gate(x1_gatein).unsqueeze(-2) + vecx1_out = self.lin_out(x[:, 1:]).mul(x1_gate) + x1_out = torch.cat([x1_scalar.unsqueeze(-2), vecx1_out], dim=-2) + return x1_out + + +class EquivariantTransformerBlock(torch.nn.Module): + """Equivariant Transformer Block module.""" + + def __init__( + self, + fiber_dim: int, + heads: int = 8, + point_dim: int = 4, + eps: float = 1e-4, + edge_dim: Optional[int] = None, + target_frames: bool = False, + edge_update: bool = False, + dropout: float = 0.0, + ): + """Initialize the EquivariantTransformerBlock module.""" + super().__init__() + self.attn_conv = PointSetAttention( + fiber_dim, + heads=heads, + point_dim=point_dim, + edge_dim=edge_dim, + edge_update=edge_update, + ) + self.fiber_dim = fiber_dim + self.target_frames = target_frames + self.eps = eps + self.edge_update = edge_update + if target_frames: + self.local_update = LocalUpdateUsingReferenceRotations( + fiber_dim, eps=eps, dropout=dropout, zero_init=True + ) + else: + self.local_update = LocalUpdateUsingChannelWiseGating( + fiber_dim, eps=eps, dropout=dropout, zero_init=True + ) + + def forward( + self, + x: torch.Tensor, + edge_index: torch.LongTensor, + t: torch.Tensor, + R: torch.Tensor = None, + x_edge: torch.Tensor = None, + ): + """Forward pass of the EquivariantTransformerBlock module.""" + if self.edge_update: + xout, edge_out = self.attn_conv(x, x, edge_index, t, t, x_edge=x_edge) + x_edge = x_edge + edge_out + else: + xout = self.attn_conv(x, x, edge_index, t, t, x_edge=x_edge) + x = x + xout + if self.target_frames: + x = self.local_update(x, R) + x + else: + x = self.local_update(x) + x + return x, x_edge + + +class EquivariantStructureDenoisingModule(torch.nn.Module): + """Equivariant Structure Denoising Module.""" + + def __init__( + self, + fiber_dim: int, + input_dim: int, + input_pair_dim: int, + hidden_dim: int = 1024, + n_stacks: int = 4, + n_heads: int = 8, + dropout: float = 0.0, + ): + """Initialize the EquivariantStructureDenoisingModule module.""" + super().__init__() + self.input_dim = input_dim + self.input_pair_dim = input_pair_dim + self.fiber_dim = fiber_dim + self.protatm_padding_dim = 37 + self.n_blocks = n_stacks + self.input_node_projector = torch.nn.Linear(input_dim, fiber_dim, bias=False) + self.input_node_vec_projector = torch.nn.Linear(input_dim, fiber_dim * 3, bias=False) + self.input_pair_projector = torch.nn.Linear(input_pair_dim, fiber_dim, bias=False) + # Inherit the residue representations + self.atm_embed = GELUMLP(input_dim + 32, fiber_dim) + self.ipa_modules = torch.nn.ModuleList( + [ + EquivariantTransformerBlock( + fiber_dim, + heads=n_heads, + point_dim=fiber_dim // (n_heads * 2), + edge_dim=fiber_dim, + target_frames=True, + edge_update=True, + dropout=dropout, + ) + for _ in range(n_stacks) + ] + ) + self.res_adapters = torch.nn.ModuleList( + [ + LocalUpdateUsingReferenceRotations( + fiber_dim, + extra_feat_dim=input_dim, + hidden_dim=hidden_dim, + dropout=dropout, + zero_init=True, + ) + for _ in range(n_stacks) + ] + ) + self.protatm_type_encoding = GELUMLP(self.protatm_padding_dim + input_dim, input_pair_dim) + self.time_encoding = GaussianFourierEncoding1D(16) + self.rel_geom_enc = RelativeGeometryEncoding(15, fiber_dim) + self.rel_geom_embed = GELUMLP(fiber_dim, fiber_dim, n_hidden_feats=fiber_dim) + # [displacement, scale] + self.out_drift_res = torch.nn.ModuleList( + [torch.nn.Linear(fiber_dim, 1, bias=False) for _ in range(n_stacks)] + ) + # for i in range(n_stacks): + # self.out_drift_res[i].weight.data.fill_(0.0) + self.out_scale_res = torch.nn.ModuleList( + [GELUMLP(fiber_dim, 1, zero_init=True) for _ in range(n_stacks)] + ) + self.out_drift_atm = torch.nn.ModuleList( + [torch.nn.Linear(fiber_dim, 1, bias=False) for _ in range(n_stacks)] + ) + # for i in range(n_stacks): + # self.out_drift_atm[i].weight.data.fill_(0.0) + self.out_scale_atm = torch.nn.ModuleList( + [GELUMLP(fiber_dim, 1, zero_init=True) for _ in range(n_stacks)] + ) + + # Pre-tabulated edges + self.graph_relations = [ + ( + "residue_to_residue", + "gather_idx_ab_a", + "gather_idx_ab_b", + "prot_res", + "prot_res", + ), + ( + "sampled_residue_to_sampled_residue", + "gather_idx_AB_a", + "gather_idx_AB_b", + "prot_res", + "prot_res", + ), + ( + "prot_atm_to_prot_atm_graph", + "protatm_protatm_idx_src", + "protatm_protatm_idx_dst", + "prot_atm", + "prot_atm", + ), + ( + "prot_atm_to_prot_atm_knn", + "knn_idx_protatm_protatm_src", + "knn_idx_protatm_protatm_dst", + "prot_atm", + "prot_atm", + ), + ( + "prot_atm_to_residue", + "protatm_res_idx_protatm", + "protatm_res_idx_res", + "prot_atm", + "prot_res", + ), + ( + "residue_to_prot_atm", + "protatm_res_idx_res", + "protatm_res_idx_protatm", + "prot_res", + "prot_atm", + ), + ( + "sampled_lig_triplet_to_lig_atm", + "gather_idx_UI_I", + "gather_idx_UI_u", + "lig_trp", + "lig_atm", + ), + ( + "lig_atm_to_sampled_lig_triplet", + "gather_idx_UI_u", + "gather_idx_UI_I", + "lig_atm", + "lig_trp", + ), + ( + "lig_atm_to_lig_atm_graph", + "gather_idx_uv_u", + "gather_idx_uv_v", + "lig_atm", + "lig_atm", + ), + ( + "sampled_residue_to_sampled_lig_triplet", + "gather_idx_AJ_a", + "gather_idx_AJ_J", + "prot_res", + "lig_trp", + ), + ( + "sampled_lig_triplet_to_sampled_residue", + "gather_idx_AJ_J", + "gather_idx_AJ_a", + "lig_trp", + "prot_res", + ), + ( + "residue_to_sampled_lig_triplet", + "gather_idx_aJ_a", + "gather_idx_aJ_J", + "prot_res", + "lig_trp", + ), + ( + "sampled_lig_triplet_to_residue", + "gather_idx_aJ_J", + "gather_idx_aJ_a", + "lig_trp", + "prot_res", + ), + ( + "sampled_lig_triplet_to_sampled_lig_triplet", + "gather_idx_IJ_I", + "gather_idx_IJ_J", + "lig_trp", + "lig_trp", + ), + ( + "prot_atm_to_lig_atm_knn", + "knn_idx_protatm_ligatm_src", + "knn_idx_protatm_ligatm_dst", + "prot_atm", + "lig_atm", + ), + ( + "lig_atm_to_prot_atm_knn", + "knn_idx_ligatm_protatm_src", + "knn_idx_ligatm_protatm_dst", + "lig_atm", + "prot_atm", + ), + ( + "lig_atm_to_lig_atm_knn", + "knn_idx_ligatm_ligatm_src", + "knn_idx_ligatm_ligatm_dst", + "lig_atm", + "lig_atm", + ), + ] + self.graph_relations_no_ligand = self.graph_relations[:6] + + def init_scalar_vec_rep(self, x, x_v=None, frame=None): + """Initialize scalar and vector representations.""" + if frame is None: + # Zero-initialize the vector channels + vec_shape = (*x.shape[:-1], 3, x.shape[-1]) + res = torch.cat([x.unsqueeze(-2), torch.zeros(vec_shape, device=x.device)], dim=-2) + else: + x_v = x_v.view(*x.shape[:-1], 3, x.shape[-1]) + x_v_glob = torch.matmul(frame.R, x_v) + res = torch.cat([x.unsqueeze(-2), x_v_glob], dim=-2) + return res + + def forward( + self, + batch, + frozen_lig=False, + frozen_prot=False, + **kwargs, + ): + """Forward pass of the EquivariantStructureDenoisingModule module.""" + features = batch["features"] + indexer = batch["indexer"] + metadata = batch["metadata"] + metadata["num_structid"] + max(metadata["num_a_per_sample"]) + + prot_res_rep_in = features["rec_res_attr_decin"] + timestep_prot = features["timestep_encoding_prot"] + device = features["res_type"].device + + # Protein all-atom representation initialization + protatm_padding_mask = features["res_atom_mask"] + protatm_atom37_onehot = torch.nn.functional.one_hot( + features["protatm_to_atom37_index"], num_classes=self.protatm_padding_dim + ) + protatm_res_pair_encoding = self.protatm_type_encoding( + torch.cat( + [ + prot_res_rep_in[indexer["protatm_res_idx_res"]], + protatm_atom37_onehot, + ], + dim=-1, + ) + ) + # Gathered AA features from individual graphs + prot_atm_rep_in = features["prot_atom_attr_projected"] + prot_atm_rep_int = self.atm_embed( + torch.cat( + [ + prot_atm_rep_in, + self.time_encoding(timestep_prot)[indexer["protatm_res_idx_res"]], + ], + dim=-1, + ) + ) + + prot_atm_coords_padded = features["input_protein_coords"] + prot_atm_coords_flat = prot_atm_coords_padded[protatm_padding_mask] + + # Embed the rigid body node representations + backbone_frames = get_frame_matrix( + prot_atm_coords_padded[:, 0], + prot_atm_coords_padded[:, 1], + prot_atm_coords_padded[:, 2], + ) + prot_res_rep = self.init_scalar_vec_rep( + self.input_node_projector(prot_res_rep_in), + x_v=self.input_node_vec_projector(prot_res_rep_in), + frame=backbone_frames, + ) + prot_atm_rep = self.init_scalar_vec_rep(prot_atm_rep_int) + # gather AA features from individual graphs + node_reps = { + "prot_res": prot_res_rep, + "prot_atm": prot_atm_rep, + } + # Embed pair representations + edge_reps = { + "residue_to_residue": features["res_res_pair_attr_decin"], + "prot_atm_to_prot_atm_graph": features["prot_atom_pair_attr_projected"], + "prot_atm_to_prot_atm_knn": features["knn_feat_protatm_protatm"], + "prot_atm_to_residue": protatm_res_pair_encoding, + "residue_to_prot_atm": protatm_res_pair_encoding, + "sampled_residue_to_sampled_residue": features["res_res_grid_attr_flat_decin"], + } + + if not batch["misc"]["protein_only"]: + max(metadata["num_i_per_sample"]) + timestep_lig = features["timestep_encoding_lig"] + lig_atm_rep_in = features["lig_atom_attr_projected"] + lig_frame_rep_in = features["lig_trp_attr_decin"] + # Ligand atom embedding. Two timescales + lig_atm_rep_int = self.atm_embed( + torch.cat( + [lig_atm_rep_in, self.time_encoding(timestep_lig)], + dim=-1, + ) + ) + lig_atm_rep = self.init_scalar_vec_rep(lig_atm_rep_int) + + # Prepare ligand atom - sidechain atom indexers + # Initialize coordinate features + lig_atm_coords = features["input_ligand_coords"].clone() + + lig_frame_atm_idx = ( + indexer["gather_idx_ijk_i"][indexer["gather_idx_I_ijk"]], + indexer["gather_idx_ijk_j"][indexer["gather_idx_I_ijk"]], + indexer["gather_idx_ijk_k"][indexer["gather_idx_I_ijk"]], + ) + ligand_trp_frames = get_frame_matrix( + lig_atm_coords[lig_frame_atm_idx[0]], + lig_atm_coords[lig_frame_atm_idx[1]], + lig_atm_coords[lig_frame_atm_idx[2]], + ) + lig_frame_rep = self.init_scalar_vec_rep( + self.input_node_projector(lig_frame_rep_in), + x_v=self.input_node_vec_projector(lig_frame_rep_in), + frame=ligand_trp_frames, + ) + node_reps.update( + { + "lig_atm": lig_atm_rep, + "lig_trp": lig_frame_rep, + } + ) + edge_reps.update( + { + "lig_atm_to_lig_atm_graph": features["lig_atom_pair_attr_projected"], + "sampled_lig_triplet_to_lig_atm": features["lig_af_pair_attr_projected"], + "lig_atm_to_sampled_lig_triplet": features["lig_af_pair_attr_projected"], + "sampled_residue_to_sampled_lig_triplet": features[ + "res_trp_grid_attr_flat_decin" + ], + "sampled_lig_triplet_to_sampled_residue": features[ + "res_trp_grid_attr_flat_decin" + ], + "residue_to_sampled_lig_triplet": features["res_trp_pair_attr_flat_decin"], + "sampled_lig_triplet_to_residue": features["res_trp_pair_attr_flat_decin"], + "sampled_lig_triplet_to_sampled_lig_triplet": features[ + "trp_trp_grid_attr_flat_decin" + ], + "prot_atm_to_lig_atm_knn": features["knn_feat_protatm_ligatm"], + "lig_atm_to_prot_atm_knn": features["knn_feat_ligatm_protatm"], + "lig_atm_to_lig_atm_knn": features["knn_feat_ligatm_ligatm"], + } + ) + + # Message passing + protatm_res_idx_res = indexer["protatm_res_idx_res"] + if batch["misc"]["protein_only"]: + graph_relations = self.graph_relations_no_ligand + else: + graph_relations = self.graph_relations + graph_batcher = make_multi_relation_graph_batcher(graph_relations, indexer, metadata) + merged_edge_idx = graph_batcher.collate_idx_list(indexer) + merged_node_reps = graph_batcher.collate_node_attr(node_reps) + merged_edge_reps = graph_batcher.collate_edge_attr( + graph_batcher.zero_pad_edge_attr(edge_reps, self.input_pair_dim, device) + ) + merged_edge_reps = self.input_pair_projector(merged_edge_reps) + assert merged_edge_idx[0].shape[0] == merged_edge_reps.shape[0] + assert merged_edge_idx[1].shape[0] == merged_edge_reps.shape[0] + + dummy_prot_atm_frames = RigidTransform(prot_atm_coords_flat, R=None) + if not batch["misc"]["protein_only"]: + dummy_lig_atm_frames = RigidTransform(lig_atm_coords, R=None) + merged_node_t = graph_batcher.collate_node_attr( + { + "prot_res": backbone_frames.t, + "prot_atm": dummy_prot_atm_frames.t, + "lig_atm": dummy_lig_atm_frames.t, + "lig_trp": ligand_trp_frames.t, + } + ) + merged_node_R = graph_batcher.collate_node_attr( + { + "prot_res": backbone_frames.R, + "prot_atm": dummy_prot_atm_frames.R, + "lig_atm": dummy_lig_atm_frames.R, + "lig_trp": ligand_trp_frames.R, + } + ) + else: + merged_node_t = graph_batcher.collate_node_attr( + { + "prot_res": backbone_frames.t, + "prot_atm": dummy_prot_atm_frames.t, + } + ) + merged_node_R = graph_batcher.collate_node_attr( + { + "prot_res": backbone_frames.R, + "prot_atm": dummy_prot_atm_frames.R, + } + ) + merged_node_frames = RigidTransform(merged_node_t, merged_node_R) + merged_edge_reps = merged_edge_reps + ( + self.rel_geom_embed( + self.rel_geom_enc(merged_node_frames, merged_edge_idx) + merged_edge_reps + ) + ) + + # No need to reassign embeddings but need to update point coordinates & frames + for block_id in range(self.n_blocks): + dummy_prot_atm_frames = RigidTransform(prot_atm_coords_flat, R=None) + if not batch["misc"]["protein_only"]: + dummy_lig_atm_frames = RigidTransform(lig_atm_coords, R=None) + merged_node_t = graph_batcher.collate_node_attr( + { + "prot_res": backbone_frames.t, + "prot_atm": dummy_prot_atm_frames.t, + "lig_atm": dummy_lig_atm_frames.t, + "lig_trp": ligand_trp_frames.t, + } + ) + merged_node_R = graph_batcher.collate_node_attr( + { + "prot_res": backbone_frames.R, + "prot_atm": dummy_prot_atm_frames.R, + "lig_atm": dummy_lig_atm_frames.R, + "lig_trp": ligand_trp_frames.R, + } + ) + else: + merged_node_t = graph_batcher.collate_node_attr( + { + "prot_res": backbone_frames.t, + "prot_atm": dummy_prot_atm_frames.t, + } + ) + merged_node_R = graph_batcher.collate_node_attr( + { + "prot_res": backbone_frames.R, + "prot_atm": dummy_prot_atm_frames.R, + } + ) + # PredictDrift iteration + merged_node_reps, merged_edge_reps = self.ipa_modules[block_id]( + merged_node_reps, + merged_edge_idx, + t=merged_node_t, + R=merged_node_R, + x_edge=merged_edge_reps, + ) + offloaded_node_reps = graph_batcher.offload_node_attr(merged_node_reps) + if "lig_trp" in offloaded_node_reps.keys(): + lig_frame_rep = offloaded_node_reps["lig_trp"] + offloaded_node_reps["lig_trp"] = lig_frame_rep + self.res_adapters[block_id]( + lig_frame_rep, ligand_trp_frames.R, extra_feats=lig_frame_rep_in + ) + prot_res_rep = offloaded_node_reps["prot_res"] + offloaded_node_reps["prot_res"] = prot_res_rep + self.res_adapters[block_id]( + prot_res_rep, backbone_frames.R, extra_feats=prot_res_rep_in + ) + merged_node_reps = graph_batcher.collate_node_attr(offloaded_node_reps) + # Displacement vectors in the global coordinate system + if not batch["misc"]["protein_only"]: + drift_trp = ( + self.out_drift_res[block_id](offloaded_node_reps["lig_trp"][:, 1:]).squeeze(-1) + * torch.sigmoid( + self.out_scale_res[block_id](offloaded_node_reps["lig_trp"][:, 0]) + ) + * 10 + ) + drift_trp_gathered = segment_mean( + drift_trp, + indexer["gather_idx_I_molid"], + metadata["num_molid"], + )[indexer["gather_idx_i_molid"]] + drift_atm = self.out_drift_atm[block_id]( + offloaded_node_reps["lig_atm"][:, 1:] + ).squeeze(-1) * torch.sigmoid( + self.out_scale_atm[block_id](offloaded_node_reps["lig_atm"][:, 0]) + ) + if not frozen_lig: + lig_atm_coords = lig_atm_coords + drift_atm + drift_trp_gathered + ligand_trp_frames = get_frame_matrix( + lig_atm_coords[lig_frame_atm_idx[0]], + lig_atm_coords[lig_frame_atm_idx[1]], + lig_atm_coords[lig_frame_atm_idx[2]], + ) + + drift_bb = ( + self.out_drift_res[block_id](offloaded_node_reps["prot_res"][:, 1:]).squeeze(-1) + * torch.sigmoid( + self.out_scale_res[block_id](offloaded_node_reps["prot_res"][:, 0]) + ) + * 10 + ) + drift_bb_gathered = drift_bb[protatm_res_idx_res] + drift_prot_atm_int = self.out_drift_atm[block_id]( + offloaded_node_reps["prot_atm"][:, 1:] + ).squeeze(-1) * torch.sigmoid( + self.out_scale_atm[block_id](offloaded_node_reps["prot_atm"][:, 0]) + ) + if not frozen_prot: + prot_atm_coords_flat = ( + prot_atm_coords_flat + drift_prot_atm_int + drift_bb_gathered + ) + + prot_atm_coords_padded = torch.zeros_like(features["input_protein_coords"]) + prot_atm_coords_padded[protatm_padding_mask] = prot_atm_coords_flat + backbone_frames = get_frame_matrix( + prot_atm_coords_padded[:, 0], + prot_atm_coords_padded[:, 1], + prot_atm_coords_padded[:, 2], + ) + + ret = { + "final_embedding_prot_atom": offloaded_node_reps["prot_atm"], + "final_embedding_prot_res": offloaded_node_reps["prot_res"], + "final_coords_prot_atom": prot_atm_coords_flat, + "final_coords_prot_atom_padded": prot_atm_coords_padded, + } + if not batch["misc"]["protein_only"]: + ret["final_embedding_lig_atom"] = offloaded_node_reps["lig_atm"] + ret["final_coords_lig_atom"] = lig_atm_coords + else: + ret["final_embedding_lig_atom"] = None + ret["final_coords_lig_atom"] = None + return ret + + +def resolve_score_head( + protein_model_cfg: DictConfig, + score_cfg: DictConfig, + task_cfg: DictConfig, + state_dict: Optional[STATE_DICT] = None, +) -> torch.nn.Module: + """Instantiates an EquivariantStructureDenoisingModule model for protein-ligand complex + structure denoising. + + :param protein_model_cfg: Protein model configuration. + :param score_cfg: Score configuration. + :param task_cfg: Task configuration. + :param state_dict: Optional (potentially-pretrained) state dictionary. + :return: EquivariantStructureDenoisingModule model. + """ + model = EquivariantStructureDenoisingModule( + score_cfg.fiber_dim, + input_dim=protein_model_cfg.residue_dim, + input_pair_dim=protein_model_cfg.pair_dim, + hidden_dim=score_cfg.hidden_dim, + n_stacks=score_cfg.n_stacks, + n_heads=protein_model_cfg.n_heads, + dropout=task_cfg.dropout, + ) + if score_cfg.from_pretrained and state_dict is not None: + try: + model.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("score_head") + } + ) + log.info("Successfully loaded pretrained score weights.") + except Exception as e: + log.warning(f"Skipping loading of pretrained score weights due to: {e}.") + return model + + +def resolve_confidence_head( + protein_model_cfg: DictConfig, + confidence_cfg: DictConfig, + task_cfg: DictConfig, + state_dict: Optional[STATE_DICT] = None, +) -> Tuple[torch.nn.Module, torch.nn.Module]: + """Instantiates an EquivariantStructureDenoisingModule model for confidence prediction. + + :param protein_model_cfg: Protein model configuration. + :param confidence_cfg: Confidence configuration. + :param task_cfg: Task configuration. + :param state_dict: Optional (potentially-pretrained) state dictionary. + :return: EquivariantStructureDenoisingModule model and plDDT gram head weights. + """ + confidence_head = EquivariantStructureDenoisingModule( + confidence_cfg.fiber_dim, + input_dim=protein_model_cfg.residue_dim, + input_pair_dim=protein_model_cfg.pair_dim, + hidden_dim=confidence_cfg.hidden_dim, + n_stacks=confidence_cfg.n_stacks, + n_heads=protein_model_cfg.n_heads, + dropout=task_cfg.dropout, + ) + plddt_gram_head = GELUMLP( + protein_model_cfg.pair_dim, + 8, + n_hidden_feats=protein_model_cfg.pair_dim, + zero_init=True, + ) + if confidence_cfg.from_pretrained and state_dict is not None: + try: + confidence_head.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("confidence_head") + } + ) + log.info("Successfully loaded pretrained confidence weights.") + except Exception as e: + log.warning(f"Skipping loading of pretrained confidence weights due to: {e}.") + + try: + plddt_gram_head.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("plddt_gram_head") + } + ) + log.info("Successfully loaded pretrained pLDDT gram head weights.") + except Exception as e: + log.warning(f"Skipping loading of pretrained pLDDT gram head weights due to: {e}.") + return confidence_head, plddt_gram_head + + +def resolve_affinity_head( + ligand_model_cfg: DictConfig, + affinity_cfg: DictConfig, + task_cfg: DictConfig, + learnable_pooling: bool = True, + state_dict: Optional[STATE_DICT] = None, +) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: + """Instantiates an EquivariantStructureDenoisingModule model for affinity prediction. + + :param ligand_model_cfg: Ligand model configuration. + :param affinity_cfg: Affinity configuration. + :param task_cfg: Task configuration. + :param learnable_pooling: Whether to use learnable ligand pooling modules. + :param state_dict: Optional (potentially-pretrained) state dictionary. + :return: EquivariantStructureDenoisingModule model as well as a ligand pooling module and + projection head. + """ + affinity_head = EquivariantStructureDenoisingModule( + affinity_cfg.fiber_dim, + input_dim=ligand_model_cfg.node_channels, + input_pair_dim=ligand_model_cfg.pair_channels, + hidden_dim=affinity_cfg.hidden_dim, + n_stacks=affinity_cfg.n_stacks, + n_heads=ligand_model_cfg.n_heads, + dropout=affinity_cfg.dropout if affinity_cfg.get("dropout") else task_cfg.dropout, + ) + if affinity_cfg.ligand_pooling in ["sum", "add", "summation", "addition"]: + ligand_pooling = SumPooling(learnable=learnable_pooling, hidden_dim=affinity_cfg.fiber_dim) + elif affinity_cfg.ligand_pooling in ["mean", "avg", "average"]: + ligand_pooling = AveragePooling( + learnable=learnable_pooling, hidden_dim=affinity_cfg.fiber_dim + ) + else: + raise NotImplementedError( + f"Unsupported ligand pooling method: {affinity_cfg.ligand_pooling}" + ) + affinity_proj_head = GELUMLP( + affinity_cfg.fiber_dim, + 1, + n_hidden_feats=affinity_cfg.fiber_dim, + zero_init=True, + ) + if affinity_cfg.from_pretrained and state_dict is not None: + try: + affinity_head.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("affinity_head") + } + ) + log.info("Successfully loaded pretrained affinity head weights.") + except Exception as e: + log.warning(f"Skipping loading of pretrained affinity head weights due to: {e}.") + + if learnable_pooling: + try: + ligand_pooling.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("ligand_pooling") + } + ) + log.info("Successfully loaded pretrained ligand pooling weights.") + except Exception as e: + log.warning(f"Skipping loading of pretrained ligand pooling weights due to: {e}.") + + try: + affinity_proj_head.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("affinity_proj_head") + } + ) + log.info("Successfully loaded pretrained affinity projection head weights.") + except Exception as e: + log.warning( + f"Skipping loading of pretrained affinity projection head weights due to: {e}." + ) + return affinity_head, ligand_pooling, affinity_proj_head diff --git a/flowdock/models/components/flowdock.py b/flowdock/models/components/flowdock.py new file mode 100644 index 0000000..0f7605f --- /dev/null +++ b/flowdock/models/components/flowdock.py @@ -0,0 +1,2029 @@ +import os +import random +from functools import partial + +import rootutils +import torch +import torch.nn.functional as F +import tqdm +from beartype.typing import Any, Callable, Dict, Literal, Optional, Tuple, Union +from omegaconf import DictConfig +from pytorch3d.ops import corresponding_points_alignment + +from flowdock.utils.inspect_ode_samplers import clamp_tensor + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.data.components.mol_features import collate_numpy_samples +from flowdock.data.components.physical import ( + get_vdw_radii_array, + get_vdw_radii_array_uff, +) +from flowdock.models.components.cpm import ( + resolve_pl_contact_stack, + resolve_protein_encoder, +) +from flowdock.models.components.esdm import ( + resolve_affinity_head, + resolve_confidence_head, + resolve_score_head, +) +from flowdock.models.components.losses import ( + compute_fape_from_atom37, + compute_lddt_ca, + compute_lddt_pli, + compute_TMscore_lbound, + compute_TMscore_raw, +) +from flowdock.models.components.mht_encoder import ( + resolve_ligand_encoder, + resolve_relational_reasoning_module, +) +from flowdock.models.components.noise import ( + sample_complex_harmonic_prior, + sample_esmfold_prior, + sample_gaussian_prior, + sample_ligand_harmonic_prior, +) +from flowdock.models.components.transforms import ( + DefaultPLCoordinateConverter, + LatentCoordinateConverter, +) +from flowdock.utils import RankedLogger +from flowdock.utils.data_utils import ( + erase_holo_coordinates, + get_standard_aa_features, + prepare_batch, +) +from flowdock.utils.frame_utils import apply_similarity_transform +from flowdock.utils.model_utils import ( + distogram_to_gaussian_contact_logits, + eval_true_contact_maps, + inplace_to_device, + inplace_to_torch, + sample_res_rowmask_from_contacts, + sample_reslig_contact_matrix, + segment_argmin, + segment_mean, + segment_sum, + topk_edge_mask_from_logits, +) + +MODEL_BATCH = Dict[str, Any] +NANOMETERS_TO_ANGSTROM = 10.0 + +log = RankedLogger(__name__, rank_zero_only=True) + + +class FlowDock(torch.nn.Module): + """A geometric conditional flow matching model for protein-ligand docking.""" + + def __init__( + self, + cfg: DictConfig, + ) -> None: + """Initialize a `FlowDock` module. + + :param cfg: A model configuration dictionary. + """ + super().__init__() + self.cfg = cfg + self.ligand_cfg = cfg.mol_encoder + self.protein_cfg = cfg.protein_encoder + self.relational_reasoning_cfg = cfg.relational_reasoning + self.contact_cfg = cfg.contact_predictor + self.score_cfg = cfg.score_head + self.confidence_cfg = cfg.confidence + self.affinity_cfg = cfg.affinity + self.global_cfg = cfg.task + self.protatm_padding_dim = self.protein_cfg.atom_padding_dim # := 37 + self.max_n_edges = self.global_cfg.edge_crop_size + self.latent_model = cfg.latent_model + self.prior_type = cfg.prior_type + + # VDW radius mapping, in Angstrom + self.atnum2vdw = torch.nn.Parameter( + torch.tensor(get_vdw_radii_array() / 100.0), + requires_grad=False, + ) + self.atnum2vdw_uff = torch.nn.Parameter( + torch.tensor(get_vdw_radii_array_uff() / 100.0), + requires_grad=False, + ) + + # graph hyperparameters + self.BINDING_SITE_CUTOFF = 6.0 + self.INTERNAL_TARGET_VARIANCE_SCALE = self.global_cfg.internal_max_sigma + self.GLOBAL_TARGET_VARIANCE_SCALE = self.global_cfg.global_max_sigma + self.CONTACT_SCALE = 5.0 # fixed hyperparameter + + ( + standard_aa_template_featset, + standard_aa_graph_featset, + ) = get_standard_aa_features() + self.standard_aa_template_featset = inplace_to_torch(standard_aa_template_featset) + self.standard_aa_molgraph_featset = inplace_to_torch( + collate_numpy_samples(standard_aa_graph_featset) + ) + + # load pretrained weights as desired + from_pretrained = ( + self.ligand_cfg.from_pretrained + or self.protein_cfg.from_pretrained + or self.relational_reasoning_cfg.from_pretrained + or self.contact_cfg.from_pretrained + or self.score_cfg.from_pretrained + or self.confidence_cfg.from_pretrained + or self.affinity_cfg.from_pretrained + ) + if from_pretrained: + assert cfg.mol_encoder.checkpoint_file is not None and os.path.exists( + cfg.mol_encoder.checkpoint_file + ), "Pretrained model weights not found." + pretrained_state_dict = ( + torch.load(cfg.mol_encoder.checkpoint_file)["state_dict"] if from_pretrained else None + ) + + # ligand encoder + self.ligand_encoder = resolve_ligand_encoder( + self.ligand_cfg, self.global_cfg, state_dict=pretrained_state_dict + ) + self.lig_masking_rate = self.global_cfg.max_masking_rate + + # protein structure encoder + self.protein_encoder, res_in_projector = resolve_protein_encoder( + self.protein_cfg, self.global_cfg, state_dict=pretrained_state_dict + ) + # protein sequence encoder + if self.protein_cfg.use_esm_embedding: + # protein sequence language model + self.plm_adapter = res_in_projector + else: + # one-hot amino acid types + self.res_in_projector = res_in_projector + + # relational reasoning module + ( + self.molgraph_single_projector, + self.molgraph_pair_projector, + self.covalent_embed, + ) = resolve_relational_reasoning_module( + self.protein_cfg, + self.ligand_cfg, + self.relational_reasoning_cfg, + state_dict=pretrained_state_dict, + ) + + # contact prediction module + ( + self.pl_contact_stack, + self.contact_code_embed, + self.dist_bins, + self.dgram_head, + ) = resolve_pl_contact_stack( + self.protein_cfg, + self.ligand_cfg, + self.contact_cfg, + self.global_cfg, + state_dict=pretrained_state_dict, + ) + + # structure denoising module + self.score_head = resolve_score_head( + self.protein_cfg, self.score_cfg, self.global_cfg, state_dict=pretrained_state_dict + ) + + # confidence prediction module + if self.confidence_cfg.enabled: + self.confidence_head, self.plddt_gram_head = resolve_confidence_head( + self.protein_cfg, + self.confidence_cfg, + self.global_cfg, + state_dict=pretrained_state_dict, + ) + + # affinity prediction module + if self.affinity_cfg.enabled: + ( + self.affinity_head, + self.ligand_pooling, + self.affinity_proj_head, + ) = resolve_affinity_head( + self.ligand_cfg, + self.affinity_cfg, + self.global_cfg, + learnable_pooling=True, + state_dict=pretrained_state_dict, + ) + + self.freeze_pretraining_params() + + def freeze_pretraining_params(self): + """Freeze pretraining parameters.""" + if self.global_cfg.freeze_mol_encoder: + log.info("Freezing ligand encoder parameters.") + self.ligand_encoder.eval() + for p in self.ligand_encoder.parameters(): + p.requires_grad = False + if self.global_cfg.freeze_protein_encoder: + log.info("Freezing protein encoder parameters.") + for module in [ + ( + self.plm_adapter + if self.protein_cfg.use_esm_embedding + else self.res_in_projector + ), + self.protein_encoder, + ]: + module.eval() + for p in module.parameters(): + p.requires_grad = False + if self.global_cfg.freeze_relational_reasoning: + log.info("Freezing relational reasoning module parameters.") + for module in [ + self.molgraph_single_projector, + self.molgraph_pair_projector, + self.covalent_embed, + ]: + module.eval() + for p in module.parameters(): + p.requires_grad = False + if self.global_cfg.freeze_contact_predictor: + log.info("Freezing contact prediction module parameters.") + for module in [ + self.pl_contact_stack, + self.contact_code_embed, + self.dgram_head, + ]: + module.eval() + for p in module.parameters(): + p.requires_grad = False + if self.global_cfg.freeze_score_head: + log.info("Freezing structure denoising module parameters.") + self.score_head.eval() + for p in self.score_head.parameters(): + p.requires_grad = False + if self.confidence_cfg.enabled and self.global_cfg.freeze_confidence: + log.info("Freezing confidence prediction module parameters.") + for module in [self.confidence_head, self.plddt_gram_head]: + module.eval() + for p in module.parameters(): + p.requires_grad = False + if self.affinity_cfg.enabled and self.global_cfg.freeze_affinity: + log.info("Freezing affinity prediction module parameters.") + self.affinity_head.eval() + for p in self.affinity_head.parameters(): + p.requires_grad = False + + @staticmethod + def assign_timestep_encodings(batch: MODEL_BATCH, t_normalized: Union[float, torch.Tensor]): + """Assign timestep encodings to the batch. + + :param batch: A batch dictionary. + :param t_normalized: The normalized timestep. + """ + # NOTE: `t_normalized` must be in the range `[0, 1]` + features = batch["features"] + indexer = batch["indexer"] + device = features["res_type"].device + if not isinstance(t_normalized, torch.Tensor): + t_normalized = torch.full( + (batch["metadata"]["num_structid"], 1), + t_normalized, + device=device, + ) + if t_normalized.shape != (batch["metadata"]["num_structid"], 1): + assert ( + t_normalized.numel() == 1 + ), f"To properly shape-coerce time step tensor of shape {t_normalized.shape}, the input tensor must contain a single value." + t_normalized = torch.full( + (batch["metadata"]["num_structid"], 1), + t_normalized.item(), + device=device, + ) + t_prot = t_normalized[indexer["gather_idx_a_structid"]] + batch["features"]["timestep_encoding_prot"] = t_prot + + if not batch["misc"]["protein_only"]: + batch["features"]["timestep_encoding_lig"] = t_normalized[ + indexer["gather_idx_i_structid"] + ] + + def resolve_latent_converter(self, *args): + if self.latent_model == "default": + return DefaultPLCoordinateConverter(self.global_cfg, *args) + else: + raise NotImplementedError + + def forward_interp( + self, + batch: MODEL_BATCH, + x_int_0: torch.Tensor, + t: torch.Tensor, + latent_converter: LatentCoordinateConverter, + umeyama_correction: bool = True, + erase_data: bool = False, + ) -> torch.Tensor: + """Interpolate latent internal coordinates. + + Note that this function adds small amounts of Gaussian noise + to the ground-truth protein and ligand coordinates, to discourage + the model from overfitting to experimental noise in the training data. + Reference: https://www.science.org/doi/10.1126/science.add2187 + + :param batch: A batch dictionary. + :param x_int_0: Dimension-less (ground-truth) internal coordinates. + :param t: The current normalized timestep. + :param latent_converter: The latent coordinate converter. + :param umeyama_correction: Whether to apply the Umeyama correction. + :param erase_data: Whether to erase data. + :return: Interpolated latent internal coordinates. + """ + ( + ca_lat, + apo_ca_lat, + cother_lat, + apo_cother_lat, + ca_lat_centroid_coords, + apo_ca_lat_centroid_coords, + lig_lat, + ) = torch.split( + x_int_0, + [ + latent_converter._n_res_per_sample, + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_molid_per_sample, + latent_converter._n_molid_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + x_int_0_ = torch.cat( + [ + ca_lat, + cother_lat, + lig_lat, + ], + dim=1, + ) + try: + assert self.global_cfg.single_protein_batch, "Only single protein batch is supported." + if self.prior_type == "gaussian": + noisy_x_int_0, noisy_x_int_1 = sample_gaussian_prior( + x_int_0_, latent_converter, sigma=1.0, x0_sigma=1e-4 + ) + elif self.prior_type == "harmonic": + noisy_x_int_0, noisy_x_int_1 = sample_complex_harmonic_prior( + x_int_0_, latent_converter, batch, x0_sigma=1e-4 + ) + elif self.prior_type == "esmfold": + # NOTE: the following unnormalization step assumes that `self.latent_model == "default"` + apo_lig_lat_ = sample_ligand_harmonic_prior( + lig_lat, apo_ca_lat * latent_converter.ca_scale, batch + ) + x_int_1_ = torch.cat( + [ + apo_ca_lat, # NOTE: already normalized + apo_cother_lat, # NOTE: already normalized + apo_lig_lat_ / latent_converter.other_scale, + ], + dim=1, + ) + noisy_x_int_0, noisy_x_int_1 = sample_esmfold_prior( + x_int_0_, x_int_1_, sigma=1e-4, x0_sigma=1e-4 + ) + else: + raise NotImplementedError(f"Unsupported prior type: {self.prior_type}") + except Exception as e: + log.error( + f"Failed to converge within `{self.prior_type}` noise function of `forward_interp()` due to: {e}." + ) + raise e + + # NOTE: by this point, both `noisy_x_int_0` and `noisy_x_int_1` are normalized + if umeyama_correction and not erase_data: + try: + # align the complex structure based solely the optimal Ca atom alignment + # NOTE: we do not perform such alignments during the initial (`t=1`) sampling timestep + noisy_x_ca_int_1 = noisy_x_int_1.split( + [ + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + )[0] + noisy_x_ca_int_0 = noisy_x_int_0.split( + [ + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + )[0] + similarity_transform = corresponding_points_alignment( + X=noisy_x_ca_int_1, Y=noisy_x_ca_int_0, estimate_scale=False + ) + noisy_x_int_1 = apply_similarity_transform(noisy_x_int_1, *similarity_transform) + except Exception as e: + log.warning( + f"Failed optimal noise alignment within `forward_interp()` due to: {e}. Skipping optimal noise alignment..." + ) + raise e + + # interpolate between target and prior distributions + noisy_x_int_t = (1 - t) * noisy_x_int_0 + t * noisy_x_int_1 + + # recalculate (and renormalize) centroid coordinates + ( + x_int_t_ca_lat, + x_int_t_cother_lat, + x_int_t_lig_lat, + ) = torch.split( + noisy_x_int_t, + [ + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + ( + x_int_1_ca_lat, + x_int_1_cother_lat, + _, + ) = torch.split( + noisy_x_int_1, + [ + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + x_int = torch.cat( + [ + x_int_t_ca_lat, + x_int_1_ca_lat, + x_int_t_cother_lat, + x_int_1_cother_lat, + ca_lat_centroid_coords, + apo_ca_lat_centroid_coords, + x_int_t_lig_lat, + ], + dim=1, + ) + return x_int + + def forward_interp_plcomplex_latinp( + self, + batch: MODEL_BATCH, + t: torch.Tensor, + latent_converter: LatentCoordinateConverter, + umeyama_correction: bool = True, + erase_data: bool = False, + ) -> MODEL_BATCH: + """Noise-interpolate protein-ligand complex latent internal coordinates. + + :param batch: A batch dictionary. + :param t: The current normalized timestep. + :param latent_converter: The latent coordinate converter. + :param umeyama_correction: Whether to apply the Umeyama correction. + :param erase_data: Whether to erase data. + :return: Batch dictionary with interpolated latent internal coordinates. + """ + # Dimension-less internal coordinates + # [B, N, 3] + x_int = latent_converter.to_latent(batch) + if erase_data: + x_int = erase_holo_coordinates(batch, x_int, latent_converter) + x_int_t = self.forward_interp( + batch, + x_int, + t, + latent_converter, + umeyama_correction=umeyama_correction, + erase_data=erase_data, + ) + return latent_converter.assign_to_batch(batch, x_int_t) + + def prepare_protein_patch_indexers( + self, batch: MODEL_BATCH, randomize_anchors: bool = False + ) -> MODEL_BATCH: + """Prepare protein patch indexers for the batch. + + :param batch: A batch dictionary. + :param randomize_anchors: Whether to randomize the anchors. + :return: Batch dictionary with protein patch indexers. + """ + features = batch["features"] + metadata = batch["metadata"] + indexer = batch["indexer"] + batch_size = metadata["num_structid"] + device = features["res_type"].device + + # Prepare indexers + # Use max to ensure segmentation faults are 100% invoked + # in case there are any bad indices + max(metadata["num_a_per_sample"]) + n_a_per_sample = max(metadata["num_a_per_sample"]) + assert ( + n_a_per_sample * batch_size == metadata["num_a"] + ), "Invalid (batched) number of residues" + n_protein_patches = min(self.protein_cfg.n_patches, n_a_per_sample) + batch["metadata"]["n_prot_patches_per_sample"] = n_protein_patches + + # Uniform segmentation + res_idx_in_batch = torch.arange(metadata["num_a"], device=device) + batch["indexer"]["gather_idx_a_pid"] = ( + res_idx_in_batch // n_a_per_sample + ) * n_protein_patches + ( + ((res_idx_in_batch % n_a_per_sample) * n_protein_patches) // n_a_per_sample + ) + + if randomize_anchors: + # Random down-sampling, assigning residues to the patch grid + # This maps grid row/column idx to sampled residue idx + batch["indexer"]["gather_idx_pid_a"] = segment_argmin( + batch["features"]["res_type"].new_zeros(n_a_per_sample * batch_size), + indexer["gather_idx_a_pid"], + n_protein_patches * batch_size, + randomize=True, + ) + else: + batch["indexer"]["gather_idx_pid_a"] = segment_mean( + res_idx_in_batch, + indexer["gather_idx_a_pid"], + n_protein_patches * batch_size, + ).long() + + return batch + + def prepare_protein_backbone_indexers( + self, batch: MODEL_BATCH, **kwargs: Dict[str, Any] + ) -> MODEL_BATCH: + """Prepare protein backbone indexers for the batch. + + :param batch: A batch dictionary. + :param kwargs: Additional keyword arguments. + :return: Batch dictionary with protein backbone indexers. + """ + features = batch["features"] + indexer = batch["indexer"] + metadata = batch["metadata"] + device = features["res_type"].device + + protatm_coords_padded = features["input_protein_coords"] + batch_size = metadata["num_structid"] + + assert self.global_cfg.single_protein_batch, "Only single protein batch is supported." + num_res_per_struct = max(metadata["num_a_per_sample"]) + # Check that the samples are clones of the same complex + assert ( + batch_size * num_res_per_struct == protatm_coords_padded.shape[0] + ), "Invalid number of residues." + + input_prot_coords_folded = protatm_coords_padded.unflatten( + 0, (batch_size, num_res_per_struct) + ) + single_struct_chain_id = indexer["gather_idx_a_chainid"][:num_res_per_struct] + single_struct_res_id = features["residue_index"][:num_res_per_struct] + ca_ca_dist = ( + input_prot_coords_folded[:, :, None, 1] - input_prot_coords_folded[:, None, :, 1] + ).norm(dim=-1) + ca_ca_knn_mask = topk_edge_mask_from_logits( + -ca_ca_dist / self.CONTACT_SCALE, + self.protein_cfg.max_residue_degree, + randomize=True, + ) + chain_mask = single_struct_chain_id[None, :, None] == single_struct_chain_id[None, None, :] + sequence_dist = single_struct_res_id[None, :, None] - single_struct_res_id[None, None, :] + sequence_proximity_mask = (torch.abs(sequence_dist) <= 4) & chain_mask + prot_res_res_edge_mask = ca_ca_knn_mask | sequence_proximity_mask + + dense_row_idx_3D = ( + torch.arange(batch_size * num_res_per_struct, device=device) + .view(batch_size, num_res_per_struct)[:, :, None] + .expand(-1, -1, num_res_per_struct) + ).contiguous() + dense_col_idx_3D = dense_row_idx_3D.transpose(1, 2).contiguous() + batch["metadata"]["num_prot_res"] = metadata["num_a"] + batch["indexer"]["gather_idx_ab_a"] = dense_row_idx_3D[prot_res_res_edge_mask] + batch["indexer"]["gather_idx_ab_b"] = dense_col_idx_3D[prot_res_res_edge_mask] + batch["indexer"]["gather_idx_ab_structid"] = indexer["gather_idx_a_structid"][ + indexer["gather_idx_ab_a"] + ] + batch["metadata"]["num_ab"] = batch["indexer"]["gather_idx_ab_a"].shape[0] + + if self.global_cfg.constrained_inpainting: + # Diversified spherical cropping scheme + assert self.global_cfg.single_protein_batch, "Only single protein batch is supported." + batch_size = metadata["num_structid"] + # Assert single ligand samples + assert batch_size == metadata["num_molid"], "Invalid number of ligands." + ligand_coords = batch["features"]["sdf_coordinates"].reshape(batch_size, -1, 3) + ligand_centroids = torch.mean(ligand_coords, dim=1) + if kwargs["training"]: + # 3A perturbations around the ligand centroid + perturbed_centroids = ligand_centroids + torch.rand_like(ligand_centroids) * 1.73 + site_radius = torch.amax( + torch.norm(ligand_coords - perturbed_centroids[:, None, :], dim=-1), + dim=1, + ) + perturbed_site_radius = ( + site_radius + (0.5 + torch.rand_like(site_radius)) * self.BINDING_SITE_CUTOFF + ) + else: + perturbed_centroids = ligand_centroids + site_radius = torch.amax( + torch.norm(ligand_coords - perturbed_centroids[:, None, :], dim=-1), + dim=1, + ) + perturbed_site_radius = site_radius + self.BINDING_SITE_CUTOFF + centroid_ca_dist = ( + batch["features"]["res_atom_positions"][:, 1].contiguous().view(batch_size, -1, 3) + - perturbed_centroids[:, None, :] + ).norm(dim=-1) + binding_site_mask = (centroid_ca_dist < perturbed_site_radius[:, None]).flatten(0, 1) + batch["features"]["binding_site_mask"] = binding_site_mask + batch["features"]["template_alignment_mask"] = (~binding_site_mask) & batch[ + "features" + ]["template_alignment_mask"].bool() + + return batch + + def initialize_protein_embeddings(self, batch: MODEL_BATCH): + """Initialize protein embeddings. + + :param batch: A batch dictionary. + """ + features = batch["features"] + + # Protein residue and residue-pair embeddings + if "res_embedding_in" not in features: + if self.protein_cfg.use_esm_embedding: + assert ( + self.global_cfg.single_protein_batch + ), "Only single protein batch is supported." + features["res_embedding_in"] = self.plm_adapter( + batch["features"]["apo_lm_embeddings"] + if "apo_lm_embeddings" in batch["features"] + else batch["features"]["lm_embeddings"] + ) + assert ( + features["res_embedding_in"].shape[0] == features["res_atom_types"].shape[0] + ), "Invalid number of residues." + else: + features["res_embedding_in"] = self.res_in_projector( + F.one_hot( + features["res_type"].long(), + num_classes=self.protein_cfg.n_aa_types, + ).float() + ) + + def initialize_protatm_indexer_and_embeddings(self, batch: MODEL_BATCH): + """Assign coordinate-independent edges and protein features from PIFormer. + + :param batch: A batch dictionary. + """ + features = batch["features"] + device = features["res_type"].device + assert self.global_cfg.single_protein_batch, "Only single protein batch is supported." + self.standard_aa_molgraph_featset = inplace_to_device( + self.standard_aa_molgraph_featset, device + ) + self.standard_aa_template_featset = inplace_to_device( + self.standard_aa_template_featset, device + ) + self.standard_aa_molgraph_featset = self.ligand_encoder(self.standard_aa_molgraph_featset) + with torch.no_grad(): + assert ( + self.standard_aa_molgraph_featset["metadata"]["num_i"] == 167 + ), "Invalid number of atoms." + template_atom_idx_in_batch_padded = torch.full( + (20, 37), fill_value=-1, dtype=torch.long, device=device + ) + template_atom37_mask = self.standard_aa_template_featset["features"][ + "res_atom_mask" + ].bool() + template_atom_idx_in_batch_padded[template_atom37_mask] = torch.arange( + 167, device=device + ) + atom_idx_in_batch_to_restype_idx = ( + torch.arange(20, device=device)[:, None] + .expand(-1, 37) + .contiguous()[template_atom37_mask] + ) + atom_idx_in_batch_to_atom37_idx = ( + torch.arange(37, device=device)[None, :] + .expand(20, -1) + .contiguous()[template_atom37_mask] + ) + template_padded_edge_mask_per_aa = torch.zeros( + (20, 37, 37), dtype=torch.bool, device=device + ) + template_aa_graph_indexer = self.standard_aa_molgraph_featset["indexer"] + template_padded_edge_mask_per_aa[ + atom_idx_in_batch_to_restype_idx[template_aa_graph_indexer["gather_idx_uv_u"]], + atom_idx_in_batch_to_atom37_idx[template_aa_graph_indexer["gather_idx_uv_u"]], + atom_idx_in_batch_to_atom37_idx[template_aa_graph_indexer["gather_idx_uv_v"]], + ] = True + + # Gather adjacency matrix to the input protein + features = batch["features"] + metadata = batch["metadata"] + # Prepare intra-residue protein atom - protein atom indexers + n_res_first = max(metadata["num_a_per_sample"]) + batch["features"]["res_atom_mask"] = features["res_atom_mask"].bool() + protatm_padding_mask = batch["features"]["res_atom_mask"][:n_res_first] + n_protatm_first = int(protatm_padding_mask.sum()) + protatm_res_idx_res_first = ( + torch.arange(n_res_first, device=device)[:, None] + .expand(-1, 37) + .contiguous()[protatm_padding_mask] + ) + protatm_to_atom37_idx_first = ( + torch.arange(37, device=device)[None, :] + .expand(n_res_first, -1) + .contiguous()[protatm_padding_mask] + ) + same_residue_mask = ( + protatm_res_idx_res_first[:, None] == protatm_res_idx_res_first[None, :] + ).contiguous() + aa_graph_edge_mask = torch.zeros( + (n_protatm_first, n_protatm_first), dtype=torch.bool, device=device + ) + src_idx_sameres = ( + torch.arange(n_protatm_first, device=device)[:, None] + .expand(-1, n_protatm_first) + .contiguous()[same_residue_mask] + ) + dst_idx_sameres = ( + torch.arange(n_protatm_first, device=device)[None, :] + .expand(n_protatm_first, -1) + .contiguous()[same_residue_mask] + ) + aa_graph_edge_mask[ + src_idx_sameres, dst_idx_sameres + ] = template_padded_edge_mask_per_aa[ + features["res_type"].long()[protatm_res_idx_res_first[src_idx_sameres]], + protatm_to_atom37_idx_first[src_idx_sameres], + protatm_to_atom37_idx_first[dst_idx_sameres], + ] + src_idx_first = ( + torch.arange(n_protatm_first, device=device)[:, None] + .expand(-1, n_protatm_first) + .contiguous()[aa_graph_edge_mask] + ) + dst_idx_first = ( + torch.arange(n_protatm_first, device=device)[None, :] + .expand(n_protatm_first, -1) + .contiguous()[aa_graph_edge_mask] + ) + batch_size = metadata["num_structid"] + src_idx = ( + ( + src_idx_first[None, :].expand(batch_size, -1) + + torch.arange(batch_size, device=device)[:, None] * n_protatm_first + ) + .contiguous() + .flatten() + ) + dst_idx = ( + ( + dst_idx_first[None, :].expand(batch_size, -1) + + torch.arange(batch_size, device=device)[:, None] * n_protatm_first + ) + .contiguous() + .flatten() + ) + batch["metadata"]["num_protatm_per_sample"] = n_protatm_first + batch["indexer"]["protatm_protatm_idx_src"] = src_idx + batch["indexer"]["protatm_protatm_idx_dst"] = dst_idx + batch["metadata"]["num_prot_atm"] = n_protatm_first * batch_size + batch["indexer"]["protatm_res_idx_res"] = ( + ( + protatm_res_idx_res_first[None, :].expand(batch_size, -1) + + torch.arange(batch_size, device=device)[:, None] * n_res_first + ) + .contiguous() + .flatten() + ) + batch["indexer"]["protatm_res_idx_protatm"] = torch.arange( + batch["metadata"]["num_prot_atm"], device=device + ) + # Gather graph features to the protein feature set + template_padded_node_feat_per_aa = torch.zeros( + (20, 37, self.protein_cfg.residue_dim), device=device + ) + template_padded_node_feat_per_aa[template_atom37_mask] = self.molgraph_single_projector( + self.standard_aa_molgraph_featset["features"]["lig_atom_attr"] + ) + protatm_padding_mask = batch["features"]["res_atom_mask"] + protatm_to_atom37_idx = ( + protatm_to_atom37_idx_first[None, :].expand(batch_size, -1).contiguous().flatten(0, 1) + ) + batch["features"]["protatm_to_atom37_index"] = protatm_to_atom37_idx + batch["features"]["protatm_to_atomic_number"] = features["res_atom_types"].long()[ + protatm_padding_mask + ] + batch["features"]["prot_atom_attr_projected"] = ( + template_padded_node_feat_per_aa[ + features["res_type"].long()[protatm_res_idx_res_first], + protatm_to_atom37_idx_first, + ][None, :] + .expand(batch_size, -1, -1) + .contiguous() + .flatten(0, 1) + ) + template_padded_edge_feat_per_aa = torch.zeros( + (20, 37, 37, self.protein_cfg.pair_dim), device=device + ) + template_padded_edge_feat_per_aa[ + atom_idx_in_batch_to_restype_idx[template_aa_graph_indexer["gather_idx_uv_u"]], + atom_idx_in_batch_to_atom37_idx[template_aa_graph_indexer["gather_idx_uv_u"]], + atom_idx_in_batch_to_atom37_idx[template_aa_graph_indexer["gather_idx_uv_v"]], + ] = self.molgraph_pair_projector( + self.standard_aa_molgraph_featset["features"]["lig_atom_pair_attr"] + ) + + batch["features"]["prot_atom_pair_attr_projected"] = ( + template_padded_edge_feat_per_aa[ + features["res_type"].long()[protatm_res_idx_res_first[src_idx_first]], + protatm_to_atom37_idx_first[src_idx_first], + protatm_to_atom37_idx_first[dst_idx_first], + ][None, :, :] + .expand(batch_size, -1, -1) + .contiguous() + .flatten(0, 1) + ) + return batch + + def initialize_ligand_embeddings(self, batch: MODEL_BATCH, **kwargs: Dict[str, Any]): + """Initialize ligand embeddings. + + :param batch: A batch dictionary. + :param kwargs: Additional keyword arguments. + """ + metadata = batch["metadata"] + batch["features"] + indexer = batch["indexer"] + batch_size = metadata["num_structid"] + + # Ligand atom, frame and pair embeddings + if kwargs["training"]: + masking_rate = random.uniform(0, self.lig_masking_rate) # nosec + else: + masking_rate = 0 + batch = self.ligand_encoder(batch, masking_rate=masking_rate) + batch["features"]["lig_atom_attr_projected"] = self.molgraph_single_projector( + batch["features"]["lig_atom_attr"] + ) + # Downsampled ligand frames + batch["features"]["lig_trp_attr_projected"] = self.molgraph_single_projector( + batch["features"]["lig_trp_attr"] + ) + batch["features"]["lig_atom_pair_attr_projected"] = self.molgraph_pair_projector( + batch["features"]["lig_atom_pair_attr"] + ) + lig_af_pair_attr_flat_ = self.molgraph_pair_projector( + batch["features"]["lig_af_pair_attr"] + ) + batch["features"]["lig_af_pair_attr_projected"] = lig_af_pair_attr_flat_ + + if self.global_cfg.single_protein_batch: + lig_af_pair_attr = lig_af_pair_attr_flat_.new_zeros( + batch_size, + max(metadata["num_U_per_sample"]), + max(metadata["num_I_per_sample"]), + self.protein_cfg.pair_dim, + ) + n_U_first = max(metadata["num_U_per_sample"]) + n_I_first = max(metadata["num_I_per_sample"]) + lig_af_pair_attr[ + indexer["gather_idx_UI_U"] // n_U_first, + indexer["gather_idx_UI_U"] % n_U_first, + indexer["gather_idx_UI_I"] % n_I_first, + ] = lig_af_pair_attr_flat_ + + batch["features"]["lig_af_grid_attr_projected"] = lig_af_pair_attr + else: + raise NotImplementedError("Only single protein batch is supported.") + + def run_encoder_stack( + self, + batch: MODEL_BATCH, + **kwargs: Dict[str, Any], + ) -> MODEL_BATCH: + """Run the encoder stack. + + :param lit_module: A LightningModule instance. + :param batch: A batch dictionary. + :param training: Whether the model is in training mode. + :param kwargs: Additional keyword arguments. + :return: A batch dictionary. + """ + with torch.no_grad(): + batch = self.prepare_protein_patch_indexers( + batch, randomize_anchors=kwargs["training"] + ) + self.prepare_protein_backbone_indexers(batch, **kwargs) + self.initialize_protein_embeddings(batch) + self.initialize_protatm_indexer_and_embeddings(batch) + + batch = self.protein_encoder( + batch, + in_attr_suffix="", + out_attr_suffix="_projected", + **kwargs, + ) + if batch["misc"]["protein_only"]: + return batch + + # NOTE: here, we are assuming a static ligand graph + if "lig_atom_attr" not in batch["features"]: + self.initialize_ligand_embeddings(batch, **kwargs) + return batch + + def run_contact_map_stack( + self, + batch: MODEL_BATCH, + iter_id: Union[int, str], + observed_block_contacts: Optional[torch.Tensor] = None, + **kwargs: Dict[str, Any], + ) -> MODEL_BATCH: + """Run the contact map stack. + + :param batch: A batch dictionary. + :param iter_id: The current iteration ID. + :param observed_block_contacts: Optional observed block contacts. + :param kwargs: Additional keyword arguments. + :return: A batch dictionary. + """ + features = batch["features"] + device = features["res_type"].device + if observed_block_contacts is not None: + # Merge into 8AA blocks and gather to patches + patch8_idx = ( + torch.arange( + observed_block_contacts.shape[1], + device=device, + ) + // 8 + ) + merged_contacts_reswise = ( + segment_sum( + observed_block_contacts.transpose(0, 1).contiguous(), + patch8_idx, + max(patch8_idx) + 1, + ) + .bool()[patch8_idx] + .transpose(0, 1) + .contiguous() + ) + merged_contacts_gathered = ( + merged_contacts_reswise.flatten(0, 1)[batch["indexer"]["gather_idx_pid_a"]] + .contiguous() + .view( + observed_block_contacts.shape[0], + -1, + observed_block_contacts.shape[2], + ) + ) + block_contact_embedding = self.contact_code_embed(merged_contacts_gathered.long()) + else: + block_contact_embedding = None + batch = self.pl_contact_stack( + batch, + in_attr_suffix="_projected", + out_attr_suffix=f"_out_{iter_id}", + observed_block_contacts=block_contact_embedding, + ) + + if batch["misc"]["protein_only"]: + return batch + + metadata = batch["metadata"] + batch_size = metadata["num_structid"] + n_a_per_sample = max(metadata["num_a_per_sample"]) + n_I_per_sample = metadata["n_lig_patches_per_sample"] + res_lig_pair_attr = batch["features"][f"res_trp_pair_attr_flat_out_{iter_id}"] + raw_dgram_logits = self.dgram_head(res_lig_pair_attr).view( + batch_size, n_a_per_sample, n_I_per_sample, 32 + ) + batch["outputs"][f"res_lig_distogram_out_{iter_id}"] = F.log_softmax( + raw_dgram_logits, dim=-1 + ) + return batch + + def infer_geometry_prior( + self, + batch: MODEL_BATCH, + cached_block_contacts: Optional[torch.Tensor] = None, + binding_site_mask: Optional[torch.Tensor] = None, + logit_clamp_value: Optional[torch.Tensor] = None, + **kwargs: Dict[str, Any], + ): + """Infer a geometry prior. + + :param batch: A batch dictionary. + :param cached_block_contacts: Cached block contacts. + :param binding_site_mask: Binding site mask. + :param logit_clamp_value: Logit clamp value. + """ + # Parse self.task_cfg.block_contact_decoding_scheme + assert ( + self.global_cfg.block_contact_decoding_scheme == "beam" + ), "Only beam search is supported." + n_lig_frames = max(batch["metadata"]["num_I_per_sample"]) + # Autoregressive block-contact sampling + if cached_block_contacts is None: + # Start from the prior distribution + sampled_block_contacts = None + last_distogram = batch["outputs"]["res_lig_distogram_out_0"] + for iter_id in tqdm.tqdm(range(n_lig_frames), desc="Block contact sampling"): + last_contact_map = distogram_to_gaussian_contact_logits( + last_distogram, self.dist_bins, self.CONTACT_SCALE + ) + sampled_block_contacts = sample_reslig_contact_matrix( + batch, last_contact_map, last=sampled_block_contacts + ).detach() + self.run_contact_map_stack( + batch, iter_id, observed_block_contacts=sampled_block_contacts + ) + last_distogram = batch["outputs"][f"res_lig_distogram_out_{iter_id}"] + batch["outputs"]["sampled_block_contacts_last"] = sampled_block_contacts + # Check that all ligands are assigned to one protein chain segment + num_assigned_per_lig = segment_sum( + torch.sum(sampled_block_contacts, dim=1).contiguous().flatten(0, 1), + batch["indexer"]["gather_idx_I_molid"], + batch["metadata"]["num_molid"], + ) + assert torch.all(num_assigned_per_lig >= 1) + else: + sampled_block_contacts = cached_block_contacts + + # Use the cached contacts and only sample once + self.run_contact_map_stack( + batch, n_lig_frames, observed_block_contacts=sampled_block_contacts + ) + last_distogram = batch["outputs"][f"res_lig_distogram_out_{n_lig_frames}"] + res_lig_contact_logit_pred = distogram_to_gaussian_contact_logits( + last_distogram, self.dist_bins, self.CONTACT_SCALE + ) + + if binding_site_mask is not None: + res_lig_contact_logit_pred = res_lig_contact_logit_pred - ( + ~binding_site_mask[:, :, None] * 1e9 + ) + if not kwargs["training"] and logit_clamp_value is not None: + res_lig_contact_logit_pred = ( + res_lig_contact_logit_pred - (res_lig_contact_logit_pred < logit_clamp_value) * 1e9 + ) + batch["outputs"]["geometry_prior_L"] = res_lig_contact_logit_pred.flatten() + + def init_randexp_kNN_edges_and_covmask( + self, batch: MODEL_BATCH, detect_covalent: bool = False + ): + """Initialize random expansion kNN edges and covalent mask. + + :param batch: A batch dictionary. + :param detect_covalent: Whether to detect covalent bonds. + """ + device = batch["features"]["res_type"].device + batch_size = batch["metadata"]["num_structid"] + protatm_padding_mask = batch["features"]["res_atom_mask"] + prot_atm_coords_padded = batch["features"]["input_protein_coords"] + protatm_coords = prot_atm_coords_padded[protatm_padding_mask].contiguous() + n_protatm_per_sample = batch["metadata"]["num_protatm_per_sample"] + protatm_coords = protatm_coords.view(batch_size, n_protatm_per_sample, 3) + if not batch["misc"]["protein_only"]: + n_ligatm_per_sample = max(batch["metadata"]["num_i_per_sample"]) + ligatm_coords = batch["features"]["input_ligand_coords"] + ligatm_coords = ligatm_coords.view(batch_size, n_ligatm_per_sample, 3) + atm_coords = torch.cat([protatm_coords, ligatm_coords], dim=1) + else: + atm_coords = protatm_coords + distance_mat = torch.norm(atm_coords[:, :, None] - atm_coords[:, None, :], dim=-1) + distance_mat[distance_mat == 0] = 1e9 + knn_edge_mask = topk_edge_mask_from_logits( + -distance_mat / self.CONTACT_SCALE, + self.cfg.score_head.max_atom_degree, + randomize=True, + ) + if (not batch["misc"]["protein_only"]) and detect_covalent: + prot_atomic_numbers = batch["features"]["protatm_to_atomic_number"].view( + batch_size, n_protatm_per_sample + ) + lig_atomic_numbers = ( + batch["features"]["atomic_numbers"].long().view(batch_size, n_ligatm_per_sample) + ) + atom_vdw = self.atnum2vdw[torch.cat([prot_atomic_numbers, lig_atomic_numbers], dim=1)] + average_vdw = (atom_vdw[:, :, None] + atom_vdw[:, None, :]) / 2 + intermol_iscov_mask = distance_mat < average_vdw * 1.3 + intermol_iscov_mask[:, :n_protatm_per_sample, :n_protatm_per_sample] = False + gather_idx_i_molid = batch["indexer"]["gather_idx_i_molid"].view( + batch_size, n_ligatm_per_sample + ) + lig_samemol_mask = gather_idx_i_molid[:, :, None] == gather_idx_i_molid[:, None, :] + intermol_iscov_mask[ + :, n_protatm_per_sample:, n_protatm_per_sample: + ] = intermol_iscov_mask[:, n_protatm_per_sample:, n_protatm_per_sample:] & ( + ~lig_samemol_mask + ) + knn_edge_mask = knn_edge_mask | intermol_iscov_mask + else: + intermol_iscov_mask = torch.zeros_like(distance_mat, dtype=torch.bool) + p_idx = torch.arange(batch_size * n_protatm_per_sample, device=device).view( + batch_size, n_protatm_per_sample + ) + pp_edge_mask = knn_edge_mask[:, :n_protatm_per_sample, :n_protatm_per_sample] + batch["indexer"]["knn_idx_protatm_protatm_src"] = ( + p_idx[:, :, None].expand(-1, -1, n_protatm_per_sample).contiguous()[pp_edge_mask] + ) + batch["indexer"]["knn_idx_protatm_protatm_dst"] = ( + p_idx[:, None, :].expand(-1, n_protatm_per_sample, -1).contiguous()[pp_edge_mask] + ) + batch["features"]["knn_feat_protatm_protatm"] = self.covalent_embed( + intermol_iscov_mask[:, :n_protatm_per_sample, :n_protatm_per_sample][ + pp_edge_mask + ].long() + ) + if not batch["misc"]["protein_only"]: + l_idx = torch.arange(batch_size * n_ligatm_per_sample, device=device).view( + batch_size, n_ligatm_per_sample + ) + pl_edge_mask = knn_edge_mask[:, :n_protatm_per_sample, n_protatm_per_sample:] + batch["indexer"]["knn_idx_protatm_ligatm_src"] = ( + p_idx[:, :, None].expand(-1, -1, n_ligatm_per_sample).contiguous()[pl_edge_mask] + ) + batch["indexer"]["knn_idx_protatm_ligatm_dst"] = ( + l_idx[:, None, :].expand(-1, n_protatm_per_sample, -1).contiguous()[pl_edge_mask] + ) + batch["features"]["knn_feat_protatm_ligatm"] = self.covalent_embed( + intermol_iscov_mask[:, :n_protatm_per_sample, n_protatm_per_sample:][ + pl_edge_mask + ].long() + ) + lp_edge_mask = knn_edge_mask[:, n_protatm_per_sample:, :n_protatm_per_sample] + batch["indexer"]["knn_idx_ligatm_protatm_src"] = ( + l_idx[:, :, None].expand(-1, -1, n_protatm_per_sample).contiguous()[lp_edge_mask] + ) + batch["indexer"]["knn_idx_ligatm_protatm_dst"] = ( + p_idx[:, None, :].expand(-1, n_ligatm_per_sample, -1).contiguous()[lp_edge_mask] + ) + batch["features"]["knn_feat_ligatm_protatm"] = self.covalent_embed( + intermol_iscov_mask[:, n_protatm_per_sample:, :n_protatm_per_sample][ + lp_edge_mask + ].long() + ) + ll_edge_mask = knn_edge_mask[:, n_protatm_per_sample:, n_protatm_per_sample:] + batch["indexer"]["knn_idx_ligatm_ligatm_src"] = ( + l_idx[:, :, None].expand(-1, -1, n_ligatm_per_sample).contiguous()[ll_edge_mask] + ) + batch["indexer"]["knn_idx_ligatm_ligatm_dst"] = ( + l_idx[:, None, :].expand(-1, n_ligatm_per_sample, -1).contiguous()[ll_edge_mask] + ) + batch["features"]["knn_feat_ligatm_ligatm"] = self.covalent_embed( + intermol_iscov_mask[:, n_protatm_per_sample:, n_protatm_per_sample:][ + ll_edge_mask + ].long() + ) + + def init_esdm_inputs( + self, batch: MODEL_BATCH, embedding_iter_id: Union[int, str] + ) -> MODEL_BATCH: + """Initialize the inputs for the ESDM. + + :param batch: A batch dictionary. + :param embedding_iter_id: The embedding iteration ID. + :param kwargs: Additional keyword arguments. + :return: A batch dictionary. + """ + with torch.no_grad(): + self.init_randexp_kNN_edges_and_covmask( + batch, + detect_covalent=self.global_cfg.detect_covalent, + ) + batch["features"]["rec_res_attr_decin"] = batch["features"][ + f"rec_res_attr_out_{embedding_iter_id}" + ] + batch["features"]["res_res_pair_attr_decin"] = batch["features"][ + f"res_res_pair_attr_out_{embedding_iter_id}" + ] + batch["features"]["res_res_grid_attr_flat_decin"] = batch["features"][ + f"res_res_grid_attr_flat_out_{embedding_iter_id}" + ] + if batch["misc"]["protein_only"]: + return batch + batch["features"]["lig_trp_attr_decin"] = batch["features"][ + f"lig_trp_attr_out_{embedding_iter_id}" + ] + # Use protein-ligand edges from the contact predictor + batch["features"]["res_trp_grid_attr_flat_decin"] = batch["features"][ + f"res_trp_grid_attr_flat_out_{embedding_iter_id}" + ] + batch["features"]["res_trp_pair_attr_flat_decin"] = batch["features"][ + f"res_trp_pair_attr_flat_out_{embedding_iter_id}" + ] + batch["features"]["trp_trp_grid_attr_flat_decin"] = batch["features"][ + f"trp_trp_grid_attr_flat_out_{embedding_iter_id}" + ] + return batch + + def run_score_head( + self, + batch: MODEL_BATCH, + embedding_iter_id: Union[int, str], + frozen_lig: Optional[bool] = None, + frozen_prot: Optional[bool] = None, + **kwargs: Dict[str, Any], + ) -> MODEL_BATCH: + """Run the score head. + + :param batch: A batch dictionary. + :param embedding_iter_id: The embedding iteration ID. + :param frozen_lig: Whether to freeze the ligand backbone. + :param frozen_prot: Whether to freeze the protein backbone. + :param kwargs: Additional keyword arguments. + :return: A batch dictionary with the score head output. + """ + batch = self.init_esdm_inputs(batch, embedding_iter_id) + return self.score_head( + batch, + frozen_lig=frozen_lig + if frozen_lig is not None + else self.global_cfg.frozen_ligand_backbone, + frozen_prot=frozen_prot + if frozen_prot is not None + else self.global_cfg.frozen_protein_backbone, + **kwargs, + ) + + def run_confidence_head(self, batch: MODEL_BATCH, **kwargs: Dict[str, Any]) -> MODEL_BATCH: + """Run the confidence head. + + :param batch: A batch dictionary. + :param kwargs: Additional keyword arguments. + :return: A batch dictionary with the confidence head output. + """ + return self.confidence_head(batch, frozen_lig=False, frozen_prot=False, **kwargs) + + def run_affinity_head(self, batch: MODEL_BATCH, **kwargs: Dict[str, Any]) -> MODEL_BATCH: + """Run the affinity head. + + :param batch: A batch dictionary. + :param kwargs: Additional keyword arguments. + :return: Affinity head output. + """ + aff_out = self.affinity_head(batch, frozen_lig=False, frozen_prot=False, **kwargs) + aff_pooled = self.ligand_pooling( + aff_out["final_embedding_lig_atom"][:, 0], + batch["indexer"]["gather_idx_i_molid"], + batch["metadata"]["num_molid"], + ) + return self.affinity_proj_head(aff_pooled).squeeze(1) + + def run_auxiliary_estimation( + self, + batch: MODEL_BATCH, + struct: MODEL_BATCH, + return_avg_stats: bool = False, + **kwargs: Dict[str, Any], + ) -> Union[MODEL_BATCH, Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]]: + """Run auxiliary estimations. + + :param batch: A batch dictionary. + :param struct: A batch dictionary. + :param return_avg_stats: Whether to return average statistics. + :param kwargs: Additional keyword arguments. + :return: A batch dictionary or a tuple of average (optional) statistics. + """ + batch_size = batch["metadata"]["num_structid"] + batch["features"]["input_protein_coords"] = struct["receptor_padded"].clone() + if struct["ligands"] is not None: + batch["features"]["input_ligand_coords"] = struct["ligands"].clone() + else: + batch["features"]["input_ligand_coords"] = None + self.assign_timestep_encodings(batch, 0.0) + batch = self.run_encoder_stack(batch, use_template=False, use_plddt=False, **kwargs) + self.run_contact_map_stack(batch, iter_id="auxiliary") + batch = self.init_esdm_inputs(batch, "auxiliary") + if self.affinity_cfg.enabled: + batch["outputs"]["affinity_logits"] = self.run_affinity_head(batch) + if not self.confidence_cfg.enabled: + return batch + conf_out = self.run_confidence_head(batch) + conf_rep = ( + conf_out["final_embedding_prot_res"][:, 0] + .contiguous() + .view(batch_size, -1, self.cfg.confidence.fiber_dim) + ) + if struct["ligands"] is not None: + conf_rep_lig = ( + conf_out["final_embedding_lig_atom"][:, 0] + .contiguous() + .view(batch_size, -1, self.cfg.confidence.fiber_dim) + ) + conf_rep = torch.cat([conf_rep, conf_rep_lig], dim=1) + plddt_logits = F.log_softmax(self.plddt_gram_head(conf_rep), dim=-1) + batch["outputs"]["plddt_logits"] = plddt_logits + plddt_gram = torch.exp(plddt_logits) + batch["outputs"]["plddt"] = torch.cumsum(plddt_gram[:, :, :4], dim=-1).mean(dim=-1) + + if return_avg_stats: + plddt_avg = (batch["outputs"]["plddt"].view(batch_size, -1).mean(dim=1)).detach() + if struct["ligands"] is not None: + plddt_avg_lig = ( + batch["outputs"]["plddt"] + .view(batch_size, -1)[:, batch["metadata"]["num_a_per_sample"][0] :] + .mean(dim=1) + .detach() + ) + plddt_avg_ligs = segment_mean( + batch["outputs"]["plddt"] + .view(batch_size, -1)[:, batch["metadata"]["num_a_per_sample"][0] :] + .reshape(-1), + batch["indexer"]["gather_idx_i_molid"], + batch["metadata"]["num_molid"], + ).detach() + else: + plddt_avg_lig = None + plddt_avg_ligs = None + return plddt_avg, plddt_avg_lig, plddt_avg_ligs + + return batch + + @staticmethod + def reverse_interp_ode_step( + x0_hat: torch.Tensor, xt: torch.Tensor, t: torch.Tensor, s: torch.Tensor + ) -> torch.Tensor: + """Reverse process sampling using an Euler ODE solver. + + :param x0_hat: The denoised state. + :param xt: The intermediate (noisy) state. + :param t: The current timestep. + :param s: The next timestep after `t`. + :return: The interpolated state. + """ + step_size = t - s + return xt + step_size * x0_hat + + @staticmethod + def reverse_interp_vdode_step( + x0_hat: torch.Tensor, + xt: torch.Tensor, + t: torch.Tensor, + s: torch.Tensor, + eta: float = 1.0, + ) -> torch.Tensor: + """Reverse process sampling using an Euler Variance Diminishing ODE (VD-ODE) solver. + + Note that the LHS and RHS time step scaling factors are clamped to the range + `[1e-6, 1 - 1e-6]` by default. + + :param x0_hat: The denoised state. + :param xt: The intermediate (noisy) state. + :param t: The current timestep. + :param s: The next timestep after `t`. + :param eta: The variance diminishing factor to employ. + :return: The interpolated state. + """ + return clamp_tensor(1 - ((s / t) * eta)) * x0_hat + clamp_tensor((s / t) * eta) * xt + + def reverse_interp_plcomplex_latinp( + self, + batch: MODEL_BATCH, + t: torch.Tensor, + s: torch.Tensor, + latent_converter: LatentCoordinateConverter, + score_converter: LatentCoordinateConverter, + sampler_step_fn: Callable, + umeyama_correction: bool = True, + use_template: bool = False, + ) -> MODEL_BATCH: + """Reverse-interpolate protein-ligand complex latent internal coordinates. + + :param batch: A batch dictionary. + :param t: The current timestep. + :param s: The next timestep after `t`. + :param latent_converter: The latent coordinate converter. + :param score_converter: The latent score converter. + :param sampler_step_fn: The sampling step function to use. + :param umeyama_correction: Whether to apply the Umeyama correction. + :param use_template: Whether to use a given protein structure template. + :return: A batch dictionary. + """ + batch_size = batch["metadata"]["num_structid"] + + # derive dimension-less internal coordinates + x_int_t = latent_converter.to_latent(batch) + ( + x_int_t_ca_lat, + x_int_t_apo_ca_lat, + x_int_t_cother_lat, + x_int_t_apo_cother_lat, + x_int_t_ca_lat_centroid_coords, + x_int_t_apo_ca_lat_centroid_coords, + x_int_t_lig_lat, + ) = torch.split( + x_int_t, + [ + latent_converter._n_res_per_sample, + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_molid_per_sample, + latent_converter._n_molid_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + x_int_t_ = torch.cat( + [ + x_int_t_ca_lat, + x_int_t_cother_lat, + x_int_t_lig_lat, + ], + dim=1, + ) + self.assign_timestep_encodings(batch, t) + + batch = self.forward( + batch, + iter_id=score_converter.iter_id, + contact_prediction=True, + score=True, + observed_block_contacts=score_converter.sampled_block_contacts, + use_template=use_template, + training=False, + ) + if umeyama_correction: + _last_pred_ca_trace = ( + batch["outputs"]["denoised_prediction"]["final_coords_prot_atom_padded"][:, 1] + .view(batch_size, -1, 3) + .detach() + ) + if score_converter._last_pred_ca_trace is not None: + similarity_transform = corresponding_points_alignment( + _last_pred_ca_trace, + score_converter._last_pred_ca_trace, + estimate_scale=False, + ) + _last_pred_ca_trace = apply_similarity_transform( + _last_pred_ca_trace, *similarity_transform + ) + protatm_padding_mask = batch["features"]["res_atom_mask"] + pred_protatm_coords = ( + batch["outputs"]["denoised_prediction"]["final_coords_prot_atom_padded"][ + protatm_padding_mask + ] + .contiguous() + .view(batch_size, -1, 3) + ) + aligned_pred_protatm_coords = ( + apply_similarity_transform(pred_protatm_coords, *similarity_transform) + .contiguous() + .flatten(0, 1) + ) + batch["outputs"]["denoised_prediction"]["final_coords_prot_atom_padded"][ + protatm_padding_mask + ] = aligned_pred_protatm_coords + batch["outputs"]["denoised_prediction"][ + "final_coords_prot_atom" + ] = aligned_pred_protatm_coords + if not batch["misc"]["protein_only"]: + pred_ligatm_coords = batch["outputs"]["denoised_prediction"][ + "final_coords_lig_atom" + ].view(batch_size, -1, 3) + aligned_pred_ligatm_coords = ( + apply_similarity_transform(pred_ligatm_coords, *similarity_transform) + .contiguous() + .flatten(0, 1) + ) + batch["outputs"]["denoised_prediction"][ + "final_coords_lig_atom" + ] = aligned_pred_ligatm_coords + score_converter._last_pred_ca_trace = _last_pred_ca_trace + + # interpolate in the latent (reduced) coordinates space + x_int_hat_t = score_converter.to_latent(batch) + ( + x_int_hat_t_ca_lat, + _, + x_int_hat_t_cother_lat, + _, + _, + _, + x_int_hat_t_lig_lat, + ) = torch.split( + x_int_hat_t, + [ + latent_converter._n_res_per_sample, + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_molid_per_sample, + latent_converter._n_molid_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + x_int_hat_t_ = torch.cat( + [ + x_int_hat_t_ca_lat, + x_int_hat_t_cother_lat, + x_int_hat_t_lig_lat, + ], + dim=1, + ) + + # only interpolate using Ca atom, non-Ca atom, and ligand atom coordinates + x_int_tm = sampler_step_fn(x_int_hat_t_, x_int_t_, t, s) + + # reassemble outputs + ( + x_int_tm_ca_lat, + x_int_tm_cother_lat, + x_int_tm_lig_lat, + ) = torch.split( + x_int_tm, + [ + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + x_int_tm_ = torch.cat( + [ + x_int_tm_ca_lat, + x_int_t_apo_ca_lat, + x_int_tm_cother_lat, + x_int_t_apo_cother_lat, + x_int_t_ca_lat_centroid_coords, + x_int_t_apo_ca_lat_centroid_coords, + x_int_tm_lig_lat, + ], + dim=1, + ) + return latent_converter.assign_to_batch(batch, x_int_tm_) + + def sample_pl_complex_structures( + self, + batch: MODEL_BATCH, + num_steps: int = 100, + return_summary_stats: int = False, + return_all_states: bool = False, + sampler: Literal["ODE", "VDODE"] = "VDODE", + sampler_eta: float = 1.0, + umeyama_correction: bool = True, + start_time: float = 1.0, + exact_prior: bool = False, + eval_input_protein: bool = False, + align_to_ground_truth: bool = True, + use_template: Optional[bool] = None, + **kwargs, + ) -> MODEL_BATCH: + """Sample protein-ligand complex structures. + + :param batch: A batch dictionary. + :param num_steps: The number of steps. + :param return_summary_stats: Whether to return summary statistics. + :param return_all_states: Whether to return all states along with sampling metrics. + :param sampler: The reverse process sampler to use. + :param sampler_eta: The variance diminishing factor to employ for the `VDODE` sampler, + which offers a trade-off between exploration (1.0) and exploitation (> 1.0). + :param umeyama_correction: Apply optimal alignment between the denoised structure and + previous step outputs. + :param start_time: The start time. + :param exact_prior: Whether to use the exact prior. + :param eval_input_protein: Whether to evaluate the input protein structure. + :param align_to_ground_truth: Whether to align to the ground truth. + :param use_template: Whether to use a given protein structure template. + :param kwargs: Additional keyword arguments. + :return: A batch dictionary. + """ + assert num_steps > 0, "Invalid number of steps." + assert 0.0 <= start_time <= 1.0, "Invalid start time." + if use_template is None: + use_template = self.global_cfg.use_template + + features = batch["features"] + indexer = batch["indexer"] + metadata = batch["metadata"] + res_atom_mask = batch["features"]["res_atom_mask"].bool() + device = features["res_type"].device + batch_size = metadata["num_structid"] + + if "num_molid" in batch["metadata"].keys() and batch["metadata"]["num_molid"] > 0: + batch["misc"]["protein_only"] = False + else: + batch["misc"]["protein_only"] = True + + forward_lat_converter = self.resolve_latent_converter( + [ + ("features", "res_atom_positions"), + ("features", "input_protein_coords"), + ], + [("features", "sdf_coordinates"), ("features", "input_ligand_coords")], + ) + reverse_lat_converter = self.resolve_latent_converter( + [ + ("features", "input_protein_coords"), + ("features", "input_protein_coords"), + ], + [ + ("features", "input_ligand_coords"), + ("features", "input_ligand_coords"), + ], + ) + reverse_score_converter = self.resolve_latent_converter( + [ + ( + "outputs", + "denoised_prediction", + "final_coords_prot_atom_padded", + ), + None, + ], + [ + ( + "outputs", + "denoised_prediction", + "final_coords_lig_atom", + ), + None, + ], + ) + + with torch.no_grad(): + if not batch["misc"]["protein_only"]: + # Autoregressive block contact map prior + if exact_prior: + batch = self.prepare_protein_patch_indexers(batch) + _, contact_logit_matrix = eval_true_contact_maps( + batch, self.CONTACT_SCALE, **kwargs + ) + else: + batch = self.forward_interp_plcomplex_latinp( + batch, + start_time, + forward_lat_converter, + erase_data=(start_time >= 1.0), + ) + self.assign_timestep_encodings(batch, start_time) + # Sample the categorical contact encodings under the hood + batch = self.forward( + batch, + contact_prediction=True, + infer_geometry_prior=True, + use_template=use_template, + training=False, + ) + # Sample initial ligand coordinates from the geometry prior + contact_logit_matrix = batch["outputs"]["geometry_prior_L"] + + sampled_lig_res_anchor_mask = sample_res_rowmask_from_contacts( + batch, contact_logit_matrix, self.global_cfg.single_protein_batch + ) + num_cont_to_sample = max(metadata["num_I_per_sample"]) + sampled_block_contacts = None + for _ in range(num_cont_to_sample): + sampled_block_contacts = sample_reslig_contact_matrix( + batch, contact_logit_matrix, last=sampled_block_contacts + ) + forward_lat_converter.lig_res_anchor_mask = sampled_lig_res_anchor_mask + reverse_lat_converter.lig_res_anchor_mask = sampled_lig_res_anchor_mask + reverse_score_converter.lig_res_anchor_mask = sampled_lig_res_anchor_mask + reverse_score_converter.iter_id = num_cont_to_sample + reverse_score_converter.sampled_block_contacts = sampled_block_contacts + else: + reverse_score_converter.iter_id = 0 + reverse_score_converter.sampled_block_contacts = None + + if sampler == "ODE": + sampler_step_fn = self.reverse_interp_ode_step + elif sampler == "VDODE": + sampler_step_fn = partial(self.reverse_interp_vdode_step, eta=sampler_eta) + else: + raise NotImplementedError(f"Reverse process sampler {sampler} not implemented.") + + with torch.no_grad(): + # NOTE: Here, we assume the predicted contacts are robust to a resampling of geometric noise + batch = self.forward_interp_plcomplex_latinp( + batch, + start_time, + forward_lat_converter, + erase_data=(start_time >= 1.0), + ) + + if return_all_states: + all_frames = [ + { + "ligands": batch["features"]["input_ligand_coords"].cpu(), + "receptor": batch["features"]["input_protein_coords"][res_atom_mask].cpu(), + "receptor_padded": batch["features"]["input_protein_coords"].cpu(), + } + ] + if eval_input_protein: + protein_fape_input, _ = compute_fape_from_atom37( + batch, + device, + batch["features"]["input_protein_coords"], + batch["features"]["res_atom_positions"], + ) + tm_lbound_input = compute_TMscore_lbound( + batch, + batch["features"]["input_protein_coords"], + batch["features"]["res_atom_positions"], + ) + tm_lbound_mirrored_input = compute_TMscore_lbound( + batch, + -batch["features"]["input_protein_coords"], + batch["features"]["res_atom_positions"], + ) + tm_aligned_ca_input = compute_TMscore_raw( + batch, + batch["features"]["input_protein_coords"][:, 1], + batch["features"]["res_atom_positions"][:, 1], + ) + lddt_ca_input = compute_lddt_ca( + batch, + batch["features"]["input_protein_coords"], + batch["features"]["res_atom_positions"], + ) + input_ret = { + "FAPE_protein_input": protein_fape_input, + "TM_aligned_ca_input": tm_aligned_ca_input, + "TM_lbound_input": tm_lbound_input, + "TM_lbound_mirrored_input": tm_lbound_mirrored_input, + "lDDT-Ca_input": lddt_ca_input, + } + # NOTE: We follow https://arxiv.org/pdf/2402.04845.pdf for symbolic conventions + # NOTE: In the FlowDock paper, `t=0` and `t=1` are the start and end times, respectively, + # yet here we reverse the notation such that `t=1` and `t=0` are the start and end times, correspondingly, + # to simplify the integration of this implementation with the original NeuralPLexer code + schedule = torch.linspace(start_time, 0, num_steps + 1, device=device) + for t, s in tqdm.tqdm( + zip(schedule[:-1], schedule[1:]), desc=f"Structure generation using {sampler}" + ): + batch = self.reverse_interp_plcomplex_latinp( + batch, + t[None, None], + s[None, None], + reverse_lat_converter, + reverse_score_converter, + sampler_step_fn, + umeyama_correction=umeyama_correction, + use_template=use_template, + ) + if return_all_states: + # all_frames.append( + # { + # "ligands": batch["features"]["input_ligand_coords"], + # "receptor": batch["features"]["input_protein_coords"][ + # res_atom_mask + # ], + # "receptor_padded": batch["features"][ + # "input_protein_coords" + # ], + # } + # ) + all_frames.append( + { + "ligands": batch["outputs"]["denoised_prediction"][ + "final_coords_lig_atom" + ].cpu(), + "receptor": batch["outputs"]["denoised_prediction"][ + "final_coords_prot_atom" + ].cpu(), + "receptor_padded": batch["outputs"]["denoised_prediction"][ + "final_coords_prot_atom_padded" + ].cpu(), + } + ) + + mean_x1 = batch["outputs"]["denoised_prediction"]["final_coords_lig_atom"] + mean_x2_padded = batch["outputs"]["denoised_prediction"][ + "final_coords_prot_atom_padded" + ] + protatm_padding_mask = batch["features"]["res_atom_mask"] + mean_x2 = mean_x2_padded[protatm_padding_mask] + if align_to_ground_truth: + similarity_transform = corresponding_points_alignment( + mean_x2_padded[:, 1].view(batch_size, -1, 3), + batch["features"]["res_atom_positions"][:, 1].view(batch_size, -1, 3), + estimate_scale=False, + ) + mean_x2 = ( + apply_similarity_transform( + mean_x2.view(batch_size, -1, 3), *similarity_transform + ) + .contiguous() + .flatten(0, 1) + ) + mean_x2_padded[protatm_padding_mask] = mean_x2 + if mean_x1 is not None: + mean_x1 = ( + apply_similarity_transform( + mean_x1.view(batch_size, -1, 3), *similarity_transform + ) + .contiguous() + .flatten(0, 1) + ) + + if return_all_states: + all_frames.append( + { + "ligands": mean_x1.cpu(), + "receptor": mean_x2.cpu(), + "receptor_padded": mean_x2_padded.cpu(), + } + ) + protein_fape, _ = compute_fape_from_atom37( + batch, + device, + mean_x2_padded, + batch["features"]["res_atom_positions"], + ) + tm_lbound = compute_TMscore_lbound( + batch, + mean_x2_padded, + batch["features"]["res_atom_positions"], + ) + tm_lbound_mirrored = compute_TMscore_lbound( + batch, + -mean_x2_padded, + batch["features"]["res_atom_positions"], + ) + tm_aligned_ca = compute_TMscore_raw( + batch, + mean_x2_padded[:, 1], + batch["features"]["res_atom_positions"][:, 1], + ) + lddt_ca = compute_lddt_ca( + batch, + mean_x2_padded, + batch["features"]["res_atom_positions"], + ) + ret = { + "FAPE_protein": protein_fape, + "TM_aligned_ca": tm_aligned_ca, + "TM_lbound": tm_lbound, + "TM_lbound_mirrored": tm_lbound_mirrored, + "lDDT-Ca": lddt_ca, + } + if eval_input_protein: + ret.update(input_ret) + if mean_x1 is not None: + n_I_per_sample = max(metadata["num_I_per_sample"]) + lig_frame_atm_idx = torch.stack( + [ + indexer["gather_idx_ijk_i"][indexer["gather_idx_I_ijk"]][:n_I_per_sample], + indexer["gather_idx_ijk_j"][indexer["gather_idx_I_ijk"]][:n_I_per_sample], + indexer["gather_idx_ijk_k"][indexer["gather_idx_I_ijk"]][:n_I_per_sample], + ], + dim=0, + ) + _, lig_fape, _ = compute_fape_from_atom37( + batch, + device, + mean_x2_padded, + batch["features"]["res_atom_positions"], + pred_lig_coords=mean_x1, + target_lig_coords=batch["features"]["sdf_coordinates"], + lig_frame_atm_idx=lig_frame_atm_idx, + split_pl_views=True, + ) + coords_pred_prot = mean_x2_padded[res_atom_mask].view( + metadata["num_structid"], -1, 3 + ) + coords_ref_prot = batch["features"]["res_atom_positions"][res_atom_mask].view( + metadata["num_structid"], -1, 3 + ) + coords_pred_lig = mean_x1.view(metadata["num_structid"], -1, 3) + coords_ref_lig = batch["features"]["sdf_coordinates"].view( + metadata["num_structid"], -1, 3 + ) + lig_rmsd = segment_mean( + ( + (coords_pred_lig - coords_pred_prot.mean(dim=1, keepdim=True)) + - (coords_ref_lig - coords_ref_prot.mean(dim=1, keepdim=True)) + ) + .square() + .sum(dim=-1) + .flatten(0, 1), + indexer["gather_idx_i_molid"], + metadata["num_molid"], + ).sqrt() + lig_centroid_distance = ( + segment_mean( + (coords_pred_lig - coords_pred_prot.mean(dim=1, keepdim=True)).flatten( + 0, 1 + ), + indexer["gather_idx_i_molid"], + metadata["num_molid"], + ) + - segment_mean( + (coords_ref_lig - coords_ref_prot.mean(dim=1, keepdim=True)).flatten(0, 1), + indexer["gather_idx_i_molid"], + metadata["num_molid"], + ) + ).norm(dim=-1) + lig_hit_score_1A = (lig_rmsd < 1.0).float() + lig_hit_score_2A = (lig_rmsd < 2.0).float() + lig_hit_score_4A = (lig_rmsd < 4.0).float() + lddt_pli = compute_lddt_pli( + batch, + mean_x2_padded, + batch["features"]["res_atom_positions"], + mean_x1, + batch["features"]["sdf_coordinates"], + ) + ret.update( + { + "ligand_RMSD": lig_rmsd, + "ligand_centroid_distance": lig_centroid_distance, + "lDDT-pli": lddt_pli, + "FAPE_ligview": lig_fape, + "ligand_hit_score_1A": lig_hit_score_1A, + "ligand_hit_score_2A": lig_hit_score_2A, + "ligand_hit_score_4A": lig_hit_score_4A, + } + ) + + if return_summary_stats: + return ret + + if return_all_states: + ret.update({"all_frames": all_frames}) + + ret.update( + { + "ligands": mean_x1, + "receptor": mean_x2, + "receptor_padded": mean_x2_padded, + } + ) + return ret + + def forward( + self, + batch: MODEL_BATCH, + iter_id: Union[int, str] = 0, + observed_block_contacts: Optional[torch.Tensor] = None, + contact_prediction: bool = True, + infer_geometry_prior: bool = False, + score: bool = False, + use_template: bool = False, + **kwargs: Dict[str, Any], + ) -> Union[MODEL_BATCH, torch.Tensor]: + """Perform a forward pass through the model. + + :param batch: A batch dictionary. + :param training: Whether the model is in training mode. + :param iter_id: The current iteration ID. + :param observed_block_contacts: Observed block contacts. + :param contact_prediction: Whether to predict contacts. + :param infer_geometry_prior: Whether to predict using a geometry prior. + :param score: Whether to predict a denoised complex structure. + :param use_template: Whether to use a template protein structure. + :param kwargs: Additional keyword arguments. + :return: Batch dictionary with outputs or ligand binding affinity. + """ + prepare_batch(batch) + + batch = self.run_encoder_stack( + batch, + use_template=use_template, + use_plddt=self.global_cfg.use_plddt, + **kwargs, + ) + + if contact_prediction: + self.run_contact_map_stack( + batch, + iter_id, + observed_block_contacts=observed_block_contacts, + **kwargs, + ) + + if infer_geometry_prior: + assert ( + batch["misc"]["protein_only"] is False + ), "Only protein-ligand complexes are supported for a geometry prior." + self.infer_geometry_prior(batch, **kwargs) + + if score: + batch["outputs"]["denoised_prediction"] = self.run_score_head( + batch, embedding_iter_id=iter_id, **kwargs + ) + + return batch + + +if __name__ == "__main__": + _ = FlowDock() diff --git a/flowdock/models/components/hetero_graph.py b/flowdock/models/components/hetero_graph.py new file mode 100644 index 0000000..5d5cab1 --- /dev/null +++ b/flowdock/models/components/hetero_graph.py @@ -0,0 +1,166 @@ +# Adapted from: https://github.com/zrqiao/NeuralPLexer + +from dataclasses import dataclass + +import torch +from beartype.typing import Dict, List, Tuple + + +@dataclass(frozen=True) +class Relation: + edge_type: str + edge_rev_name: str + edge_frd_name: str + src_node_type: str + dst_node_type: str + num_edges: int + + +class MultiRelationGraphBatcher: + """Collate sub-graphs of different node/edge types into a single instance. + + Returned multi-relation edge indices are stored in LongTensor of shape [2, N_edges]. + """ + + def __init__( + self, + relation_forms: List[Relation], + graph_metadata: Dict[str, int], + ): + """Initialize the batcher.""" + self._relation_forms = relation_forms + self._make_offset_dict(graph_metadata) + + def _make_offset_dict(self, graph_metadata): + """Create offset dictionaries for node and edge types.""" + self._node_chunk_sizes = {} + self._edge_chunk_sizes = {} + self._offsets_lower = {} + self._offsets_upper = {} + all_node_types = set() + for relation in self._relation_forms: + assert ( + f"num_{relation.src_node_type}" in graph_metadata.keys() + ), f"Missing metadata: num_{relation.src_node_type}" + assert ( + f"num_{relation.dst_node_type}" in graph_metadata.keys() + ), f"Missing metadata: num_{relation.src_node_type}" + all_node_types.add(relation.src_node_type) + all_node_types.add(relation.dst_node_type) + offset = 0 + # Fix node type ordering + self.all_node_types = list(all_node_types) + for node_type in self.all_node_types: + self._offsets_lower[node_type] = offset + self._node_chunk_sizes[node_type] = graph_metadata[f"num_{node_type}"] + new_offset = offset + self._node_chunk_sizes[node_type] + self._offsets_upper[node_type] = new_offset + offset = new_offset + + def collate_single_relation_graphs(self, indexer, node_attr_dict, edge_attr_dict): + """Collate sub-graphs of different node/edge types into a single instance.""" + return { + "node_attr": self.collate_node_attr(node_attr_dict), + "edge_attr": self.collate_edge_attr(edge_attr_dict), + "edge_index": self.collate_idx_list(indexer), + } + + def collate_idx_list( + self, + indexer: Dict[str, torch.Tensor], + ) -> torch.Tensor: + """Collate edge indices for all relations.""" + ret_eidxs_rev, ret_eidxs_frd = [], [] + for relation in self._relation_forms: + assert relation.edge_rev_name in indexer.keys() + assert relation.edge_frd_name in indexer.keys() + assert indexer[relation.edge_rev_name].dim() == 1 + assert indexer[relation.edge_frd_name].dim() == 1 + assert torch.all( + indexer[relation.edge_rev_name] < self._node_chunk_sizes[relation.src_node_type] + ), f"Node index on edge exceeding boundary: {relation.edge_type}, {self._node_chunk_sizes[relation.src_node_type]}, {self._node_chunk_sizes[relation.dst_node_type]}, {max(indexer[relation.edge_rev_name])}, {max(indexer[relation.edge_frd_name])}" + assert torch.all( + indexer[relation.edge_frd_name] < self._node_chunk_sizes[relation.dst_node_type] + ), f"Node index on edge exceeding boundary: {relation.edge_type}, {self._node_chunk_sizes[relation.src_node_type]}, {self._node_chunk_sizes[relation.dst_node_type]}, {max(indexer[relation.edge_rev_name])}, {max(indexer[relation.edge_frd_name])}" + ret_eidxs_rev.append( + indexer[relation.edge_rev_name] + self._offsets_lower[relation.src_node_type] + ) + ret_eidxs_frd.append( + indexer[relation.edge_frd_name] + self._offsets_lower[relation.dst_node_type] + ) + ret_eidxs_rev = torch.cat(ret_eidxs_rev, dim=0) + ret_eidxs_frd = torch.cat(ret_eidxs_frd, dim=0) + return torch.stack([ret_eidxs_rev, ret_eidxs_frd], dim=0) + + def collate_node_attr(self, node_attr_dict: Dict[str, torch.Tensor]): + """Collate node attributes for all node types.""" + for node_type in self.all_node_types: + assert ( + node_attr_dict[node_type].shape[0] == self._node_chunk_sizes[node_type] + ), f"Node count mismatch: {node_type}, {node_attr_dict[node_type].shape[0]}, {self._node_chunk_sizes[node_type]}" + return torch.cat([node_attr_dict[node_type] for node_type in self.all_node_types], dim=0) + + def collate_edge_attr(self, edge_attr_dict: Dict[str, torch.Tensor]): + """Collate edge attributes for all relations.""" + # for relation in self._relation_forms: + # print(relation.edge_type, edge_attr_dict[relation.edge_type].shape) + return torch.cat( + [edge_attr_dict[relation.edge_type] for relation in self._relation_forms], + dim=0, + ) + + def zero_pad_edge_attr( + self, + edge_attr_dict: Dict[str, torch.Tensor], + embedding_dim: int, + device: torch.device, + ): + """Zero pad edge attributes for all relations.""" + for relation in self._relation_forms: + if edge_attr_dict[relation.edge_type] is None: + edge_attr_dict[relation.edge_type] = torch.zeros( + (relation.num_edges, embedding_dim), + device=device, + ) + return edge_attr_dict + + def offload_node_attr(self, cat_node_attr: torch.Tensor): + """Offload node attributes for all node types.""" + node_chunk_sizes = [self._node_chunk_sizes[node_type] for node_type in self.all_node_types] + node_attr_split = torch.split(cat_node_attr, node_chunk_sizes) + return { + self.all_node_types[i]: node_attr_split[i] for i in range(len(self.all_node_types)) + } + + def offload_edge_attr(self, cat_edge_attr: torch.Tensor): + """Offload edge attributes for all relations.""" + edge_chunk_sizes = [relation.num_edges for relation in self._relation_forms] + edge_attr_split = torch.split(cat_edge_attr, edge_chunk_sizes) + return { + self._relation_forms[i].edge_type: edge_attr_split[i] + for i in range(len(self._relation_forms)) + } + + +def make_multi_relation_graph_batcher( + list_of_relations: List[Tuple[str, str, str, str, str]], + indexer, + metadata, +): + """Make a multi-relation graph batcher.""" + # Use one instantiation of the indexer to compute chunk sizes + relation_forms = [ + Relation( + edge_type=rl_tuple[0], + edge_rev_name=rl_tuple[1], + edge_frd_name=rl_tuple[2], + src_node_type=rl_tuple[3], + dst_node_type=rl_tuple[4], + num_edges=indexer[rl_tuple[1]].shape[0], + ) + for rl_tuple in list_of_relations + ] + return MultiRelationGraphBatcher( + relation_forms, + metadata, + ) diff --git a/flowdock/models/components/losses.py b/flowdock/models/components/losses.py new file mode 100644 index 0000000..d9d1731 --- /dev/null +++ b/flowdock/models/components/losses.py @@ -0,0 +1,1439 @@ +import random + +import rootutils +import torch +import torch.nn.functional as F +from beartype.typing import Any, Dict, Literal, Optional, Tuple, Union +from lightning import LightningModule + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils.frame_utils import cartesian_to_internal, get_frame_matrix +from flowdock.utils.metric_utils import compute_per_atom_lddt +from flowdock.utils.model_utils import ( + distance_to_gaussian_contact_logits, + distogram_to_gaussian_contact_logits, + eval_true_contact_maps, + sample_res_rowmask_from_contacts, + sample_reslig_contact_matrix, + segment_mean, +) + +MODEL_BATCH = Dict[str, Any] +MODEL_STAGE = Literal["train", "val", "test", "predict"] +LOSS_MODES = Literal[ + "structure_prediction", + "auxiliary_estimation", + "auxiliary_estimation_without_structure_prediction", +] + + +def compute_contact_prediction_losses( + pred_distograms: torch.Tensor, + ref_dist_mat: torch.Tensor, + dist_bins: torch.Tensor, + contact_scale: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute the contact prediction losses for a given batch. + + :param pred_distograms: The predicted distograms. + :param ref_dist_mat: The reference distance matrix. + :param dist_bins: The distance bins. + :param contact_scale: The contact scale. + :return: The distogram and forward KL losses. + """ + # True onehot distance and distogram loss + distance_bin_idx = torch.bucketize(ref_dist_mat, dist_bins[:-1], right=True) + distogram_loss = F.cross_entropy(pred_distograms.flatten(0, -2), distance_bin_idx.flatten()) + # Evaluate contact logits via log(\sum_j p_j \exp(-\alpha*r_j^2)) + ref_contact_logits = distance_to_gaussian_contact_logits(ref_dist_mat, contact_scale) + pred_contact_logits = distogram_to_gaussian_contact_logits( + pred_distograms, + dist_bins, + contact_scale, + ) + forward_kl_loss = F.kl_div( + F.log_softmax( + pred_contact_logits.flatten(-2, -1), + dim=-1, + ), + F.log_softmax( + ref_contact_logits.flatten(-2, -1), + dim=-1, + ), + log_target=True, + reduction="batchmean", + ) + return distogram_loss, forward_kl_loss + + +def compute_protein_distogram_loss( + batch: MODEL_BATCH, + target_coords: torch.Tensor, + dist_bins: torch.Tensor, + dgram_head: torch.nn.Module, + entry: str = "res_res_grid_attr_flat", +) -> torch.Tensor: + """Compute the protein distogram loss for a given batch. + + :param batch: A batch dictionary. + :param target_coords: The target coordinates. + :param dist_bins: The distance bins. + :param dgram_head: The distogram head to use for loss calculation. + :param entry: The entry to use. + :return: The distogram loss. + """ + n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"] + sampled_grid_features = batch["features"][entry] + sampled_ca_coords = target_coords[batch["indexer"]["gather_idx_pid_a"]].view( + batch["metadata"]["num_structid"], n_protein_patches, 3 + ) + sampled_ca_dist = torch.norm( + sampled_ca_coords[:, :, None] - sampled_ca_coords[:, None, :], dim=-1 + ) + # Using AF2 parameters + distance_bin_idx = torch.bucketize(sampled_ca_dist, dist_bins[:-1], right=True) + distogram_loss = F.cross_entropy(dgram_head(sampled_grid_features), distance_bin_idx.flatten()) + return distogram_loss + + +def compute_fape_from_atom37( + batch: MODEL_BATCH, + device: Union[str, torch.device], + pred_prot_coords: torch.Tensor, # [N_res, 37, 3] + target_prot_coords: torch.Tensor, # [N_res, 37, 3] + pred_lig_coords: Optional[torch.Tensor] = None, # [N_atom, 3] + target_lig_coords: Optional[torch.Tensor] = None, # [N_atom, 3] + lig_frame_atm_idx: Optional[torch.Tensor] = None, # [3, N_atom] + split_pl_views: bool = False, + cap_size: int = 8000, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """Compute the Frame Aligned Point Error (FAPE) loss from `atom37` coordinates. + + :param batch: A batch dictionary. + :param device: The device to use. + :param pred_prot_coords: The predicted protein coordinates. + :param target_prot_coords: The target protein coordinates. + :param pred_lig_coords: The predicted ligand coordinates. + :param target_lig_coords: The target ligand coordinates. + :param lig_frame_atm_idx: The ligand frame atom indices. + :param split_pl_views: Whether to split the protein-ligand views. + :param cap_size: The capped size. + :return: The FAPE loss. + """ + features = batch["features"] + batch_size = batch["metadata"]["num_structid"] + with torch.no_grad(): + atom_mask = ( + features["res_atom_mask"].bool().view(batch["metadata"]["num_structid"], -1, 37) + ).clone() + atom_mask[:, :, [6, 7, 12, 13, 16, 17, 20, 21, 26, 27, 29, 30]] = False + pred_prot_coords = pred_prot_coords.view(batch_size, -1, 37, 3) + target_prot_coords = target_prot_coords.view(batch_size, -1, 37, 3) + pred_bb_frames = get_frame_matrix( + pred_prot_coords[:, :, 0, :], + pred_prot_coords[:, :, 1, :], + pred_prot_coords[:, :, 2, :], + ) + # pred_bb_frames.R = pred_bb_frames.R.detach() + target_bb_frames = get_frame_matrix( + target_prot_coords[:, :, 0, :], + target_prot_coords[:, :, 1, :], + target_prot_coords[:, :, 2, :], + ) + pred_prot_coords_flat = pred_prot_coords[atom_mask].view(batch_size, -1, 3) + target_prot_coords_flat = target_prot_coords[atom_mask].view(batch_size, -1, 3) + if pred_lig_coords is not None: + assert target_prot_coords is not None, "Target protein coordinates must be provided." + assert lig_frame_atm_idx is not None, "Ligand frame atom indices must be provided." + pred_lig_coords = pred_lig_coords.view(batch_size, -1, 3) + target_lig_coords = target_lig_coords.view(batch_size, -1, 3) + pred_coords = torch.cat([pred_prot_coords_flat, pred_lig_coords], dim=1) + target_coords = torch.cat([target_prot_coords_flat, target_lig_coords], dim=1) + pred_lig_frames = get_frame_matrix( + pred_lig_coords[:, lig_frame_atm_idx[0]], + pred_lig_coords[:, lig_frame_atm_idx[1]], + pred_lig_coords[:, lig_frame_atm_idx[2]], + ) + pred_frames = pred_bb_frames.concatenate(pred_lig_frames, dim=1) + target_lig_frames = get_frame_matrix( + target_lig_coords[:, lig_frame_atm_idx[0]], + target_lig_coords[:, lig_frame_atm_idx[1]], + target_lig_coords[:, lig_frame_atm_idx[2]], + ) + target_frames = target_bb_frames.concatenate(target_lig_frames, dim=1) + else: + pred_coords = pred_prot_coords_flat + target_coords = target_prot_coords_flat + pred_frames = pred_bb_frames + target_frames = target_bb_frames + # Columns-frames, rows-points + # [B, 1, N, 3] - [B, F, 1, 3] + sampling_rate = cap_size / (batch_size * target_coords.shape[1]) + sampling_mask = torch.rand(target_coords.shape[1], device=device) < sampling_rate + aligned_pred_points = cartesian_to_internal( + pred_coords[:, sampling_mask].unsqueeze(1), pred_frames.unsqueeze(2) + ) + with torch.no_grad(): + aligned_target_points = cartesian_to_internal( + target_coords[:, sampling_mask].unsqueeze(1), target_frames.unsqueeze(2) + ) + pair_dist_aligned = ( + torch.square(aligned_pred_points - aligned_target_points) + .sum(-1) + .add(1e-4) + .sqrt() + .sub(1e-2) + ) + cropped_pair_dists = torch.clamp(pair_dist_aligned, max=10) + normalized_pair_dists = ( + pair_dist_aligned / aligned_target_points.square().sum(-1).add(1e-4).sqrt() + ) + if split_pl_views: + fape_protframe = cropped_pair_dists[:, : target_bb_frames.t.shape[1]].mean((1, 2)) / 10 + fape_ligframe = cropped_pair_dists[:, target_bb_frames.t.shape[1] :].mean((1, 2)) / 10 + return fape_protframe, fape_ligframe, normalized_pair_dists.mean((1, 2)) + return cropped_pair_dists.mean((1, 2)) / 10, normalized_pair_dists.mean((1, 2)) + + +def compute_aa_distance_geometry_loss( + batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor +) -> torch.Tensor: + """Compute the amino acid distance geometry loss for a given batch. + + :param batch: A batch dictionary. + :param pred_coords: The predicted coordinates. + :param target_coords: The target coordinates. + :return: The distance geometry loss. + """ + batch_size = batch["metadata"]["num_structid"] + features = batch["features"] + atom_mask = features["res_atom_mask"].bool() + # Add backbone atoms from previous residue + atom_mask = atom_mask.view(batch_size, -1, 37) + atom_mask = torch.cat([atom_mask[:, 1:], atom_mask[:, :-1, 0:3]], dim=2).flatten(0, 1) + pred_coords = pred_coords.view(batch_size, -1, 37, 3) + pred_coords = torch.cat([pred_coords[:, 1:], pred_coords[:, :-1, 0:3]], dim=2).flatten(0, 1) + target_coords = target_coords.view(batch_size, -1, 37, 3) + target_coords = torch.cat([target_coords[:, 1:], target_coords[:, :-1, 0:3]], dim=2).flatten( + 0, 1 + ) + local_pair_dist_target = ( + (target_coords[:, None, :] - target_coords[:, :, None]).square().sum(-1).add(1e-4).sqrt() + ) + local_pair_dist_pred = ( + (pred_coords[:, None, :] - pred_coords[:, :, None]).square().sum(-1).add(1e-4).sqrt() + ) + local_pair_mask = ( + atom_mask[:, None, :] & atom_mask[:, :, None] & (local_pair_dist_target < 3.0) + ) + ret = (local_pair_dist_target - local_pair_dist_pred).abs()[local_pair_mask] + return ret.view(batch["metadata"]["num_structid"], -1).mean(dim=1) + + +def compute_sm_distance_geometry_loss( + batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor +) -> torch.Tensor: + """Compute the small molecule distance geometry loss for a given batch. + + :param batch: A batch dictionary. + :param pred_coords: The predicted coordinates. + :param target_coords: The target coordinates. + :return: The distance geometry loss. + """ + batch_size = batch["metadata"]["num_structid"] + pred_coords = pred_coords.view(batch_size, -1, 3) + target_coords = target_coords.view(batch_size, -1, 3) + pair_dist_target = ( + (target_coords[:, None, :] - target_coords[:, :, None]).square().sum(-1).add(1e-4).sqrt() + ) + pair_dist_pred = ( + (pred_coords[:, None, :] - pred_coords[:, :, None]).square().sum(-1).add(1e-4).sqrt() + ) + local_pair_mask = pair_dist_target < 3.0 + ret = (pair_dist_target - pair_dist_pred).abs()[local_pair_mask] + return ret.view(batch_size, -1).mean(dim=1) + + +def compute_drmsd_and_clashloss( + batch: MODEL_BATCH, + device: Union[str, torch.device], + pred_prot_coords: torch.Tensor, + target_prot_coords: torch.Tensor, + atnum2vdw_uff: torch.nn.Parameter, + cap_size: int = 4000, + pred_lig_coords: Optional[torch.Tensor] = None, + target_lig_coords: Optional[torch.Tensor] = None, + ligatm_types: Optional[torch.Tensor] = None, + binding_site: bool = False, + pl_interface: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Compute the differentiable root-mean-square deviation (dRMSD) and optional clash loss for a + given batch. + + :param batch: A batch dictionary. + :param device: The device to use. + :param pred_prot_coords: The predicted protein coordinates. + :param target_prot_coords: The target protein coordinates. + :param atnum2vdw_uff: The atomic number to UFF VDW parameters mapping `Parameter`. + :param cap_size: The capped size. + :param pred_lig_coords: The predicted ligand coordinates. + :param target_lig_coords: The target ligand coordinates. + :param ligatm_types: The ligand atom types. + :param binding_site: Whether to compute the binding site. + :param pl_interface: Whether to compute the protein-ligand interface. + :return: The dRMSD and optional clash loss. + """ + features = batch["features"] + with torch.no_grad(): + if not binding_site: + atom_mask = features["res_atom_mask"].bool().clone() + else: + atom_mask = ( + features["res_atom_mask"].bool() & features["binding_site_mask_clean"][:, None] + ) + if pl_interface: + # Removing ambiguous atoms + atom_mask[:, [6, 7, 12, 13, 16, 17, 20, 21, 26, 27, 29, 30]] = False + + batch_size = batch["metadata"]["num_structid"] + pred_prot_coords = pred_prot_coords[atom_mask].view(batch_size, -1, 3) + if pred_lig_coords is not None: + assert target_prot_coords is not None, "Target protein coordinates must be provided." + assert ligatm_types is not None, "Ligand atom types must be provided." + pred_lig_coords = pred_lig_coords.view(batch_size, -1, 3) + pred_coords = torch.cat([pred_prot_coords, pred_lig_coords], dim=1) + else: + pred_coords = pred_prot_coords + sampling_rate = cap_size / pred_coords.shape[1] + sampling_mask = torch.rand(pred_coords.shape[1], device=device) < sampling_rate + pred_coords = pred_coords[:, sampling_mask] + if pl_interface: + pred_dist = ( + torch.square(pred_coords[:, :, None] - pred_lig_coords[:, None, :]) + .sum(-1) + .add(1e-4) + .sqrt() + .sub(1e-2) + ) + else: + pred_dist = ( + torch.square(pred_coords[:, :, None] - pred_coords[:, None, :]) + .sum(-1) + .add(1e-4) + .sqrt() + .sub(1e-2) + ) + with torch.no_grad(): + target_prot_coords = target_prot_coords[atom_mask].view(batch_size, -1, 3) + if pred_lig_coords is not None: + target_lig_coords = target_lig_coords.view(batch_size, -1, 3) + target_coords = torch.cat([target_prot_coords, target_lig_coords], dim=1) + else: + target_coords = target_prot_coords + target_coords = target_coords[:, sampling_mask] + if pl_interface: + target_dist = ( + torch.square(target_coords[:, :, None] - target_lig_coords[:, None, :]) + .sum(-1) + .add(1e-4) + .sqrt() + .sub(1e-2) + ) + else: + target_dist = ( + torch.square(target_coords[:, :, None] - target_coords[:, None, :]) + .sum(-1) + .add(1e-4) + .sqrt() + .sub(1e-2) + ) + # In Angstrom, using UFF params to compute clash loss + protatm_types = features["res_atom_types"].long()[atom_mask] + protatm_vdw = atnum2vdw_uff[protatm_types].view(batch_size, -1) + if pred_lig_coords is not None: + ligatm_vdw = atnum2vdw_uff[ligatm_types].view(batch_size, -1) + atm_vdw = torch.cat([protatm_vdw, ligatm_vdw], dim=1) + else: + atm_vdw = protatm_vdw + atm_vdw = atm_vdw[:, sampling_mask] + average_vdw = (atm_vdw[:, :, None] + atm_vdw[:, None, :]) / 2 + # Use conservative cutoffs to avoid mis-penalization + + dist_errors = (pred_dist - target_dist).square() + drmsd = dist_errors.add(1e-2).sqrt().sub(1e-1).mean(dim=(1, 2)) + if pl_interface: + return drmsd, None + + covalent_like = target_dist < (average_vdw * 1.2) + # Alphafold supplementary Eq. 46, modified + clash_pairwise = torch.clamp(average_vdw * 1.1 - pred_dist.add(1e-6), min=0.0) + clash_loss = clash_pairwise.mul(~covalent_like).sum(dim=2).mean(dim=1) + return drmsd, clash_loss + + +def compute_template_weighted_centroid_drmsd( + batch: MODEL_BATCH, + pred_prot_coords: torch.Tensor, +) -> torch.Tensor: + """Compute the template-weighted centroid dRMSD for a given batch. + + :param batch: A batch dictionary. + :param pred_prot_coords: The predicted protein coordinates. + :return: The dRMSD. + """ + batch_size = batch["metadata"]["num_structid"] + + pred_cent_coords = ( + pred_prot_coords.mul(batch["features"]["res_atom_mask"].bool()[:, :, None]) + .sum(dim=1) + .div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9) + ).view(batch_size, -1, 3) + pred_dist = ( + torch.square(pred_cent_coords[:, :, None] - pred_cent_coords[:, None, :]) + .sum(-1) + .add(1e-4) + .sqrt() + .sub(1e-2) + ) + with torch.no_grad(): + target_cent_coords = ( + batch["features"]["res_atom_positions"] + .mul(batch["features"]["res_atom_mask"].bool()[:, :, None]) + .sum(dim=1) + .div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9) + ).view(batch_size, -1, 3) + template_cent_coords = ( + batch["features"]["apo_res_atom_positions"] + .mul(batch["features"]["apo_res_atom_mask"].bool()[:, :, None]) + .sum(dim=1) + .div(batch["features"]["apo_res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9) + ).view(batch_size, -1, 3) + target_dist = ( + torch.square(target_cent_coords[:, :, None] - target_cent_coords[:, None, :]) + .sum(-1) + .add(1e-4) + .sqrt() + .sub(1e-2) + ) + template_dist = ( + torch.square(template_cent_coords[:, :, None] - template_cent_coords[:, None, :]) + .sum(-1) + .add(1e-4) + .sqrt() + .sub(1e-2) + ) + template_alignment_mask = ( + batch["features"]["apo_res_alignment_mask"].bool().view(batch_size, -1) + ) + motion_mask = ( + ((target_dist - template_dist).abs() > 2.0) + * template_alignment_mask[:, None, :] + * template_alignment_mask[:, :, None] + ) + + dist_errors = (pred_dist - target_dist).square() + drmsd = (dist_errors.add(1e-4).sqrt().sub(1e-2).mul(motion_mask).sum(dim=(1, 2))) / ( + motion_mask.long().sum(dim=(1, 2)) + 1 + ) + return drmsd + + +def compute_TMscore_lbound( + batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor +) -> torch.Tensor: + """Compute the TM-score lower bound for a given batch. + + :param batch: A batch dictionary. + :param pred_coords: The predicted coordinates. + :param target_coords: The target coordinates. + :return: The TM-score lower bound. + """ + features = batch["features"] + atom_mask = features["res_atom_mask"].bool().view(batch["metadata"]["num_structid"], -1, 37) + pred_coords = pred_coords.view(batch["metadata"]["num_structid"], -1, 37, 3) + target_coords = target_coords.view(batch["metadata"]["num_structid"], -1, 37, 3) + pred_bb_frames = get_frame_matrix( + pred_coords[:, :, 0, :], + pred_coords[:, :, 1, :], + pred_coords[:, :, 2, :], + strict=True, + ) + target_bb_frames = get_frame_matrix( + target_coords[:, :, 0, :], + target_coords[:, :, 1, :], + target_coords[:, :, 2, :], + strict=True, + ) + pred_coords_flat = pred_coords[atom_mask].view(batch["metadata"]["num_structid"], -1, 3) + target_coords_flat = target_coords[atom_mask].view(batch["metadata"]["num_structid"], -1, 3) + # Columns-frames, rows-points + # [B, 1, N, 3] - [B, F, 1, 3] + aligned_pred_points = cartesian_to_internal( + pred_coords_flat.unsqueeze(1), pred_bb_frames.unsqueeze(2) + ) + with torch.no_grad(): + aligned_target_points = cartesian_to_internal( + target_coords_flat.unsqueeze(1), target_bb_frames.unsqueeze(2) + ) + pair_dist_aligned = (aligned_pred_points - aligned_target_points).norm(dim=-1) + tm_normalizer = 1.24 * (max(target_coords.shape[1], 19) - 15) ** (1 / 3) - 1.8 + per_frame_tm = torch.mean(1 / (1 + (pair_dist_aligned / tm_normalizer) ** 2), dim=2) + return torch.amax(per_frame_tm, dim=1) + + +def compute_TMscore_raw( + batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor +) -> torch.Tensor: + """Compute the raw TM-score for a given batch. + + :param batch: A batch dictionary. + :param pred_coords: The predicted coordinates. + :param target_coords: The target coordinates. + :return: The raw TM-score. + """ + pred_coords = pred_coords.view(batch["metadata"]["num_structid"], -1, 3) + target_coords = target_coords.view(batch["metadata"]["num_structid"], -1, 3) + pair_dist_aligned = (pred_coords - target_coords).norm(dim=-1) + tm_normalizer = 1.24 * (max(target_coords.shape[1], 19) - 15) ** (1 / 3) - 1.8 + per_struct_tm = torch.mean(1 / (1 + (pair_dist_aligned / tm_normalizer) ** 2), dim=1) + return per_struct_tm + + +def compute_lddt_ca( + batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor +) -> torch.Tensor: + """Compute the local distance difference test (lDDT) for C-alpha atoms for a given batch. + + :param batch: A batch dictionary. + :param pred_coords: The predicted coordinates. + :param target_coords: The target coordinates. + :return: The lDDT for C-alpha atoms. + """ + pred_coords = pred_coords.view(batch["metadata"]["num_structid"], -1, 37, 3) + target_coords = target_coords.view(batch["metadata"]["num_structid"], -1, 37, 3) + pred_ca_flat = pred_coords[:, :, 1] + target_ca_flat = target_coords[:, :, 1] + target_dist = (target_ca_flat[:, :, None] - target_ca_flat[:, None, :]).norm(dim=-1) + pred_dist = (pred_ca_flat[:, :, None] - pred_ca_flat[:, None, :]).norm(dim=-1) + conserved_mask = target_dist < 15.0 + lddt = 0 + for threshold in [0.5, 1, 2, 4]: + below_threshold = (pred_dist - target_dist).abs() < threshold + lddt = lddt + below_threshold.mul(conserved_mask).sum((1, 2)) / conserved_mask.sum((1, 2)) + return lddt / 4 + + +def compute_lddt_pli( + batch: MODEL_BATCH, + pred_prot_coords: torch.Tensor, + target_prot_coords: torch.Tensor, + pred_lig_coords: torch.Tensor, + target_lig_coords: torch.Tensor, +) -> torch.Tensor: + """Compute the local distance difference test (lDDT) for protein-ligand interface atoms for a + given batch. + + :param batch: A batch dictionary. + :param pred_prot_coords: The predicted protein coordinates. + :param target_prot_coords: The target protein coordinates. + :param pred_lig_coords: The predicted ligand coordinates. + :param target_lig_coords: The target ligand coordinates. + :return: The lDDT for protein-ligand interface atoms. + """ + features = batch["features"] + batch_size = batch["metadata"]["num_structid"] + atom_mask = features["res_atom_mask"].bool() + pred_prot_coords = pred_prot_coords[atom_mask].view(batch_size, -1, 3) + target_prot_coords = target_prot_coords[atom_mask].view(batch_size, -1, 3) + pred_lig_coords = pred_lig_coords.view(batch_size, -1, 3) + target_lig_coords = target_lig_coords.view(batch_size, -1, 3) + target_dist = (target_prot_coords[:, :, None] - target_lig_coords[:, None, :]).norm(dim=-1) + pred_dist = (pred_prot_coords[:, :, None] - pred_lig_coords[:, None, :]).norm(dim=-1) + conserved_mask = target_dist < 6.0 + lddt = 0 + for threshold in [0.5, 1, 2, 4]: + below_threshold = (pred_dist - target_dist).abs() < threshold + lddt = lddt + below_threshold.mul(conserved_mask).sum((1, 2)) / conserved_mask.sum((1, 2)) + return lddt / 4 + + +def eval_structure_prediction_losses( + lit_module: LightningModule, + batch: MODEL_BATCH, + batch_idx: int, + device: Union[str, torch.device], + stage: MODEL_STAGE, + t_1: float = 1.0, +) -> MODEL_BATCH: + """Evaluate the structure prediction losses for a given batch. + + :param lit_module: The LightningModule object to reference. + :param batch: A batch dictionary. + :param batch_idx: The batch index. + :param device: The device to use. + :param stage: The stage of the training. + :param t_1: The final timestep in the range [0, 1]. + :return: Batch dictionary with losses. + """ + assert 0 <= t_1 <= 1, "`t_1` must be in the range `[0, 1]`." + batch_size = batch["metadata"]["num_structid"] + max(batch["metadata"]["num_a_per_sample"]) + + if "num_molid" in batch["metadata"].keys() and batch["metadata"]["num_molid"] > 0: + batch["misc"]["protein_only"] = False + else: + batch["misc"]["protein_only"] = True + + if "augmented_coordinates" in batch["features"].keys(): + batch["features"]["sdf_coordinates"] = batch["features"]["augmented_coordinates"] + is_native_sample = 0 + else: + is_native_sample = 1 + + # Sample the timestep for each structure + t = torch.rand((batch_size, 1), device=device) + + prior_training = int(random.randint(0, 10) == 1) # nosec + if prior_training == 1: + t = torch.full_like(t, t_1) + + if lit_module.training and lit_module.hparams.cfg.task.use_template: + use_template = bool(random.randint(0, 1)) # nosec + else: + use_template = lit_module.hparams.cfg.task.use_template + + lit_module.net.assign_timestep_encodings(batch, t) + features = batch["features"] + indexer = batch["indexer"] + metadata = batch["metadata"] + + loss = 0 + forward_lat_converter = lit_module.net.resolve_latent_converter( + [ + ("features", "res_atom_positions"), + ("features", "input_protein_coords"), + ], + [("features", "sdf_coordinates"), ("features", "input_ligand_coords")], + ) + batch = lit_module.net.prepare_protein_patch_indexers(batch) + if not batch["misc"]["protein_only"]: + max(metadata["num_i_per_sample"]) + + # Evaluate the contact map + ref_dist_mat, contact_logit_matrix = eval_true_contact_maps( + batch, lit_module.net.CONTACT_SCALE + ) + num_cont_to_sample = max(metadata["num_I_per_sample"]) + sampled_block_contacts = [ + None, + ] + # Onehot contact code sampling + with torch.no_grad(): + for _ in range(num_cont_to_sample): + sampled_block_contacts.append( + sample_reslig_contact_matrix( + batch, contact_logit_matrix, last=sampled_block_contacts[-1] + ) + ) + forward_lat_converter.lig_res_anchor_mask = sample_res_rowmask_from_contacts( + batch, + contact_logit_matrix, + lit_module.hparams.cfg.task.single_protein_batch, + ) + with torch.no_grad(): + batch = lit_module.net.forward_interp_plcomplex_latinp( + batch, t[:, :, None], forward_lat_converter + ) + if prior_training == 1: + iter_id = random.randint(0, num_cont_to_sample) # nosec + else: + iter_id = num_cont_to_sample + batch = lit_module.forward( + batch, contact_prediction=False, score=False, use_template=use_template + ) + batch = lit_module.net.run_contact_map_stack( + batch, + iter_id=iter_id, + observed_block_contacts=sampled_block_contacts[iter_id], + ) + pred_distogram = batch["outputs"][f"res_lig_distogram_out_{iter_id}"] + ( + pl_distogram_loss, + pl_contact_loss_forward, + ) = compute_contact_prediction_losses( + pred_distogram, ref_dist_mat, lit_module.net.dist_bins, lit_module.net.CONTACT_SCALE + ) + cont_loss = 0 + cont_loss = ( + cont_loss + + pl_distogram_loss + * lit_module.hparams.cfg.task.contact_loss_weight + * is_native_sample + ) + cont_loss = ( + cont_loss + + pl_contact_loss_forward + * lit_module.hparams.cfg.task.contact_loss_weight + * is_native_sample + ) + lit_module.log( + f"{stage}_contact/contact_loss_distogram", + pl_distogram_loss.detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_contact/contact_loss_forwardKL", + pl_contact_loss_forward.detach(), + on_epoch=True, + batch_size=batch_size, + ) + if lit_module.hparams.cfg.task.freeze_contact_predictor: + # Keep the contact prediction parameters in the computational graph but with zero gradients + cont_loss *= 0.0 + else: + with torch.no_grad(): + batch = lit_module.net.forward_interp_plcomplex_latinp( + batch, t[:, :, None], forward_lat_converter + ) + iter_id = 0 + batch = lit_module.forward( + batch, + iter_id=0, + contact_prediction=True, + score=False, + use_template=use_template, + ) + protein_distogram_loss = compute_protein_distogram_loss( + batch, + batch["features"]["res_atom_positions"][:, 1], + lit_module.net.dist_bins, + lit_module.net.dgram_head, + entry=f"res_res_grid_attr_flat_out_{iter_id}", + ) + lit_module.log( + f"{stage}_contact/prot_distogram_loss", + protein_distogram_loss.detach(), + on_epoch=True, + batch_size=batch_size, + ) + if lit_module.hparams.cfg.task.freeze_contact_predictor: + # Keep the distogram prediction parameters in the computational graph but with zero gradients + protein_distogram_loss *= 0.0 + + # NOTE: we keep the loss weighting time-independent since `sigma=1` for all prior distributions (where relevant) + lambda_weighting = t.new_ones(batch_size) + + # Run score head and evaluate structure prediction losses + res_atom_mask = features["res_atom_mask"].bool() + + scores = lit_module.net.run_score_head(batch, embedding_iter_id=iter_id) + + if lit_module.training: + # # Sigmoid scaling + # violation_loss_ratio = 1 / ( + # 1 + # + math.exp(10 - 12 * lit_module.current_epoch / lit_module.trainer.max_epochs) + # ) + # violation_loss_ratio = (lit_module.current_epoch / lit_module.trainer.max_epochs) + violation_loss_ratio = 1.0 + else: + violation_loss_ratio = 1.0 + + if not batch["misc"]["protein_only"]: + if "binding_site_mask_clean" not in batch["features"]: + with torch.no_grad(): + min_lig_res_dist_clean = ( + ( + batch["features"]["res_atom_positions"][:, 1].view(batch_size, -1, 3)[ + :, :, None + ] + - batch["features"]["sdf_coordinates"].view(batch_size, -1, 3)[:, None, :] + ) + .norm(dim=-1) + .amin(dim=2) + ).flatten(0, 1) + binding_site_mask_clean = ( + min_lig_res_dist_clean < lit_module.net.BINDING_SITE_CUTOFF + ) + batch["features"]["binding_site_mask_clean"] = binding_site_mask_clean + coords_pred_prot = scores["final_coords_prot_atom_padded"][res_atom_mask].view( + batch_size, -1, 3 + ) + coords_ref_prot = batch["features"]["res_atom_positions"][res_atom_mask].view( + batch_size, -1, 3 + ) + coords_pred_bs_prot = scores["final_coords_prot_atom_padded"][ + res_atom_mask & batch["features"]["binding_site_mask_clean"][:, None] + ].view(batch_size, -1, 3) + coords_ref_bs_prot = batch["features"]["res_atom_positions"][ + res_atom_mask & batch["features"]["binding_site_mask_clean"][:, None] + ].view(batch_size, -1, 3) + coords_pred_lig = scores["final_coords_lig_atom"].view(batch_size, -1, 3) + coords_ref_lig = batch["features"]["sdf_coordinates"].view(batch_size, -1, 3) + coords_pred = torch.cat([coords_pred_prot, coords_pred_lig], dim=1) + coords_ref = torch.cat([coords_ref_prot, coords_ref_lig], dim=1) + coords_pred_bs = torch.cat([coords_pred_bs_prot, coords_pred_lig], dim=1) + coords_ref_bs = torch.cat([coords_ref_bs_prot, coords_ref_lig], dim=1) + n_I_per_sample = max(metadata["num_I_per_sample"]) + lig_frame_atm_idx = torch.stack( + [ + indexer["gather_idx_ijk_i"][indexer["gather_idx_I_ijk"]][:n_I_per_sample], + indexer["gather_idx_ijk_j"][indexer["gather_idx_I_ijk"]][:n_I_per_sample], + indexer["gather_idx_ijk_k"][indexer["gather_idx_I_ijk"]][:n_I_per_sample], + ], + dim=0, + ) + ( + global_fape_pview, + global_fape_lview, + normalized_fape, + ) = compute_fape_from_atom37( + batch, + device, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + pred_lig_coords=scores["final_coords_lig_atom"], + target_lig_coords=batch["features"]["sdf_coordinates"], + lig_frame_atm_idx=lig_frame_atm_idx, + split_pl_views=True, + ) + aa_distgeom_error = compute_aa_distance_geometry_loss( + batch, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + ) + lig_distgeom_error = compute_sm_distance_geometry_loss( + batch, + scores["final_coords_lig_atom"], + batch["features"]["sdf_coordinates"], + ) + glob_drmsd, _ = compute_drmsd_and_clashloss( + batch, + device, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + lit_module.net.atnum2vdw_uff, + pred_lig_coords=scores["final_coords_lig_atom"], + target_lig_coords=batch["features"]["sdf_coordinates"], + ligatm_types=batch["features"]["atomic_numbers"].long(), + ) + bs_drmsd, clash_error = compute_drmsd_and_clashloss( + batch, + device, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + lit_module.net.atnum2vdw_uff, + pred_lig_coords=scores["final_coords_lig_atom"], + target_lig_coords=batch["features"]["sdf_coordinates"], + ligatm_types=batch["features"]["atomic_numbers"].long(), + binding_site=True, + ) + pli_drmsd, _ = compute_drmsd_and_clashloss( + batch, + device, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + lit_module.net.atnum2vdw_uff, + pred_lig_coords=scores["final_coords_lig_atom"], + target_lig_coords=batch["features"]["sdf_coordinates"], + ligatm_types=batch["features"]["atomic_numbers"].long(), + pl_interface=True, + ) + distgeom_loss = ( + aa_distgeom_error.mul(lambda_weighting) * max(metadata["num_a_per_sample"]) + + lig_distgeom_error.mul(lambda_weighting) * max(metadata["num_i_per_sample"]) + ).mean() / max(metadata["num_a_per_sample"]) + + fape_loss = ( + ( + global_fape_pview + + global_fape_lview + * ( + lit_module.hparams.cfg.task.ligand_score_loss_weight + / lit_module.hparams.cfg.task.global_score_loss_weight + ) + + normalized_fape + ) + .mul(lambda_weighting) + .mean() + ) + + if not lit_module.hparams.cfg.task.freeze_score_head: + loss = ( + loss + + fape_loss + * lit_module.hparams.cfg.task.global_score_loss_weight + * is_native_sample + ) + loss = ( + loss + + glob_drmsd.mul(lambda_weighting).mean() + * lit_module.hparams.cfg.task.drmsd_loss_weight + ) + if use_template: + twe_drmsd = compute_template_weighted_centroid_drmsd( + batch, scores["final_coords_prot_atom_padded"] + ) + if not lit_module.hparams.cfg.task.freeze_score_head: + loss = ( + loss + + twe_drmsd.mul(lambda_weighting).mean() + * lit_module.hparams.cfg.task.drmsd_loss_weight + ) + lit_module.log( + f"{stage}/drmsd_loss_weighted", + twe_drmsd.mul(lambda_weighting).mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_weighted", + twe_drmsd.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + if not lit_module.hparams.cfg.task.freeze_score_head: + loss = ( + loss + + bs_drmsd.mul(lambda_weighting).mean() + * lit_module.hparams.cfg.task.drmsd_loss_weight + ) + loss = ( + loss + + pli_drmsd.mul(lambda_weighting).mean() + * lit_module.hparams.cfg.task.drmsd_loss_weight + ) + + loss = ( + loss + + distgeom_loss + * lit_module.hparams.cfg.task.local_distgeom_loss_weight + * violation_loss_ratio + ) + loss = ( + loss + + clash_error.mul(lambda_weighting).mean() + * lit_module.hparams.cfg.task.clash_loss_weight + * violation_loss_ratio + ) + if not lit_module.hparams.cfg.task.freeze_contact_predictor: + loss = (0.1 + 0.9 * prior_training) * cont_loss + (1 - prior_training * 0.99) * loss + loss = ( + loss + protein_distogram_loss * lit_module.hparams.cfg.task.distogram_loss_weight + ) + with torch.no_grad(): + tm_lbound = compute_TMscore_lbound( + batch, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + ) + lig_rmsd = segment_mean( + ( + (coords_pred_lig - coords_pred_prot.mean(dim=1, keepdim=True)) + - (coords_ref_lig - coords_ref_prot.mean(dim=1, keepdim=True)) + ) + .square() + .sum(dim=-1) + .flatten(0, 1), + indexer["gather_idx_i_molid"], + metadata["num_molid"], + ).sqrt() + lit_module.log( + f"{stage}/tm_lbound", + tm_lbound.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/ligand_rmsd_ubound", + lig_rmsd.mean().detach(), + on_epoch=True, + batch_size=lig_rmsd.shape[0], + ) + # L1 score matching loss + dsm_loss_global = ( + ( + (coords_pred - coords_pred_prot.mean(dim=1, keepdim=True)) + - (coords_ref - coords_ref_prot.mean(dim=1, keepdim=True)) + ) + .square() + .sum(dim=-1) + .add(1e-2) + .sqrt() + .sub(1e-1) + .mean(dim=1) + .mul(lambda_weighting) + ) + dsm_loss_site = ( + ( + (coords_pred_bs - coords_pred_bs_prot.mean(dim=1, keepdim=True)) + - (coords_ref_bs - coords_ref_bs_prot.mean(dim=1, keepdim=True)) + ) + .square() + .sum(dim=-1) + .add(1e-2) + .sqrt() + .sub(1e-1) + .mean(dim=1) + .mul(lambda_weighting) + ) + dsm_loss_ligand = ( + ( + (coords_pred_lig - coords_pred.mean(dim=1, keepdim=True)) + - (coords_ref_lig - coords_ref.mean(dim=1, keepdim=True)) + ) + .square() + .sum(dim=-1) + .add(1e-2) + .sqrt() + .sub(1e-1) + .mean(dim=1) + .mul(lambda_weighting) + ) + lit_module.log( + f"{stage}/denoising_loss_global", + dsm_loss_global.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/denoising_loss_site", + dsm_loss_site.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/denoising_loss_ligand", + dsm_loss_ligand.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_loss_global", + glob_drmsd.mul(lambda_weighting).mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_loss_site", + bs_drmsd.mul(lambda_weighting).mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_loss_pli", + pli_drmsd.mul(lambda_weighting).mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_global", + glob_drmsd.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_site", + bs_drmsd.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_pli", + pli_drmsd.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/fape_global_proteinview", + global_fape_pview.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/fape_global_ligandview", + global_fape_lview.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/fape_normalized", + normalized_fape.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/fape_loss", + fape_loss.detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/aa_distgeom_error", + aa_distgeom_error.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/lig_distgeom_error", + lig_distgeom_error.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/clash_error", + clash_error.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/clash_loss", + clash_error.mul(lambda_weighting).mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/distgeom_loss", + distgeom_loss.detach(), + on_epoch=True, + batch_size=batch_size, + ) + else: + coords_pred = scores["final_coords_prot_atom_padded"][res_atom_mask].view( + batch_size, -1, 3 + ) + coords_ref = batch["features"]["res_atom_positions"][res_atom_mask].view(batch_size, -1, 3) + global_fape_pview, normalized_fape = compute_fape_from_atom37( + batch, + device, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + ) + aa_distgeom_error = compute_aa_distance_geometry_loss( + batch, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + ) + glob_drmsd, clash_error = compute_drmsd_and_clashloss( + batch, + device, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + lit_module.net.atnum2vdw_uff, + ) + distgeom_loss = aa_distgeom_error.mul(lambda_weighting).mean() + fape_loss = (global_fape_pview + normalized_fape).mul(lambda_weighting).mean() + + global_fape_pview.detach() + if not lit_module.hparams.cfg.task.freeze_score_head: + loss = ( + loss + + distgeom_loss + * lit_module.hparams.cfg.task.local_distgeom_loss_weight + * violation_loss_ratio + ) + loss = loss + fape_loss * lit_module.hparams.cfg.task.global_score_loss_weight + loss = ( + loss + + glob_drmsd.mul(lambda_weighting).mean() + * lit_module.hparams.cfg.task.drmsd_loss_weight + ) + if use_template: + twe_drmsd = compute_template_weighted_centroid_drmsd( + batch, scores["final_coords_prot_atom_padded"] + ) + if not lit_module.hparams.cfg.task.freeze_score_head: + loss = ( + loss + + twe_drmsd.mul(lambda_weighting).mean() + * lit_module.hparams.cfg.task.drmsd_loss_weight + ) + lit_module.log( + f"{stage}/drmsd_loss_weighted", + twe_drmsd.mul(lambda_weighting).mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_weighted", + twe_drmsd.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + if not lit_module.hparams.cfg.task.freeze_score_head: + loss = ( + loss + + clash_error.mul(lambda_weighting).mean() + * lit_module.hparams.cfg.task.clash_loss_weight + * violation_loss_ratio + ) + if not lit_module.hparams.cfg.task.freeze_contact_predictor: + loss = ( + loss + protein_distogram_loss * lit_module.hparams.cfg.task.distogram_loss_weight + ) + + with torch.no_grad(): + dsm_loss_global = ( + ( + (coords_pred - coords_pred.mean(dim=1, keepdim=True)) + - (coords_ref - coords_ref.mean(dim=1, keepdim=True)) + ) + .square() + .sum(dim=-1) + .add(1e-2) + .sqrt() + .sub(1e-1) + .mean(dim=1) + .mul(lambda_weighting) + ) + lit_module.log( + f"{stage}/denoising_loss_global", + dsm_loss_global.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + tm_lbound = compute_TMscore_lbound( + batch, + scores["final_coords_prot_atom_padded"], + batch["features"]["res_atom_positions"], + ) + lit_module.log( + f"{stage}/tm_lbound", + tm_lbound.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_loss_global", + glob_drmsd.mul(lambda_weighting).mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/drmsd_global", + glob_drmsd.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/fape_global_proteinview", + global_fape_pview.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/fape_normalized", + normalized_fape.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}/fape_loss", + fape_loss.detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/aa_distgeom_error", + aa_distgeom_error.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/clash_error", + clash_error.mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/clash_loss", + clash_error.mul(lambda_weighting).mean().detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_violation/distgeom_loss", + distgeom_loss.detach(), + on_epoch=True, + batch_size=batch_size, + ) + if torch.is_tensor(loss) and not torch.isnan(loss): + lit_module.log( + f"{stage}/loss", + loss.detach(), + on_epoch=True, + batch_size=batch_size, + sync_dist=(stage != "train"), + ) + batch["outputs"]["loss"] = loss + if not torch.is_tensor(batch["outputs"]["loss"]) and batch["outputs"]["loss"] == 0: + batch["outputs"]["loss"] = None + return batch + + +def eval_auxiliary_estimation_losses( + lit_module: LightningModule, + batch: MODEL_BATCH, + stage: MODEL_STAGE, + loss_mode: LOSS_MODES, + **kwargs: Dict[str, Any], +) -> MODEL_BATCH: + """Evaluate the auxiliary estimation losses for a given batch. + + :param lit_module: The LightningModule object to reference. + :param batch: A batch dictionary. + :param stage: The stage of the training. + :param loss_mode: The loss mode to use. + :param kwargs: Additional keyword arguments. + :return: Batch dictionary with losses. + """ + use_template = bool(random.randint(0, 1)) # nosec + if use_template: + # Enable higher ligand diversity when using backbone template + start_time = 1.0 + else: + start_time = random.randint(1, 5) / 5 # nosec + with torch.no_grad(): + if loss_mode == "auxiliary_estimation_without_structure_prediction": + # Sample the structure without using the structure prediction head + # i.e., Provide the holo (ground-truth) protein and ligand structures for affinity estimation + output_struct = { + "receptor": batch["features"]["res_atom_positions"].flatten(0, 1), + "receptor_padded": batch["features"]["res_atom_positions"], + "ligands": batch["features"]["sdf_coordinates"], + } + else: + output_struct = lit_module.net.sample_pl_complex_structures( + batch, + sampler="VDODE", + sampler_eta=1.0, + num_steps=int(5 / start_time), + start_time=start_time, + exact_prior=True, + use_template=use_template, + cutoff=20.0, # Hot logits + ) + batch_size = batch["metadata"]["num_structid"] + batch = lit_module.net.run_auxiliary_estimation(batch, output_struct, **kwargs) + if lit_module.hparams.cfg.confidence.enabled: + with torch.no_grad(): + # Receptor centroids + ref_coords = ( + ( + batch["features"]["res_atom_positions"] + .mul(batch["features"]["res_atom_mask"].bool()[:, :, None]) + .sum(dim=1) + .div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9) + ) + .contiguous() + .view(batch_size, -1, 3) + ) + pred_coords = ( + ( + output_struct["receptor_padded"] + .mul(batch["features"]["res_atom_mask"].bool()[:, :, None]) + .sum(dim=1) + .div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9) + ) + .contiguous() + .view(batch_size, -1, 3) + ) + # The number of effective protein atoms used in plddt calculation + n_protatm_per_sample = pred_coords.shape[1] + if output_struct["ligands"] is not None: + ref_lig_coords = ( + batch["features"]["sdf_coordinates"].contiguous().view(batch_size, -1, 3) + ) + ref_coords = torch.cat([ref_coords, ref_lig_coords], dim=1) + pred_lig_coords = output_struct["ligands"].contiguous().view(batch_size, -1, 3) + pred_coords = torch.cat([pred_coords, pred_lig_coords], dim=1) + per_atom_lddt, per_atom_lddt_gram = compute_per_atom_lddt( + batch, pred_coords, ref_coords + ) + + plddt_dev = (per_atom_lddt - batch["outputs"]["plddt"]).abs().mean() + confidence_loss = ( + F.cross_entropy( + batch["outputs"]["plddt_logits"].flatten(0, 1), + per_atom_lddt_gram.flatten(0, 1), + reduction="none", + ) + .contiguous() + .view(batch_size, -1) + ) + conf_loss = confidence_loss.mean() + if output_struct["ligands"] is not None: + plddt_dev_lig = ( + ( + per_atom_lddt.view(batch_size, -1)[:, n_protatm_per_sample:] + - batch["outputs"]["plddt"].view(batch_size, -1)[:, n_protatm_per_sample:] + ) + .abs() + .mean() + ) + conf_loss_lig = confidence_loss[:, n_protatm_per_sample:].mean() + conf_loss = conf_loss + conf_loss_lig # + plddt_dev_lig * 0.1 + lit_module.log( + f"{stage}_confidence/plddt_dev_lig", + plddt_dev_lig.detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_confidence/plddt_dev", + plddt_dev.detach(), + on_epoch=True, + batch_size=batch_size, + ) + lit_module.log( + f"{stage}_confidence/loss", + conf_loss.detach(), + on_epoch=True, + batch_size=batch_size, + sync_dist=(stage != "train"), + ) + if lit_module.hparams.cfg.task.freeze_confidence: + # Keep the confidence prediction parameters in the computational graph but with zero gradients + conf_loss *= 0 + else: + conf_loss = 0 + if lit_module.hparams.cfg.affinity.enabled: + num_molid_per_sample = batch["metadata"]["num_molid"] // batch_size + gather_idx_molid_structid = torch.arange( + batch_size, device=batch["outputs"]["affinity_logits"].device + ).repeat_interleave(num_molid_per_sample) + + # Calculate affinity loss as the mean squared error between the predicted affinity logits and the ground-truth affinity values + affinity_logits = batch["outputs"]["affinity_logits"] + # Substitute missing ground-truth affinity values with the affinity head's (detached) predicted logits to indicate no learning signal for these examples + affinity = torch.where( + batch["features"]["affinity"].isnan(), + affinity_logits.detach(), + batch["features"]["affinity"], + ) + aff_loss = segment_mean( + # Find the (batched) mean squared error over all ligand chains in the same complex, then calculate the mean of each batch + (affinity_logits - affinity).square(), + gather_idx_molid_structid, + batch_size, + ).mean() + lit_module.log( + f"{stage}_affinity/loss", + aff_loss.detach(), + on_epoch=True, + batch_size=batch_size, + sync_dist=(stage != "train"), + ) + if lit_module.hparams.cfg.task.freeze_affinity: + # Keep the affinity prediction parameters in the computational graph but with zero gradients + aff_loss *= 0 + else: + aff_loss = 0 + plddt_loss = conf_loss * lit_module.hparams.cfg.task.plddt_loss_weight + affinity_loss = aff_loss * lit_module.hparams.cfg.task.affinity_loss_weight + batch["outputs"]["loss"] = plddt_loss + affinity_loss + if not torch.is_tensor(batch["outputs"]["loss"]) and batch["outputs"]["loss"] == 0: + batch["outputs"]["loss"] = None + return batch diff --git a/flowdock/models/components/mht_encoder.py b/flowdock/models/components/mht_encoder.py new file mode 100644 index 0000000..7dca795 --- /dev/null +++ b/flowdock/models/components/mht_encoder.py @@ -0,0 +1,364 @@ +# Adapted from: https://github.com/zrqiao/NeuralPLexer + +import rootutils +import torch +from beartype.typing import Any, Dict, Optional, Tuple +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher +from flowdock.models.components.modules import TransformerLayer +from flowdock.utils import RankedLogger +from flowdock.utils.model_utils import GELUMLP, segment_softmax, segment_sum + +MODEL_BATCH = Dict[str, Any] +STATE_DICT = Dict[str, Any] + +log = RankedLogger(__name__, rank_zero_only=True) + + +class PathConvStack(torch.nn.Module): + """Path integral convolution stack for ligand encoding.""" + + def __init__( + self, + pair_channels: int, + n_heads: int = 8, + max_pi_length: int = 8, + dropout: float = 0.0, + ): + """Initialize PathConvStack model.""" + super().__init__() + self.pair_channels = pair_channels + self.max_pi_length = max_pi_length + self.n_heads = n_heads + + self.prop_value_layer = torch.nn.Linear(pair_channels, n_heads, bias=False) + self.triangle_pair_kernel_layer = torch.nn.Linear(pair_channels, n_heads, bias=False) + self.prop_update_mlp = GELUMLP( + n_heads * (max_pi_length + 1), pair_channels, dropout=dropout + ) + + def forward( + self, + prop_attr: torch.Tensor, + stereo_attr: torch.Tensor, + indexer: Dict[str, torch.LongTensor], + metadata: Dict[str, Any], + ) -> torch.Tensor: + """Forward pass for PathConvStack model. + + :param prop_attr: Atom-frame pair attributes. + :param stereo_attr: Stereochemistry attributes. + :param indexer: A dictionary of indices. + :param metadata: A dictionary of metadata. + :return: Updated atom-frame pair attributes. + """ + triangle_pair_kernel = self.triangle_pair_kernel_layer(stereo_attr) + # Segment-wise softmax, normalized by outgoing triangles + triangle_pair_alpha = segment_softmax( + triangle_pair_kernel, indexer["gather_idx_ijkl_jkl"], metadata["num_ijk"] + ) # .div(self.max_pi_length) + # Uijk,ijkl->ujkl pair representation update + kernel = triangle_pair_alpha[indexer["gather_idx_Uijkl_ijkl"]] + out_prop_attr = [self.prop_value_layer(prop_attr)] + for _ in range(self.max_pi_length): + gathered_prop_attr = out_prop_attr[-1][indexer["gather_idx_Uijkl_Uijk"]] + out_prop_attr.append( + segment_sum( + kernel.mul(gathered_prop_attr), + indexer["gather_idx_Uijkl_ujkl"], + metadata["num_Uijk"], + ) + ) + new_prop_attr = torch.cat(out_prop_attr, dim=-1) + new_prop_attr = self.prop_update_mlp(new_prop_attr) + prop_attr + return new_prop_attr + + +class PIFormer(torch.nn.Module): + """PIFormer model for ligand encoding.""" + + def __init__( + self, + node_channels: int, + pair_channels: int, + n_atom_encodings: int, + n_bond_encodings: int, + n_atom_pos_encodings: int, + n_stereo_encodings: int, + heads: int, + head_dim: int, + max_path_length: int = 4, + n_transformer_stacks=4, + hidden_dim: Optional[int] = None, + dropout: float = 0.0, + ): + """Initialize PIFormer model.""" + super().__init__() + self.node_channels = node_channels + self.pair_channels = pair_channels + self.max_pi_length = max_path_length + self.n_transformer_stacks = n_transformer_stacks + self.n_atom_encodings = n_atom_encodings + self.n_bond_encodings = n_bond_encodings + self.n_atom_pair_encodings = n_bond_encodings + 4 + self.n_atom_pos_encodings = n_atom_pos_encodings + + self.input_atom_layer = torch.nn.Linear(n_atom_encodings, node_channels) + self.input_pair_layer = torch.nn.Linear(self.n_atom_pair_encodings, pair_channels) + self.input_stereo_layer = torch.nn.Linear(n_stereo_encodings, pair_channels) + self.input_prop_layer = GELUMLP( + self.n_atom_pair_encodings * 3, + pair_channels, + ) + self.path_integral_stacks = torch.nn.ModuleList( + [ + PathConvStack( + pair_channels, + max_pi_length=max_path_length, + dropout=dropout, + ) + for _ in range(n_transformer_stacks) + ] + ) + self.graph_transformer_stacks = torch.nn.ModuleList( + [ + TransformerLayer( + node_channels, + heads, + head_dim=head_dim, + edge_channels=pair_channels, + hidden_dim=hidden_dim, + dropout=dropout, + edge_update=True, + ) + for _ in range(n_transformer_stacks) + ] + ) + + def forward(self, batch: MODEL_BATCH, masking_rate: float = 0.0) -> MODEL_BATCH: + """Forward pass for PIFormer model. + + :param batch: A batch dictionary. + :param masking_rate: Masking rate. + :return: A batch dictionary. + """ + features = batch["features"] + indexer = batch["indexer"] + metadata = batch["metadata"] + features["atom_encodings"] = features["atom_encodings"] + atom_attr = features["atom_encodings"] + atom_pair_attr = features["atom_pair_encodings"] + af_pair_attr = features["atom_frame_pair_encodings"] + stereo_enc = features["stereo_chemistry_encodings"] + batch["features"]["lig_atom_token"] = atom_attr.detach().clone() + batch["features"]["lig_pair_token"] = atom_pair_attr.detach().clone() + + atom_mask = torch.rand(atom_attr.shape[0], device=atom_attr.device) > masking_rate + stereo_mask = torch.rand(stereo_enc.shape[0], device=stereo_enc.device) > masking_rate + atom_pair_mask = ( + torch.rand(atom_pair_attr.shape[0], device=atom_pair_attr.device) > masking_rate + ) + af_pair_mask = ( + torch.rand(af_pair_attr.shape[0], device=atom_pair_attr.device) > masking_rate + ) + atom_attr = atom_attr * atom_mask[:, None] + stereo_enc = stereo_enc * stereo_mask[:, None] + atom_pair_attr = atom_pair_attr * atom_pair_mask[:, None] + af_pair_attr = af_pair_attr * af_pair_mask[:, None] + + # Embedding blocks + metadata["num_atom"] = metadata["num_u"] + metadata["num_frame"] = metadata["num_ijk"] + atom_attr = self.input_atom_layer(atom_attr) + atom_pair_attr = self.input_pair_layer(atom_pair_attr) + triangle_attr = atom_attr.new_zeros(metadata["num_frame"], self.node_channels) + # Initialize atom-frame pair attributes. Reusing uv indices + prop_attr = self.input_prop_layer(af_pair_attr) + stereo_attr = self.input_stereo_layer(stereo_enc) + + graph_relations = [ + ("atom_to_atom", "gather_idx_uv_u", "gather_idx_uv_v", "atom", "atom"), + ( + "atom_to_frame", + "gather_idx_Uijk_u", + "gather_idx_Uijk_ijk", + "atom", + "frame", + ), + ( + "frame_to_atom", + "gather_idx_Uijk_ijk", + "gather_idx_Uijk_u", + "frame", + "atom", + ), + ( + "frame_to_frame", + "gather_idx_ijkl_ijk", + "gather_idx_ijkl_jkl", + "frame", + "frame", + ), + ] + + graph_batcher = make_multi_relation_graph_batcher(graph_relations, indexer, metadata) + merged_edge_idx = graph_batcher.collate_idx_list(indexer) + node_reps = {"atom": atom_attr, "frame": triangle_attr} + edge_reps = { + "atom_to_atom": atom_pair_attr, + "atom_to_frame": prop_attr, + "frame_to_atom": prop_attr, + "frame_to_frame": stereo_attr, + } + + # Graph path integral recursion + for block_id in range(self.n_transformer_stacks): + merged_node_attr = graph_batcher.collate_node_attr(node_reps) + merged_edge_attr = graph_batcher.collate_edge_attr(edge_reps) + _, merged_node_attr, merged_edge_attr = self.graph_transformer_stacks[block_id]( + merged_node_attr, + merged_node_attr, + merged_edge_idx, + merged_edge_attr, + ) + node_reps = graph_batcher.offload_node_attr(merged_node_attr) + edge_reps = graph_batcher.offload_edge_attr(merged_edge_attr) + prop_attr = edge_reps["atom_to_frame"] + stereo_attr = edge_reps["frame_to_frame"] + prop_attr = prop_attr + self.path_integral_stacks[block_id]( + prop_attr, + stereo_attr, + indexer, + metadata, + ) + edge_reps["atom_to_frame"] = prop_attr + + node_reps["sampled_frame"] = node_reps["frame"][indexer["gather_idx_I_ijk"]] + + batch["metadata"]["num_lig_atm"] = metadata["num_u"] + batch["metadata"]["num_lig_trp"] = metadata["num_I"] + + batch["features"]["lig_atom_attr"] = node_reps["atom"] + # Downsampled ligand frames + batch["features"]["lig_trp_attr"] = node_reps["sampled_frame"] + batch["features"]["lig_atom_pair_attr"] = edge_reps["atom_to_atom"] + batch["features"]["lig_prop_attr"] = edge_reps["atom_to_frame"] + edge_reps["sampled_atom_to_sampled_frame"] = edge_reps["atom_to_frame"][ + indexer["gather_idx_UI_Uijk"] + ] + batch["features"]["lig_af_pair_attr"] = edge_reps["sampled_atom_to_sampled_frame"] + return batch + + +def resolve_ligand_encoder( + ligand_model_cfg: DictConfig, + task_cfg: DictConfig, + state_dict: Optional[STATE_DICT] = None, +) -> torch.nn.Module: + """Instantiates a PIFormer model for ligand encoding. + + :param ligand_model_cfg: Ligand model configuration. + :param task_cfg: Task configuration. + :param state_dict: Optional (potentially-pretrained) state dictionary. + :return: Ligand encoder model. + """ + model = PIFormer( + ligand_model_cfg.node_channels, + ligand_model_cfg.pair_channels, + ligand_model_cfg.n_atom_encodings, + ligand_model_cfg.n_bond_encodings, + ligand_model_cfg.n_atom_pos_encodings, + ligand_model_cfg.n_stereo_encodings, + ligand_model_cfg.n_attention_heads, + ligand_model_cfg.attention_head_dim, + hidden_dim=ligand_model_cfg.hidden_dim, + max_path_length=ligand_model_cfg.max_path_integral_length, + n_transformer_stacks=ligand_model_cfg.n_transformer_stacks, + dropout=task_cfg.dropout, + ) + if ligand_model_cfg.from_pretrained and state_dict is not None: + try: + model.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("ligand_encoder") + } + ) + log.info( + "Successfully loaded pretrained ligand Molecular Heat Transformer (MHT) weights." + ) + except Exception as e: + log.warning( + f"Skipping loading of pretrained ligand Molecular Heat Transformer (MHT) weights due to: {e}." + ) + return model + + +def resolve_relational_reasoning_module( + protein_model_cfg: DictConfig, + ligand_model_cfg: DictConfig, + relational_reasoning_cfg: DictConfig, + state_dict: Optional[STATE_DICT] = None, +) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: + """Instantiates relational reasoning module for ligand encoding. + + :param protein_model_cfg: Protein model configuration. + :param ligand_model_cfg: Ligand model configuration. + :param relational_reasoning_cfg: Relational reasoning configuration. + :param state_dict: Optional (potentially-pretrained) state dictionary. + :return: Relational reasoning modules for ligand encoding. + """ + molgraph_single_projector = torch.nn.Linear( + ligand_model_cfg.node_channels, protein_model_cfg.residue_dim, bias=False + ) + molgraph_pair_projector = torch.nn.Linear( + ligand_model_cfg.pair_channels, protein_model_cfg.pair_dim, bias=False + ) + covalent_embed = torch.nn.Embedding(2, protein_model_cfg.pair_dim) + if relational_reasoning_cfg.from_pretrained and state_dict is not None: + try: + molgraph_single_projector.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("molgraph_single_projector") + } + ) + log.info("Successfully loaded pretrained ligand graph single projector weights.") + except Exception as e: + log.warning( + f"Skipping loading of pretrained ligand graph single projector weights due to: {e}." + ) + + try: + molgraph_pair_projector.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("molgraph_pair_projector") + } + ) + log.info("Successfully loaded pretrained ligand graph pair projector weights.") + except Exception as e: + log.warning( + f"Skipping loading of pretrained ligand graph pair projector weights due to: {e}." + ) + + try: + covalent_embed.load_state_dict( + { + ".".join(k.split(".")[1:]): v + for k, v in state_dict.items() + if k.startswith("covalent_embed") + } + ) + log.info("Successfully loaded pretrained ligand covalent embedding weights.") + except Exception as e: + log.warning( + f"Skipping loading of pretrained ligand covalent embedding weights due to: {e}." + ) + return molgraph_single_projector, molgraph_pair_projector, covalent_embed diff --git a/flowdock/models/components/modules.py b/flowdock/models/components/modules.py new file mode 100644 index 0000000..7ee190f --- /dev/null +++ b/flowdock/models/components/modules.py @@ -0,0 +1,423 @@ +import math + +import rootutils +import torch +import torch.nn.functional as F +from beartype.typing import Optional, Tuple, Union +from openfold.model.primitives import Attention +from openfold.utils.tensor_utils import permute_final_dims + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils.model_utils import GELUMLP, segment_softmax + + +class MultiHeadAttentionConv(torch.nn.Module): + """Native Pytorch implementation.""" + + def __init__( + self, + dim: Union[int, Tuple[int, int]], + head_dim: int, + edge_dim: int = None, + n_heads: int = 1, + dropout: float = 0.0, + edge_lin: bool = True, + **kwargs, + ): + """Multi-Head Attention Convolution layer.""" + super().__init__() + self.dim = dim + self.head_dim = head_dim + self.n_heads = n_heads + self.dropout = dropout + self.edge_dim = edge_dim + self.edge_lin = edge_lin + self._alpha = None + + if isinstance(dim, int): + dim = (dim, dim) + + self.lin_key = torch.nn.Linear(dim[0], n_heads * head_dim, bias=False) + self.lin_query = torch.nn.Linear(dim[1], n_heads * head_dim, bias=False) + self.lin_value = torch.nn.Linear(dim[0], n_heads * head_dim, bias=False) + if edge_lin is True: + self.lin_edge = torch.nn.Linear(edge_dim, n_heads, bias=False) + else: + self.lin_edge = self.register_parameter("lin_edge", None) + + self.reset_parameters() + + def reset_parameters(self): + """Reset the parameters of the layer.""" + self.lin_key.reset_parameters() + self.lin_query.reset_parameters() + self.lin_value.reset_parameters() + if self.edge_lin: + self.lin_edge.reset_parameters() + + def forward( + self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + edge_index: torch.Tensor, + edge_attr: torch.Tensor = None, + return_attention_weights=None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Forward pass of the Multi-Head Attention Convolution layer. + + :param x (torch.Tensor or Tuple[torch.Tensor, torch.Tensor]): The input features. + :param edge_index (torch.Tensor): The edge index tensor. + :param edge_attr (torch.Tensor, optional): The edge attribute tensor. + :param return_attention_weights (bool, optional): If set to `True`, + will additionally return the tuple + `(edge_index, attention_weights)`, holding the computed + attention weights for each edge. Default is `None`. + :return: The output features or the tuple + `(output_features, (edge_index, attention_weights))`. + """ + + H, C = self.n_heads, self.head_dim + + if isinstance(x, torch.Tensor): + x = (x, x) + + query = self.lin_query(x[1]).view(*x[1].shape[:-1], H, C) + key = self.lin_key(x[0]).view(*x[0].shape[:-1], H, C) + value = self.lin_value(x[0]).view(*x[0].shape[:-1], H, C) + + attended_values = self.message(key, query, value, edge_attr, edge_index) + out = self.aggregate(attended_values, edge_index[1], query.shape[0]) + + alpha = self._alpha + self._alpha = None + + out = out.contiguous().view(*out.shape[:-2], H * C) + + if isinstance(return_attention_weights, bool): + assert alpha is not None + return out, (edge_index, alpha) + else: + return out + + def message( + self, + key: torch.Tensor, + query: torch.Tensor, + value: torch.Tensor, + edge_attr: torch.Tensor, + index: torch.Tensor, + ) -> torch.Tensor: + """Add the relative positional encodings to attention scores. + + :param key (torch.Tensor): The key tensor. + :param query (torch.Tensor): The query tensor. + :param value (torch.Tensor): The value tensor. + :param edge_attr (torch.Tensor): The edge attribute tensor. + :param index (torch.Tensor): The edge index tensor. + :return: The output tensor. + """ + edge_bias = 0 + if self.lin_edge is not None: + assert edge_attr is not None + edge_bias = self.lin_edge(edge_attr) + + _alpha_z = (query[index[1]] * key[index[0]]).sum(dim=-1) / math.sqrt( + self.head_dim + ) + edge_bias + self._alpha = _alpha_z + alpha = segment_softmax(_alpha_z, index[1], query.shape[0]) + alpha = F.dropout(alpha, p=self.dropout, training=self.training) + + out = value[index[0]] + out *= alpha.unsqueeze(-1) + return out + + def aggregate( + self, src: torch.Tensor, dst_idx: torch.Tensor, dst_size: torch.Size + ) -> torch.Tensor: + """Aggregate the source tensor to the destination tensor. + + :param src (torch.Tensor): The source tensor. + :param dst_idx (torch.Tensor): The destination index tensor. + :param dst_size (torch.Size): The destination size tensor. + :return: The output tensor. + """ + out = torch.zeros( + dst_size, + *src.shape[1:], + dtype=src.dtype, + device=src.device, + ).index_add_(0, dst_idx, src) + return out + + +class TransformerLayer(torch.nn.Module): + """A single layer of a transformer model.""" + + def __init__( + self, + node_dim: int, + n_heads: int, + head_dim: Optional[int] = None, + hidden_dim: Optional[int] = None, + bidirectional: bool = False, + edge_channels: Optional[int] = None, + dropout: float = 0.0, + edge_update: bool = False, + ): + """Initialize the transformer layer.""" + super().__init__() + edge_lin = edge_channels is not None + self.edge_update = edge_update + if head_dim is None: + head_dim = node_dim // n_heads + self.conv = MultiHeadAttentionConv( + node_dim, + head_dim, + edge_dim=edge_channels, + n_heads=n_heads, + edge_lin=edge_lin, + dropout=dropout, + ) + self.bidirectional = bidirectional + self.projector = torch.nn.Linear(head_dim * n_heads, node_dim, bias=False) + self.norm = torch.nn.LayerNorm(node_dim) + self.mlp = GELUMLP( + node_dim, + node_dim, + n_hidden_feats=hidden_dim, + dropout=dropout, + zero_init=True, + ) + if edge_update: + self.mlpe = GELUMLP( + n_heads + edge_channels, edge_channels, dropout=dropout, zero_init=True + ) + + def forward( + self, + x_s: torch.Tensor, + x_a: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """Forward pass through the transformer layer. + + :param x_s (torch.Tensor): The source node features. + :param x_a (torch.Tensor): The target node features. + :param edge_index (torch.Tensor): The edge index tensor. :param edge_attr (torch.Tensor, + optional): The edge attribute tensor. + :return: The output source and target node features. + """ + out_a, (edge_index, alpha) = self.conv( + (x_s, x_a), + edge_index, + edge_attr, + return_attention_weights=True, + ) + x_a = x_a + self.projector(out_a) + x_a = self.mlp(self.norm(x_a)) + x_a + if self.bidirectional: + out_s = self.conv((x_a, x_s), (edge_index[1], edge_index[0]), edge_attr) + x_s = x_s + self.projector(out_s) + x_s = self.mlp(self.norm(x_s)) + x_s + if self.edge_update: + edge_attr = edge_attr + self.mlpe(torch.cat([alpha, edge_attr], dim=-1)) + return x_s, x_a, edge_attr + else: + return x_s, x_a + + +class PointSetAttention(torch.nn.Module): + """PointSetAttention module.""" + + def __init__( + self, + fiber_dim: int, + heads: int = 8, + point_dim: int = 4, + edge_dim: Optional[int] = None, + edge_update: bool = False, + dropout: float = 0.0, + ): + """Initialize the PointSetAttention module.""" + super().__init__() + self.fiber_dim = fiber_dim + self.edge_dim = edge_dim + self.heads = heads + self.point_dim = point_dim + self.dropout = dropout + self.edge_update = edge_update + self.distance_scaling = 10 # 1 nm + + # num attention contributions + num_attn_logits = 2 + + self.lin_query = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False) + self.lin_key = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False) + self.lin_value = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False) + if edge_dim is not None: + self.lin_edge = torch.nn.Linear(edge_dim, heads, bias=False) + if edge_update: + self.edge_update_mlp = GELUMLP(heads + edge_dim, edge_dim) + + # qkv projection for scalar attention (normal) + self.scalar_attn_logits_scale = (num_attn_logits * point_dim) ** -0.5 + + # qkv projection for point attention (coordinate and orientation aware) + point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.0)) - 1.0) + self.point_weights = torch.nn.Parameter(point_weight_init_value) + + self.point_attn_logits_scale = ((num_attn_logits * point_dim) * (9 / 2)) ** -0.5 + point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.0)) - 1.0) + self.point_weights = torch.nn.Parameter(point_weight_init_value) + + # combine out - point dim * 4 + self.to_out = torch.nn.Linear(heads * point_dim, fiber_dim, bias=False) + + def forward( + self, + x_k: torch.Tensor, + x_q: torch.Tensor, + edge_index: torch.LongTensor, + point_centers_k: torch.Tensor, + point_centers_q: torch.Tensor, + x_edge: torch.Tensor = None, + ): + """Forward pass of the PointSetAttention module.""" + H, P = self.heads, self.point_dim + + q = self.lin_query(x_q) + k = self.lin_key(x_k) + v = self.lin_value(x_k) + + scalar_q = q[..., 0, :].view(-1, H, P) + scalar_k = k[..., 0, :].view(-1, H, P) + scalar_v = v[..., 0, :].view(-1, H, P) + + point_q_local = q[..., 1:, :].view(-1, 3, H, P) + point_k_local = k[..., 1:, :].view(-1, 3, H, P) + point_v_local = v[..., 1:, :].view(-1, 3, H, P) + + point_q = point_q_local + point_centers_q[..., None, None] / self.distance_scaling + point_k = point_k_local + point_centers_k[..., None, None] / self.distance_scaling + point_v = point_v_local + point_centers_k[..., None, None] / self.distance_scaling + + if self.edge_dim is not None: + edge_bias = self.lin_edge(x_edge) + else: + edge_bias = 0 + + attn_logits, attentions = self.compute_attention( + scalar_k, scalar_q, point_k, point_q, edge_bias, edge_index + ) + res_scalar = self.aggregate( + attentions[:, :, None] * scalar_v[edge_index[0]], + edge_index[1], + scalar_q.shape[0], + ) + res_points = self.aggregate( + attentions[:, None, :, None] * point_v[edge_index[0]], + edge_index[1], + point_q.shape[0], + ) + res_points_local = res_points - point_centers_q[..., None, None] / self.distance_scaling + + # [N, H, P], [N, 3, H, P] -> [N, 4, C] + res = torch.cat([res_scalar.unsqueeze(-3), res_points_local], dim=-3).flatten(-2, -1) + out = self.to_out(res) # [N, 4, C] + if self.edge_update: + edge_out = self.edge_update_mlp(torch.cat([attn_logits, x_edge], dim=-1)) + return out, edge_out + return out + + def compute_attention(self, scalar_k, scalar_q, point_k, point_q, edge_bias, index): + """Compute the attention scores.""" + scalar_q = scalar_q[index[1]] + scalar_k = scalar_k[index[0]] + point_q = point_q[index[1]] + point_k = point_k[index[0]] + + scalar_logits = (scalar_q * scalar_k).sum(dim=-1) * self.scalar_attn_logits_scale + point_weights = F.softplus(self.point_weights).unsqueeze(0) + point_logits = ( + torch.square(point_q - point_k).sum(dim=(-3, -1)) * self.point_attn_logits_scale + ) + + logits = scalar_logits - 1 / 2 * point_logits * point_weights + edge_bias + alpha = segment_softmax(logits, index[1], scalar_q.shape[0]) + alpha = F.dropout(alpha, p=self.dropout, training=self.training) + return logits, alpha + + @staticmethod + def aggregate(src, dst_idx, dst_size): + """Aggregate the source tensor to the destination tensor.""" + out = torch.zeros( + dst_size, + *src.shape[1:], + dtype=src.dtype, + device=src.device, + ).index_add_(0, dst_idx, src) + return out + + +class BiDirectionalTriangleAttention(torch.nn.Module): + """ + Adapted from https://github.com/aqlaboratory/openfold + supports rectangular pair representation tensors + """ + + def __init__(self, c_in: int, c_hidden: int, no_heads: int, inf: float = 1e9): + """Initialize the Bi-Directional Triangle Attention layer.""" + super().__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + + self.linear = torch.nn.Linear(c_in, self.no_heads, bias=False) + + self.mha_1 = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + self.mha_2 = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + self.layer_norm = torch.nn.LayerNorm(self.c_in) + + def forward( + self, + x1: torch.Tensor, + x2: torch.Tensor, + x_pair: torch.Tensor, + mask: Optional[torch.Tensor] = None, + use_lma: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the Bi-Directional Triangle Attention layer.""" + if mask is None: + # [*, I, J, K] + mask = x_pair.new_ones( + x_pair.shape[:-1], + ) + + # [*, I, J, C_in] + x1 = self.layer_norm(x1) + # [*, I, K, C_in] + x2 = self.layer_norm(x2) + + # [*, I, 1, J, K] + mask_bias = (self.inf * (mask - 1))[..., :, None, :, :] + + # [*, I, H, J, K] + triangle_bias = permute_final_dims(self.linear(x_pair), [0, 3, 1, 2]) + + biases_J2I = [mask_bias, triangle_bias] + + x1_out = self.mha_1(q_x=x1, kv_x=x2, biases=biases_J2I, use_lma=use_lma) + x1 = x1 + x1_out + + # transpose the triangle bias for I->J attention. + mask_bias_T_ = mask_bias.transpose(-2, -1).contiguous() + triangle_bias_T_ = triangle_bias.transpose(-2, -1).contiguous() + biases_I2J = [mask_bias_T_, triangle_bias_T_] + x2_out = self.mha_2(q_x=x2, kv_x=x1, biases=biases_I2J, use_lma=use_lma) + x2 = x2 + x2_out + + return x1, x2 diff --git a/flowdock/models/components/noise.py b/flowdock/models/components/noise.py new file mode 100644 index 0000000..bc92212 --- /dev/null +++ b/flowdock/models/components/noise.py @@ -0,0 +1,443 @@ +import numpy as np +import rootutils +import torch +from beartype.typing import Any, Dict, List, Optional, Tuple, Union + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.models.components.transforms import LatentCoordinateConverter +from flowdock.utils import RankedLogger +from flowdock.utils.model_utils import segment_mean + +MODEL_BATCH = Dict[str, Any] + +log = RankedLogger(__name__, rank_zero_only=True) + + +class DiffusionSDE: + """Diffusion SDE class. + + Adapted from: https://github.com/HannesStark/FlowSite + """ + + def __init__(self, sigma: torch.Tensor, tau_factor: float = 5.0): + """Initialize the Diffusion SDE class.""" + self.lamb = 1 / sigma**2 + self.tau_factor = tau_factor + + def var(self, t: torch.Tensor) -> torch.Tensor: + """Calculate the variance of the diffusion SDE.""" + return (1 - torch.exp(-self.lamb * t)) / self.lamb + + def max_t(self) -> float: + """Calculate the maximum time of the diffusion SDE.""" + return self.tau_factor / self.lamb + + def mu_factor(self, t: torch.Tensor) -> torch.Tensor: + """Calculate the mu factor of the diffusion SDE.""" + return torch.exp(-self.lamb * t / 2) + + +class HarmonicSDE: + """Harmonic SDE class. + + Adapted from: https://github.com/HannesStark/FlowSite + """ + + def __init__(self, J: Optional[torch.Tensor] = None, diagonalize: bool = True): + """Initialize the Harmonic SDE class.""" + self.l_index = 1 + self.use_cuda = False + if not diagonalize: + return + if J is not None: + self.D, self.P = np.linalg.eigh(J) + self.N = self.D.size + + @staticmethod + def diagonalize( + N, + ptr: torch.Tensor, + edges: Optional[List[Tuple[int, int]]] = None, + antiedges: Optional[List[Tuple[int, int]]] = None, + a=1, + b=0.3, + lamb: Optional[torch.Tensor] = None, + device: Optional[Union[str, torch.device]] = None, + ): + """Diagonalize using the Harmonic SDE.""" + device = device or ptr.device + J = torch.zeros((N, N), device=device) + if edges is None: + for i, j in zip(np.arange(N - 1), np.arange(1, N)): + J[i, i] += a + J[j, j] += a + J[i, j] = J[j, i] = -a + else: + for i, j in edges: + J[i, i] += a + J[j, j] += a + J[i, j] = J[j, i] = -a + if antiedges is not None: + for i, j in antiedges: + J[i, i] -= b + J[j, j] -= b + J[i, j] = J[j, i] = b + if edges is not None: + J += torch.diag(lamb) + + Ds, Ps = [], [] + for start, end in zip(ptr[:-1], ptr[1:]): + D, P = torch.linalg.eigh(J[start:end, start:end]) + D_ = D + if edges is None: + D_inv = 1 / D + D_inv[0] = 0 + D_ = D_inv + Ds.append(D_) + Ps.append(P) + return torch.cat(Ds), torch.block_diag(*Ps) + + def eigens(self, t): + """Calculate the eigenvalues of `sigma_t` using the Harmonic SDE.""" + np_ = torch if self.use_cuda else np + D = 1 / self.D * (1 - np_.exp(-t * self.D)) + t = torch.tensor(t, device="cuda").float() if self.use_cuda else t + return np_.where(D != 0, D, t) + + def conditional(self, mask, x2): + """Calculate the conditional distribution using the Harmonic SDE.""" + J_11 = self.J[~mask][:, ~mask] + J_12 = self.J[~mask][:, mask] + h = -J_12 @ x2 + mu = np.linalg.inv(J_11) @ h + D, P = np.linalg.eigh(J_11) + z = np.random.randn(*mu.shape) + return (P / D**0.5) @ z + mu + + def A(self, t, invT=False): + """Calculate the matrix `A` using the Harmonic SDE.""" + D = self.eigens(t) + A = self.P * (D**0.5) + if not invT: + return A + AinvT = self.P / (D**0.5) + return A, AinvT + + def Sigma_inv(self, t): + """Calculate the inverse of the covariance matrix `Sigma` using the Harmonic SDE.""" + D = 1 / self.eigens(t) + return (self.P * D) @ self.P.T + + def Sigma(self, t): + """Calculate the covariance matrix `Sigma` using the Harmonic SDE.""" + D = self.eigens(t) + return (self.P * D) @ self.P.T + + @property + def J(self): + """Return the matrix `J`.""" + return (self.P * self.D) @ self.P.T + + def rmsd(self, t): + """Calculate the root mean square deviation using the Harmonic SDE.""" + l_index = self.l_index + D = 1 / self.D * (1 - np.exp(-t * self.D)) + return np.sqrt(3 * D[l_index:].mean()) + + def sample(self, t, x=None, score=False, k=None, center=True, adj=False): + """Sample from the Harmonic SDE.""" + l_index = self.l_index + np_ = torch if self.use_cuda else np + if x is None: + if self.use_cuda: + x = torch.zeros((self.N, 3), device="cuda").float() + else: + x = np.zeros((self.N, 3)) + if t == 0: + return x + z = ( + np.random.randn(self.N, 3) + if not self.use_cuda + else torch.randn(self.N, 3, device="cuda").float() + ) + D = self.eigens(t) + xx = self.P.T @ x + if center: + z[0] = 0 + xx[0] = 0 + if k: + z[k + l_index :] = 0 + xx[k + l_index :] = 0 + + out = np_.exp(-t * self.D / 2)[:, None] * xx + np_.sqrt(D)[:, None] * z + + if score: + score = -(1 / np_.sqrt(D))[:, None] * z + if adj: + score = score + self.D[:, None] * out + return self.P @ out, self.P @ score + return self.P @ out + + def score_norm(self, t, k=None, adj=False): + """Calculate the score norm using the Harmonic SDE.""" + if k == 0: + return 0 + l_index = self.l_index + np_ = torch if self.use_cuda else np + k = k or self.N - 1 + D = 1 / self.eigens(t) + if adj: + D = D * np_.exp(-self.D * t) + return (D[l_index : k + l_index].sum() / self.N) ** 0.5 + + def inject(self, t, modes): + """Inject noise along the given modes using the Harmonic SDE.""" + z = ( + np.random.randn(self.N, 3) + if not self.use_cuda + else torch.randn(self.N, 3, device="cuda").float() + ) + z[~modes] = 0 + A = self.A(t, invT=False) + return A @ z + + def score(self, x0, xt, t): + """Calculate the score of the diffusion kernel using the Harmonic SDE.""" + Sigma_inv = self.Sigma_inv(t) + mu_t = (self.P * np.exp(-t * self.D / 2)) @ (self.P.T @ x0) + return Sigma_inv @ (mu_t - xt) + + def project(self, X, k, center=False): + """Project onto the first `k` nonzero modes using the Harmonic SDE.""" + l_index = self.l_index + D = self.P.T @ X + D[k + l_index :] = 0 + if center: + D[0] = 0 + return self.P @ D + + def unproject(self, X, mask, k, return_Pinv=False): + """Find the vector along the first k nonzero modes whose mask is closest to `X`""" + l_index = self.l_index + PP = self.P[mask, : k + l_index] + Pinv = np.linalg.pinv(PP) + out = self.P[:, : k + l_index] @ Pinv @ X + if return_Pinv: + return out, Pinv + return out + + def energy(self, X): + """Calculate the energy using the Harmonic SDE.""" + l_index = self.l_index + return (self.D[:, None] * (self.P.T @ X) ** 2).sum(-1)[l_index:] / 2 + + @property + def free_energy(self): + """Calculate the free energy using the Harmonic SDE.""" + l_index = self.l_index + return 3 * np.log(self.D[l_index:]).sum() / 2 + + def KL_H(self, t): + """Calculate the Kullback-Leibler divergence using the Harmonic SDE.""" + l_index = self.l_index + D = self.D[l_index:] + return -3 * 0.5 * (np.log(1 - np.exp(-D * t)) + np.exp(-D * t)).sum(0) + + +def sample_gaussian_prior( + x0: torch.Tensor, + latent_converter: LatentCoordinateConverter, + sigma: float, + x0_sigma: float = 1e-4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Sample noise from a Gaussian prior distribution. + + :param x0: ground-truth tensor + :param latent_converter: The latent coordinate converter + :param sigma: standard deviation of the Gaussian noise + :param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor + :return: tuple of ground-truth and predicted tensors with additive Gaussian prior noise + """ + prior = torch.randn_like(x0) + x_int_0 = x0 + prior * x0_sigma # add small Gaussian noise to the ground-truth tensor + ( + x1_ca_lat, + x1_cother_lat, + x1_lig_lat, + ) = torch.split( + prior * sigma, + [ + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + x_int_1 = torch.cat( + [ + x1_ca_lat, + x1_cother_lat, + x1_lig_lat, + ], + dim=1, + ) + return x_int_0, x_int_1 + + +def sample_protein_harmonic_prior( + protein_ca_x0: torch.Tensor, + protein_cother_x0: torch.Tensor, + batch: MODEL_BATCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Sample protein noise from a harmonic prior distribution. + Adapted from: https://github.com/bjing2016/alphaflow + + Note that this function represents non-Ca atoms as Gaussian noise + centered around each harmonically-noised Ca atom. + + :param protein_ca_x0: ground-truth protein Ca-atom tensor + :param protein_cother_x0: ground-truth protein other-atom tensor + :param batch: A batch dictionary + :return: tuple of harmonic protein Ca atom noise and Gaussian protein other atom noise + """ + indexer = batch["indexer"] + metadata = batch["metadata"] + protein_bid = indexer["gather_idx_a_structid"] + protein_num_nodes = protein_ca_x0.size(0) * protein_ca_x0.size(1) + ptr = torch.cumsum(torch.bincount(protein_bid), dim=0) + ptr = torch.cat((torch.tensor([0], device=protein_bid.device), ptr)) + try: + D_inv, P = HarmonicSDE.diagonalize( + protein_num_nodes, + ptr, + a=3 / (3.8**2), + ) + except Exception as e: + log.error( + f"Failed to call HarmonicSDE.diagonalize() for protein(s) {metadata['sample_ID_per_sample']} due to: {e}" + ) + raise e + noise = torch.randn((protein_num_nodes, 3), device=protein_ca_x0.device) + harmonic_ca_noise = P @ (torch.sqrt(D_inv)[:, None] * noise) + gaussian_cother_noise = ( + torch.randn_like(protein_cother_x0.flatten(0, 1)) + + harmonic_ca_noise[indexer["gather_idx_a_cotherid"]] + ) + return ( + harmonic_ca_noise.view(protein_ca_x0.size()).contiguous(), + gaussian_cother_noise.view(protein_cother_x0.size()).contiguous(), + ) + + +def sample_ligand_harmonic_prior( + lig_x0: torch.Tensor, protein_ca_x0: torch.Tensor, batch: MODEL_BATCH, sigma: float = 1.0 +) -> torch.Tensor: + """ + Sample ligand noise from a harmonic prior distribution. + Adapted from: https://github.com/HannesStark/FlowSite + + :param lig_x0: ground-truth ligand tensor + :param protein_x0: ground-truth protein Ca-atom tensor + :param batch: A batch dictionary + :param sigma: standard deviation of the harmonic noise + :return: tensor of harmonic noise + """ + indexer = batch["indexer"] + metadata = batch["metadata"] + lig_num_nodes = lig_x0.size(0) * lig_x0.size(1) + num_molid_per_sample = max(metadata["num_molid_per_sample"]) + # NOTE: here, we distinguish each ligand chain in a complex for harmonic chain sampling + lig_bid = indexer["gather_idx_i_molid"] + protein_sigma = ( + segment_mean( + torch.square(protein_ca_x0).flatten(0, 1), + indexer["gather_idx_a_structid"], + metadata["num_structid"], + ).mean(dim=-1) + ** 0.5 + ).repeat_interleave(num_molid_per_sample) + sde = DiffusionSDE(protein_sigma * sigma) + edges = torch.stack( + ( + indexer["gather_idx_ij_i"], + indexer["gather_idx_ij_j"], + ) + ) + edges = edges[:, edges[0] < edges[1]] # de-duplicate edges + ptr = torch.cumsum(torch.bincount(lig_bid), dim=0) + ptr = torch.cat((torch.tensor([0], device=lig_bid.device), ptr)) + try: + D, P = HarmonicSDE.diagonalize( + lig_num_nodes, + ptr, + edges=edges.T, + lamb=sde.lamb[lig_bid], + ) + except Exception as e: + log.error( + f"Failed to call HarmonicSDE.diagonalize() for ligand(s) {metadata['sample_ID_per_sample']} due to: {e}" + ) + raise e + noise = torch.randn((lig_num_nodes, 3), device=lig_x0.device) + prior = P @ (noise / torch.sqrt(D)[:, None]) + return prior.view(lig_x0.size()).contiguous() + + +def sample_complex_harmonic_prior( + x0: torch.Tensor, + latent_converter: LatentCoordinateConverter, + batch: MODEL_BATCH, + x0_sigma: float = 1e-4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Sample protein-ligand complex noise from a harmonic prior distribution. + From: https://github.com/bjing2016/alphaflow + + :param x0: ground-truth tensor + :param latent_converter: The latent coordinate converter + :param batch: A batch dictionary + :param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor + :return: tuple of ground-truth and predicted tensors with additive Gaussian and harmonic prior noise, respectively + """ + ca_lat, cother_lat, lig_lat = x0.split( + [ + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + harmonic_ca_lat, gaussian_cother_lat = sample_protein_harmonic_prior( + ca_lat, + cother_lat, + batch, + ) + harmonic_lig_lat = sample_ligand_harmonic_prior(lig_lat, harmonic_ca_lat, batch) + x1 = torch.cat( + [ + # NOTE: the following normalization steps assume that `self.latent_model == "default"` + harmonic_ca_lat / latent_converter.ca_scale, + gaussian_cother_lat / latent_converter.other_scale, + harmonic_lig_lat / latent_converter.other_scale, + ], + dim=1, + ) + gaussian_prior = torch.randn_like(x0) + return x0 + gaussian_prior * x0_sigma, x1 + + +def sample_esmfold_prior( + x0: torch.Tensor, x1: torch.Tensor, sigma: float, x0_sigma: float = 1e-4 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Sample noise from an ESMFold prior distribution. + + :param x0: ground-truth tensor + :param x1: predicted tensor + :param sigma: standard deviation of the ESMFold prior's additive Gaussian noise + :param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor + :return: tuple of ground-truth and predicted tensors with additive Gaussian prior noise + """ + prior_noise = torch.randn_like(x0) + return x0 + prior_noise * x0_sigma, x1 + prior_noise * sigma diff --git a/flowdock/models/components/transforms.py b/flowdock/models/components/transforms.py new file mode 100644 index 0000000..843f7d9 --- /dev/null +++ b/flowdock/models/components/transforms.py @@ -0,0 +1,241 @@ +# Adapted from: https://github.com/zrqiao/NeuralPLexer + +import rootutils +import torch + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils.model_utils import segment_mean + + +class LatentCoordinateConverter: + """Transform the batched feature dict to latent coordinate arrays.""" + + def __init__(self, config, prot_atom37_namemap, lig_namemap): + """Initialize the converter.""" + super().__init__() + self.config = config + self.prot_namemap = prot_atom37_namemap + self.lig_namemap = lig_namemap + self.cached_noise = None + self._last_pred_ca_trace = None + + @staticmethod + def nested_get(dic, keys): + """Get the value in the nested dictionary.""" + for key in keys: + dic = dic[key] + return dic + + @staticmethod + def nested_set(dic, keys, value): + """Set the value in the nested dictionary.""" + for key in keys[:-1]: + dic = dic.setdefault(key, {}) + dic[keys[-1]] = value + + def to_latent(self, batch): + """Convert the batched feature dict to latent coordinates.""" + return None + + def assign_to_batch(self, batch, x_int): + """Assign the latent coordinates to the batched feature dict.""" + return None + + +class DefaultPLCoordinateConverter(LatentCoordinateConverter): + """Minimal conversion, using internal coords for sidechains and global coords for others.""" + + def __init__(self, config, prot_atom37_namemap, lig_namemap): + """Initialize the converter.""" + super().__init__(config, prot_atom37_namemap, lig_namemap) + # Scale parameters in Angstrom + self.ca_scale = config.global_max_sigma + self.other_scale = config.internal_max_sigma + + def to_latent(self, batch: dict): + """Convert the batched feature dict to latent coordinates.""" + indexer = batch["indexer"] + metadata = batch["metadata"] + self._batch_size = metadata["num_structid"] + atom37_mask = batch["features"]["res_atom_mask"].bool() + self._cother_mask = atom37_mask.clone() + self._cother_mask[:, 1] = False + atom37_coords = self.nested_get(batch, self.prot_namemap[0]) + try: + apo_available = True + apo_atom37_coords = self.nested_get( + batch, self.prot_namemap[0][:-1] + ("apo_" + self.prot_namemap[0][-1],) + ) + except KeyError: + apo_available = False + apo_atom37_coords = torch.zeros_like(atom37_coords) + ca_atom_centroid_coords = segment_mean( + # NOTE: in contrast to NeuralPLexer, we center all coordinates at the origin using the Ca atom centroids + atom37_coords[:, 1], + indexer["gather_idx_a_structid"], + self._batch_size, + ) + if apo_available: + apo_ca_atom_centroid_coords = segment_mean( + apo_atom37_coords[:, 1], + indexer["gather_idx_a_structid"], + self._batch_size, + ) + else: + apo_ca_atom_centroid_coords = torch.zeros_like(ca_atom_centroid_coords) + ca_coords_glob = ( + (atom37_coords[:, 1] - ca_atom_centroid_coords[indexer["gather_idx_a_structid"]]) + .contiguous() + .view(self._batch_size, -1, 3) + ) + if apo_available: + apo_ca_coords_glob = ( + ( + apo_atom37_coords[:, 1] + - apo_ca_atom_centroid_coords[indexer["gather_idx_a_structid"]] + ) + .contiguous() + .view(self._batch_size, -1, 3) + ) + else: + apo_ca_coords_glob = torch.zeros_like(ca_coords_glob) + cother_coords_int = ( + (atom37_coords - ca_atom_centroid_coords[indexer["gather_idx_a_structid"], None])[ + self._cother_mask + ] + .contiguous() + .view(self._batch_size, -1, 3) + ) + if apo_available: + apo_cother_coords_int = ( + ( + apo_atom37_coords + - apo_ca_atom_centroid_coords[indexer["gather_idx_a_structid"], None] + )[self._cother_mask] + .contiguous() + .view(self._batch_size, -1, 3) + ) + else: + apo_cother_coords_int = torch.zeros_like(cother_coords_int) + self._n_res_per_sample = ca_coords_glob.shape[1] + self._n_cother_per_sample = cother_coords_int.shape[1] + if batch["misc"]["protein_only"]: + self._n_ligha_per_sample = 0 + x_int = torch.cat( + [ + ca_coords_glob / self.ca_scale, + apo_ca_coords_glob / self.ca_scale, + cother_coords_int / self.other_scale, + apo_cother_coords_int / self.other_scale, + ], + dim=1, + ) + return x_int + lig_ha_coords = self.nested_get(batch, self.lig_namemap[0]) + lig_ha_coords_int = ( + lig_ha_coords - ca_atom_centroid_coords[indexer["gather_idx_i_structid"]] + ) + lig_ha_coords_int = lig_ha_coords_int.contiguous().view(self._batch_size, -1, 3) + ca_atom_centroid_coords = ca_atom_centroid_coords.contiguous().view( + self._batch_size, -1, 3 + ) + apo_ca_atom_centroid_coords = apo_ca_atom_centroid_coords.contiguous().view( + self._batch_size, -1, 3 + ) + x_int = torch.cat( + [ + ca_coords_glob / self.ca_scale, + apo_ca_coords_glob / self.ca_scale, + cother_coords_int / self.other_scale, + apo_cother_coords_int / self.other_scale, + ca_atom_centroid_coords / self.ca_scale, + apo_ca_atom_centroid_coords / self.ca_scale, + lig_ha_coords_int / self.other_scale, + ], + dim=1, + ) + # NOTE: since we use the Ca atom centroids for centralization, we have only one molid per sample + self._n_molid_per_sample = ca_atom_centroid_coords.shape[1] + self._n_ligha_per_sample = lig_ha_coords_int.shape[1] + return x_int + + def assign_to_batch(self, batch: dict, x_lat: torch.Tensor): + """Assign the latent coordinates to the batched feature dict.""" + indexer = batch["indexer"] + new_atom37_coords = x_lat.new_zeros(self._batch_size * self._n_res_per_sample, 37, 3) + apo_new_atom37_coords = x_lat.new_zeros(self._batch_size * self._n_res_per_sample, 37, 3) + if batch["misc"]["protein_only"]: + ca_lat, apo_ca_lat, cother_lat, apo_cother_lat = torch.split( + x_lat, + [ + self._n_res_per_sample, + self._n_res_per_sample, + self._n_cother_per_sample, + self._n_cother_per_sample, + ], + dim=1, + ) + else: + ( + ca_lat, + apo_ca_lat, + cother_lat, + apo_cother_lat, + ca_cent_lat, + _, + lig_lat, + ) = torch.split( + x_lat, + [ + self._n_res_per_sample, + self._n_res_per_sample, + self._n_cother_per_sample, + self._n_cother_per_sample, + self._n_molid_per_sample, + self._n_molid_per_sample, + self._n_ligha_per_sample, + ], + dim=1, + ) + new_ca_glob = (ca_lat * self.ca_scale).contiguous().flatten(0, 1) + apo_new_ca_glob = (apo_ca_lat * self.ca_scale).contiguous().flatten(0, 1) + new_atom37_coords[self._cother_mask] = ( + (cother_lat * self.other_scale).contiguous().flatten(0, 1) + ) + apo_new_atom37_coords[self._cother_mask] = ( + (apo_cother_lat * self.other_scale).contiguous().flatten(0, 1) + ) + new_atom37_coords = new_atom37_coords + apo_new_atom37_coords = apo_new_atom37_coords + new_atom37_coords[~self._cother_mask] = 0 + apo_new_atom37_coords[~self._cother_mask] = 0 + new_atom37_coords[:, 1] = new_ca_glob + apo_new_atom37_coords[:, 1] = apo_new_ca_glob + self.nested_set(batch, self.prot_namemap[1], new_atom37_coords) + self.nested_set( + batch, + self.prot_namemap[1][:-1] + ("apo_" + self.prot_namemap[1][-1],), + apo_new_atom37_coords, + ) + if batch["misc"]["protein_only"]: + self.nested_set(batch, self.lig_namemap[1], None) + self.empty_cache() + return batch + new_ligha_coords_int = (lig_lat * self.other_scale).contiguous().flatten(0, 1) + new_ligha_coords_cent = (ca_cent_lat * self.ca_scale).contiguous().flatten(0, 1) + new_ligha_coords = ( + new_ligha_coords_int + new_ligha_coords_cent[indexer["gather_idx_i_structid"]] + ) + self.nested_set(batch, self.lig_namemap[1], new_ligha_coords) + self.empty_cache() + return batch + + def empty_cache(self): + """Empty the cached variables.""" + self._batch_size = None + self._cother_mask = None + self._n_res_per_sample = None + self._n_cother_per_sample = None + self._n_ligha_per_sample = None + self._n_molid_per_sample = None diff --git a/flowdock/models/flowdock_fm_module.py b/flowdock/models/flowdock_fm_module.py new file mode 100644 index 0000000..a50fdd9 --- /dev/null +++ b/flowdock/models/flowdock_fm_module.py @@ -0,0 +1,943 @@ +import os + +import esm +import numpy as np +import rootutils +import torch +from beartype.typing import Any, Dict, Literal, Optional, Union +from lightning import LightningModule +from omegaconf import DictConfig +from torchmetrics.functional.regression import ( + mean_absolute_error, + mean_squared_error, + pearson_corrcoef, + spearman_corrcoef, +) + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.models.components.losses import ( + eval_auxiliary_estimation_losses, + eval_structure_prediction_losses, +) +from flowdock.utils import RankedLogger +from flowdock.utils.data_utils import pdb_filepath_to_protein, prepare_batch +from flowdock.utils.model_utils import extract_esm_embeddings +from flowdock.utils.sampling_utils import multi_pose_sampling +from flowdock.utils.visualization_utils import ( + construct_prot_lig_pairs, + write_prot_lig_pairs_to_pdb_file, +) + +MODEL_BATCH = Dict[str, Any] +MODEL_STAGE = Literal["train", "val", "test", "predict"] +LOSS_MODES_LIST = [ + "structure_prediction", + "auxiliary_estimation", + "auxiliary_estimation_without_structure_prediction", +] +LOSS_MODES = Literal[ + "structure_prediction", + "auxiliary_estimation", + "auxiliary_estimation_without_structure_prediction", +] +AUX_ESTIMATION_STAGES = ["train", "val", "test"] + +log = RankedLogger(__name__, rank_zero_only=True) + + +class FlowDockFMLitModule(LightningModule): + """A `LightningModule` for geometric flow matching (FM) with FlowDock. + + A `LightningModule` implements 8 key methods: + + ```python + def __init__(self): + # Define initialization code here. + + def setup(self, stage): + # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. + # This hook is called on every process when using DDP. + + def training_step(self, batch, batch_idx): + # The complete training step. + + def validation_step(self, batch, batch_idx): + # The complete validation step. + + def test_step(self, batch, batch_idx): + # The complete test step. + + def predict_step(self, batch, batch_idx): + # The complete predict step. + + def configure_optimizers(self): + # Define and configure optimizers and LR schedulers. + ``` + + Docs: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + """ + + def __init__( + self, + net: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + compile: bool, + cfg: DictConfig, + **kwargs: Dict[str, Any], + ): + """Initialize a `FlowDockFMLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + :param compile: Whether to compile the model before training. + :param cfg: The model configuration. + :param kwargs: Additional keyword arguments. + """ + super().__init__() + + # the model along with its hyperparameters + self.net = net(cfg) + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False, ignore=["net"]) + + # for validating input arguments + if self.hparams.cfg.task.loss_mode not in LOSS_MODES_LIST: + raise ValueError( + f"Invalid loss mode: {self.hparams.cfg.task.loss_mode}. Must be one of {LOSS_MODES}." + ) + + # for inspecting the model's outputs during validation and testing + ( + self.training_step_outputs, + self.validation_step_outputs, + self.test_step_outputs, + self.predict_step_outputs, + ) = ( + [], + [], + [], + [], + ) + + def forward( + self, + batch: MODEL_BATCH, + iter_id: Union[int, str] = 0, + observed_block_contacts: Optional[torch.Tensor] = None, + contact_prediction: bool = True, + infer_geometry_prior: bool = False, + score: bool = False, + affinity: bool = True, + use_template: bool = False, + **kwargs: Dict[str, Any], + ) -> MODEL_BATCH: + """Perform a forward pass through the model. + + :param batch: A batch dictionary. + :param iter_id: The current iteration ID. + :param observed_block_contacts: Observed block contacts. + :param contact_prediction: Whether to predict contacts. + :param infer_geometry_prior: Whether to predict using a geometry prior. + :param score: Whether to predict a denoised complex structure. + :param affinity: Whether to predict ligand binding affinity. + :param use_template: Whether to use a template protein structure. + :param kwargs: Additional keyword arguments. + :return: Batch dictionary with outputs. + """ + return self.net( + batch, + iter_id=iter_id, + observed_block_contacts=observed_block_contacts, + contact_prediction=contact_prediction, + infer_geometry_prior=infer_geometry_prior, + score=score, + affinity=affinity, + use_template=use_template, + training=self.training, + **kwargs, + ) + + def model_step( + self, + batch: MODEL_BATCH, + batch_idx: int, + stage: MODEL_STAGE, + loss_mode: Optional[LOSS_MODES] = None, + ) -> MODEL_BATCH: + """Perform a single model step on a batch of data. + + :param batch: A batch dictionary. + :param batch_idx: The index of the current batch. + :param stage: The current model stage (i.e., `train`, `val`, `test`, or `predict`). + :param loss_mode: The loss mode to use for training. + :return: Batch dictionary with losses. + """ + prepare_batch(batch) + predicting_aux_outputs = ( + self.hparams.cfg.confidence.enabled or self.hparams.cfg.affinity.enabled + ) + is_aux_loss_stage = stage in AUX_ESTIMATION_STAGES + is_aux_batch = batch_idx % self.hparams.cfg.task.aux_batch_freq == 0 + struct_pred_loss_mode_requested = ( + loss_mode is not None and loss_mode == "structure_prediction" + ) + should_eval_aux_loss = ( + predicting_aux_outputs + and is_aux_loss_stage + and is_aux_batch + and not struct_pred_loss_mode_requested + and ( + not self.hparams.cfg.task.freeze_confidence + or ( + not self.hparams.cfg.task.freeze_affinity + and batch["features"]["affinity"].any().item() + ) + ) + ) + eval_aux_loss_mode_requested = ( + predicting_aux_outputs + and loss_mode is not None + and "auxiliary_estimation" in loss_mode + ) + if should_eval_aux_loss or eval_aux_loss_mode_requested: + return eval_auxiliary_estimation_losses( + self, batch, stage, loss_mode, training=self.training + ) + loss_fn = eval_structure_prediction_losses + return loss_fn(self, batch, batch_idx, self.device, stage, t_1=1.0) + + def on_train_start(self): + """Lightning hook that is called when training begins.""" + pass + + def training_step(self, batch: MODEL_BATCH, batch_idx: int) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch dictionary. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + if self.hparams.cfg.task.overfitting_example_name is not None and not all( + name == self.hparams.cfg.task.overfitting_example_name + for name in batch["metadata"]["sample_ID_per_sample"] + ): + return None + + try: + batch = self.model_step(batch, batch_idx, "train") + except Exception as e: + log.error( + f"Failed to perform training step for batch index {batch_idx} due to: {e}. Skipping example." + ) + return None + + if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]: + training_outputs = { + "affinity_logits": batch["outputs"]["affinity_logits"], + "affinity": batch["features"]["affinity"], + } + self.training_step_outputs.append(training_outputs) + + # return loss or backpropagation will fail + return batch["outputs"]["loss"] + + def on_train_epoch_end(self): + """Lightning hook that is called when a training epoch ends.""" + if self.hparams.cfg.affinity.enabled and any( + "affinity_logits" in output for output in self.training_step_outputs + ): + affinity_logits = torch.cat( + [ + output["affinity_logits"] + for output in self.training_step_outputs + if "affinity_logits" in output + ] + ) + affinity = torch.cat( + [ + output["affinity"] + for output in self.training_step_outputs + if "affinity_logits" in output + ] + ) + affinity_logits = affinity_logits[~affinity.isnan()] + affinity = affinity[~affinity.isnan()] + if affinity.numel() > 1: + # NOTE: there must be at least two affinity batches to properly score the affinity predictions + aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity)) + aff_mae = mean_absolute_error(affinity_logits, affinity) + aff_pearson = pearson_corrcoef(affinity_logits, affinity) + aff_spearman = spearman_corrcoef(affinity_logits, affinity) + self.log( + "train_affinity/RMSE", + aff_rmse.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=False, + ) + self.log( + "train_affinity/MAE", + aff_mae.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=False, + ) + self.log( + "train_affinity/Pearson", + aff_pearson.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=False, + ) + self.log( + "train_affinity/Spearman", + aff_spearman.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=False, + ) + self.training_step_outputs.clear() # free memory + + def on_validation_start(self): + """Lightning hook that is called when validation begins.""" + # create a directory to store model outputs from each validation epoch + os.makedirs( + os.path.join(self.trainer.default_root_dir, "validation_epoch_outputs"), exist_ok=True + ) + + def validation_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0): + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch dictionary. + :param batch_idx: The index of the current batch. + :param dataloader_idx: The index of the current dataloader. + """ + if self.hparams.cfg.task.overfitting_example_name is not None and not all( + name == self.hparams.cfg.task.overfitting_example_name + for name in batch["metadata"]["sample_ID_per_sample"] + ): + return None + + try: + prepare_batch(batch) + sampling_stats = self.net.sample_pl_complex_structures( + batch, + sampler="VDODE", + sampler_eta=1.0, + num_steps=10, + start_time=1.0, + exact_prior=False, + return_all_states=True, + eval_input_protein=True, + ) + all_frames = sampling_stats["all_frames"] + del sampling_stats["all_frames"] + for metric_name in sampling_stats.keys(): + log_stat = sampling_stats[metric_name].mean().detach() + batch_size = sampling_stats[metric_name].shape[0] + self.log( + f"val_sampling/{metric_name}", + log_stat, + on_step=True, + on_epoch=True, + batch_size=batch_size, + ) + sampling_stats = self.net.sample_pl_complex_structures( + batch, + sampler="VDODE", + sampler_eta=1.0, + num_steps=10, + start_time=1.0, + exact_prior=False, + use_template=False, + ) + for metric_name in sampling_stats.keys(): + log_stat = sampling_stats[metric_name].mean().detach() + batch_size = sampling_stats[metric_name].shape[0] + self.log( + f"val_sampling_notemplate/{metric_name}", + log_stat, + on_step=True, + on_epoch=True, + batch_size=batch_size, + ) + sampling_stats = self.net.sample_pl_complex_structures( + batch, + sampler="VDODE", + sampler_eta=1.0, + num_steps=10, + start_time=1.0, + return_summary_stats=True, + exact_prior=True, + ) + for metric_name in sampling_stats.keys(): + log_stat = sampling_stats[metric_name].mean().detach() + batch_size = sampling_stats[metric_name].shape[0] + self.log( + f"val_sampling_trueprior/{metric_name}", + log_stat, + on_step=True, + on_epoch=True, + batch_size=batch_size, + ) + batch = self.model_step(batch, batch_idx, "val") + except Exception as e: + log.error( + f"Failed to perform validation step for batch index {batch_idx} of dataloader {dataloader_idx} due to: {e}. Skipping example." + ) + return None + + # store model outputs for inspection + validation_outputs = {} + if self.hparams.cfg.task.visualize_generated_samples: + validation_outputs = { + "name": batch["metadata"]["sample_ID_per_sample"], + "batch_size": batch["metadata"]["num_structid"], + "aatype": batch["features"]["res_type"].long().cpu().numpy(), + "res_atom_mask": batch["features"]["res_atom_mask"].cpu().numpy(), + "protein_coordinates_list": [ + frame["receptor_padded"].cpu().numpy() for frame in all_frames + ], + "ligand_coordinates_list": [ + frame["ligands"].cpu().numpy() for frame in all_frames + ], + "ligand_mol": batch["metadata"]["mol_per_sample"], + "protein_batch_indexer": batch["indexer"]["gather_idx_a_structid"].cpu().numpy(), + "ligand_batch_indexer": batch["indexer"]["gather_idx_i_structid"].cpu().numpy(), + "gt_protein_coordinates": batch["features"]["res_atom_positions"].cpu().numpy(), + "gt_ligand_coordinates": batch["features"]["sdf_coordinates"].cpu().numpy(), + "dataloader_idx": dataloader_idx, + } + if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]: + validation_outputs.update( + { + "affinity_logits": batch["outputs"]["affinity_logits"], + "affinity": batch["features"]["affinity"], + "dataloader_idx": dataloader_idx, + } + ) + if validation_outputs: + self.validation_step_outputs.append(validation_outputs) + + def on_validation_epoch_end(self): + "Lightning hook that is called when a validation epoch ends." + if self.hparams.cfg.task.visualize_generated_samples: + for i, outputs in enumerate(self.validation_step_outputs): + for batch_index in range(outputs["batch_size"]): + prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index) + write_prot_lig_pairs_to_pdb_file( + prot_lig_pairs, + os.path.join( + self.trainer.default_root_dir, + "validation_epoch_outputs", + f"{outputs['name'][batch_index]}_validation_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}_dataloader_{outputs['dataloader_idx']}.pdb", + ), + ) + if self.hparams.cfg.affinity.enabled and any( + "affinity_logits" in output for output in self.validation_step_outputs + ): + affinity_logits = torch.cat( + [ + output["affinity_logits"] + for output in self.validation_step_outputs + if "affinity_logits" in output + ] + ) + affinity = torch.cat( + [ + output["affinity"] + for output in self.validation_step_outputs + if "affinity_logits" in output + ] + ) + affinity_logits = affinity_logits[~affinity.isnan()] + affinity = affinity[~affinity.isnan()] + if affinity.numel() > 1: + # NOTE: there must be at least two affinity batches to properly score the affinity predictions + aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity)) + aff_mae = mean_absolute_error(affinity_logits, affinity) + aff_pearson = pearson_corrcoef(affinity_logits, affinity) + aff_spearman = spearman_corrcoef(affinity_logits, affinity) + self.log( + "val_affinity/RMSE", + aff_rmse.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=True, + ) + self.log( + "val_affinity/MAE", + aff_mae.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=True, + ) + self.log( + "val_affinity/Pearson", + aff_pearson.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=True, + ) + self.log( + "val_affinity/Spearman", + aff_spearman.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=True, + ) + self.validation_step_outputs.clear() # free memory + + def on_test_start(self): + """Lightning hook that is called when testing begins.""" + # create a directory to store model outputs from each test epoch + os.makedirs( + os.path.join(self.trainer.default_root_dir, "test_epoch_outputs"), exist_ok=True + ) + + def test_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0): + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch dictionary. + :param batch_idx: The index of the current batch. + :param dataloader_idx: The index of the current dataloader. + """ + if self.hparams.cfg.task.overfitting_example_name is not None and not all( + name == self.hparams.cfg.task.overfitting_example_name + for name in batch["metadata"]["sample_ID_per_sample"] + ): + return None + + try: + prepare_batch(batch) + if self.hparams.cfg.task.eval_structure_prediction: + sampling_stats = self.net.sample_pl_complex_structures( + batch, + sampler=self.hparams.cfg.task.sampler, + sampler_eta=self.hparams.cfg.task.sampler_eta, + num_steps=self.hparams.cfg.task.num_steps, + start_time=self.hparams.cfg.task.start_time, + exact_prior=False, + return_all_states=True, + eval_input_protein=True, + ) + all_frames = sampling_stats["all_frames"] + del sampling_stats["all_frames"] + for metric_name in sampling_stats.keys(): + log_stat = sampling_stats[metric_name].mean().detach() + batch_size = sampling_stats[metric_name].shape[0] + self.log( + f"test_sampling/{metric_name}", + log_stat, + on_step=True, + on_epoch=True, + batch_size=batch_size, + ) + sampling_stats = self.net.sample_pl_complex_structures( + batch, + sampler=self.hparams.cfg.task.sampler, + sampler_eta=self.hparams.cfg.task.sampler_eta, + num_steps=self.hparams.cfg.task.num_steps, + start_time=self.hparams.cfg.task.start_time, + exact_prior=False, + use_template=False, + ) + for metric_name in sampling_stats.keys(): + log_stat = sampling_stats[metric_name].mean().detach() + batch_size = sampling_stats[metric_name].shape[0] + self.log( + f"test_sampling_notemplate/{metric_name}", + log_stat, + on_step=True, + on_epoch=True, + batch_size=batch_size, + ) + sampling_stats = self.net.sample_pl_complex_structures( + batch, + sampler=self.hparams.cfg.task.sampler, + sampler_eta=self.hparams.cfg.task.sampler_eta, + num_steps=self.hparams.cfg.task.num_steps, + start_time=self.hparams.cfg.task.start_time, + return_summary_stats=True, + exact_prior=True, + ) + for metric_name in sampling_stats.keys(): + log_stat = sampling_stats[metric_name].mean().detach() + batch_size = sampling_stats[metric_name].shape[0] + self.log( + f"test_sampling_trueprior/{metric_name}", + log_stat, + on_step=True, + on_epoch=True, + batch_size=batch_size, + ) + batch = self.model_step( + batch, batch_idx, "test", loss_mode=self.hparams.cfg.task.loss_mode + ) + except Exception as e: + log.error( + f"Failed to perform test step for {batch['metadata']['sample_ID_per_sample']} with batch index {batch_idx} of dataloader {dataloader_idx} due to: {e}." + ) + raise e + + # store model outputs for inspection + test_outputs = {} + if ( + self.hparams.cfg.task.visualize_generated_samples + and self.hparams.cfg.task.eval_structure_prediction + ): + test_outputs.update( + { + "name": batch["metadata"]["sample_ID_per_sample"], + "batch_size": batch["metadata"]["num_structid"], + "aatype": batch["features"]["res_type"].long().cpu().numpy(), + "res_atom_mask": batch["features"]["res_atom_mask"].cpu().numpy(), + "protein_coordinates_list": [ + frame["receptor_padded"].cpu().numpy() for frame in all_frames + ], + "ligand_coordinates_list": [ + frame["ligands"].cpu().numpy() for frame in all_frames + ], + "ligand_mol": batch["metadata"]["mol_per_sample"], + "protein_batch_indexer": batch["indexer"]["gather_idx_a_structid"] + .cpu() + .numpy(), + "ligand_batch_indexer": batch["indexer"]["gather_idx_i_structid"] + .cpu() + .numpy(), + "gt_protein_coordinates": batch["features"]["res_atom_positions"] + .cpu() + .numpy(), + "gt_ligand_coordinates": batch["features"]["sdf_coordinates"].cpu().numpy(), + "dataloader_idx": dataloader_idx, + } + ) + if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]: + test_outputs.update( + { + "affinity_logits": batch["outputs"]["affinity_logits"], + "affinity": batch["features"]["affinity"], + "dataloader_idx": dataloader_idx, + } + ) + if test_outputs: + self.test_step_outputs.append(test_outputs) + + def on_test_epoch_end(self): + """Lightning hook that is called when a test epoch ends.""" + if ( + self.hparams.cfg.task.visualize_generated_samples + and self.hparams.cfg.task.eval_structure_prediction + ): + for i, outputs in enumerate(self.test_step_outputs): + for batch_index in range(outputs["batch_size"]): + prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index) + write_prot_lig_pairs_to_pdb_file( + prot_lig_pairs, + os.path.join( + self.trainer.default_root_dir, + "test_epoch_outputs", + f"{outputs['name'][batch_index]}_test_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}_dataloader_{outputs['dataloader_idx']}.pdb", + ), + ) + if self.hparams.cfg.affinity.enabled and any( + "affinity_logits" in output for output in self.test_step_outputs + ): + affinity_logits = torch.cat( + [ + output["affinity_logits"] + for output in self.test_step_outputs + if "affinity_logits" in output + ] + ) + affinity = torch.cat( + [ + output["affinity"] + for output in self.test_step_outputs + if "affinity_logits" in output + ] + ) + affinity_logits = affinity_logits[~affinity.isnan()] + affinity = affinity[~affinity.isnan()] + if affinity.numel() > 1: + # NOTE: there must be at least two affinity batches to properly score the affinity predictions + aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity)) + aff_mae = mean_absolute_error(affinity_logits, affinity) + aff_pearson = pearson_corrcoef(affinity_logits, affinity) + aff_spearman = spearman_corrcoef(affinity_logits, affinity) + self.log( + "test_affinity/RMSE", + aff_rmse.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=True, + ) + self.log( + "test_affinity/MAE", + aff_mae.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=True, + ) + self.log( + "test_affinity/Pearson", + aff_pearson.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=True, + ) + self.log( + "test_affinity/Spearman", + aff_spearman.detach(), + on_epoch=True, + batch_size=len(affinity), + sync_dist=True, + ) + self.test_step_outputs.clear() # free memory + + def on_predict_start(self): + """Lightning hook that is called when testing begins.""" + # create a directory to store model outputs from each predict epoch + os.makedirs( + os.path.join(self.trainer.default_root_dir, "predict_epoch_outputs"), exist_ok=True + ) + + log.info("Loading pretrained ESM model...") + esm_model, self.esm_alphabet = esm.pretrained.load_model_and_alphabet_hub( + self.hparams.cfg.model.cfg.protein_encoder.esm_version + ) + self.esm_model = esm_model.eval().float() + self.esm_batch_converter = self.esm_alphabet.get_batch_converter() + self.esm_model.cpu() + + skip_loading_esmfold_weights = ( + # skip loading ESMFold weights if the template protein structure for a single complex input is provided + self.hparams.cfg.task.csv_path is None + and self.hparams.cfg.task.input_template is not None + and os.path.exists(self.hparams.cfg.task.input_template) + ) + if not skip_loading_esmfold_weights: + log.info("Loading pretrained ESMFold model...") + esmfold_model = esm.pretrained.esmfold_v1() + self.esmfold_model = esmfold_model.eval().float() + self.esmfold_model.set_chunk_size(self.hparams.cfg.esmfold_chunk_size) + self.esmfold_model.cpu() + + def predict_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0): + """Perform a single predict step on a batch of data from the predict set. + + :param batch: A batch dictionary. + :param batch_idx: The index of the current batch. + :param dataloader_idx: The index of the current dataloader. + """ + rec_path = batch["rec_path"][0] + ligand_paths = list( + path[0] for path in batch["lig_paths"] + ) # unpack a list of (batched) single-element string tuples + sample_id = batch["sample_id"][0] if "sample_id" in batch else "sample" + input_template = batch["input_template"][0] if "input_template" in batch else None + + out_path = ( + os.path.join(self.hparams.cfg.out_path, sample_id) + if "sample_id" in batch + else self.hparams.cfg.out_path + ) + + # generate ESM embeddings for the protein + protein = pdb_filepath_to_protein(rec_path) + sequences = [ + "".join(np.array(list(chain_seq))[chain_mask]) + for (_, chain_seq, chain_mask) in protein.letter_sequences + ] + esm_embeddings = extract_esm_embeddings( + self.esm_model, + self.esm_alphabet, + self.esm_batch_converter, + sequences, + device="cpu", + esm_repr_layer=self.hparams.cfg.model.cfg.protein_encoder.esm_repr_layer, + ) + sequences_to_embeddings = { + f"{seq}:{i}": esm_embeddings[i].cpu().numpy() for i, seq in enumerate(sequences) + } + + # generate initial ESMFold-predicted structure for the protein if a template is not provided + apo_rec_path = None + if input_template and os.path.exists(input_template): + apo_protein = pdb_filepath_to_protein(input_template) + apo_chain_seq_masked = "".join( + "".join(np.array(list(chain_seq))[chain_mask]) + for (_, chain_seq, chain_mask) in apo_protein.letter_sequences + ) + chain_seq_masked = "".join( + "".join(np.array(list(chain_seq))[chain_mask]) + for (_, chain_seq, chain_mask) in protein.letter_sequences + ) + if apo_chain_seq_masked != chain_seq_masked: + log.error( + f"Provided template protein structure {input_template} does not match the input protein sequence within {rec_path}. Skipping example {sample_id} at batch index {batch_idx} of dataloader {dataloader_idx}." + ) + return None + log.info(f"Starting from provided template protein structure: {input_template}") + apo_rec_path = input_template + if apo_rec_path is None and self.hparams.cfg.prior_type == "esmfold": + esmfold_sequence = ":".join(sequences) + apo_rec_path = rec_path.replace(".pdb", "_apo.pdb") + with torch.no_grad(): + esmfold_pdb_output = self.esmfold_model.infer_pdb(esmfold_sequence) + with open(apo_rec_path, "w") as f: + f.write(esmfold_pdb_output) + + _, _, _, _, _, all_frames, batch_all, b_factors, plddt_rankings = multi_pose_sampling( + rec_path, + ligand_paths, + self.hparams.cfg, + self, + out_path, + separate_pdb=self.hparams.cfg.separate_pdb, + apo_receptor_path=apo_rec_path, + sample_id=sample_id, + protein=protein, + sequences_to_embeddings=sequences_to_embeddings, + return_all_states=self.hparams.cfg.task.visualize_generated_samples, + auxiliary_estimation_only=self.hparams.cfg.task.auxiliary_estimation_only, + ) + # store model outputs for inspection + if self.hparams.cfg.task.visualize_generated_samples: + predict_outputs = { + "name": batch_all["metadata"]["sample_ID_per_sample"], + "batch_size": batch_all["metadata"]["num_structid"], + "aatype": batch_all["features"]["res_type"].long().cpu().numpy(), + "res_atom_mask": batch_all["features"]["res_atom_mask"].cpu().numpy(), + "protein_coordinates_list": [ + frame["receptor_padded"].cpu().numpy() for frame in all_frames + ], + "ligand_coordinates_list": [ + frame["ligands"].cpu().numpy() for frame in all_frames + ], + "ligand_mol": batch_all["metadata"]["mol_per_sample"], + "protein_batch_indexer": batch_all["indexer"]["gather_idx_a_structid"] + .cpu() + .numpy(), + "ligand_batch_indexer": batch_all["indexer"]["gather_idx_i_structid"] + .cpu() + .numpy(), + "b_factors": b_factors, + "plddt_rankings": plddt_rankings, + } + self.predict_step_outputs.append(predict_outputs) + + def on_predict_epoch_end(self): + """Lightning hook that is called when a predict epoch ends.""" + if self.hparams.cfg.task.visualize_generated_samples: + for i, outputs in enumerate(self.predict_step_outputs): + for batch_index in range(outputs["batch_size"]): + prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index) + ranking = ( + outputs["plddt_rankings"][batch_index] + if "plddt_rankings" in outputs + else None + ) + write_prot_lig_pairs_to_pdb_file( + prot_lig_pairs, + os.path.join( + self.hparams.cfg.out_path, + outputs["name"][batch_index], + "predict_epoch_outputs", + f"{outputs['name'][batch_index]}{f'_rank{ranking + 1}' if ranking is not None else ''}_predict_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}.pdb", + ), + ) + self.predict_step_outputs.clear() # free memory + + def on_after_backward(self): + """Skip updates in case of unstable gradients. + + Reference: https://github.com/Lightning-AI/lightning/issues/4956 + """ + valid_gradients = True + for _, param in self.named_parameters(): + if param.grad is not None: + valid_gradients = not ( + torch.isnan(param.grad).any() or torch.isinf(param.grad).any() + ) + if not valid_gradients: + break + if not valid_gradients: + log.warning( + "Detected `inf` or `nan` values in gradients. Not updating model parameters." + ) + self.zero_grad() + + def optimizer_step( + self, + epoch, + batch_idx, + optimizer, + optimizer_closure, + ): + """Override the optimizer step to dynamically update the learning rate. + + :param epoch: The current epoch. + :param batch_idx: The index of the current batch. + :param optimizer: The optimizer to use for training. + :param optimizer_closure: The optimizer closure. + """ + # update params + optimizer = optimizer.optimizer + optimizer.step(closure=optimizer_closure) + + # warm up learning rate + if self.trainer.global_step < 1000: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / 1000.0) + for pg in optimizer.param_groups: + # NOTE: `self.hparams.optimizer.keywords["lr"]` refers to the optimizer's initial learning rate + pg["lr"] = lr_scale * self.hparams.optimizer.keywords["lr"] + + def setup(self, stage: str): + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. + + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + if self.hparams.compile and stage == "fit": + self.net = torch.compile(self.net) + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + try: + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + except TypeError: + # NOTE: strategies such as DeepSpeed require `params` to instead be specified as `model_params` + optimizer = self.hparams.optimizer(model_params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +if __name__ == "__main__": + _ = FlowDockFMLitModule(None, None, None, None) diff --git a/flowdock/sample.py b/flowdock/sample.py new file mode 100644 index 0000000..a8f3705 --- /dev/null +++ b/flowdock/sample.py @@ -0,0 +1,284 @@ +import os + +import hydra +import lightning as L +import lovely_tensors as lt +import pandas as pd +import rootutils +import torch +from beartype.typing import Any, Dict, List, Tuple +from lightning import LightningModule, Trainer +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies.strategy import Strategy +from omegaconf import DictConfig, open_dict +from torch.utils.data import DataLoader + +lt.monkey_patch() + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from flowdock import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable +from flowdock.utils import ( + RankedLogger, + extras, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) +from flowdock.utils.data_utils import ( + create_full_pdb_with_zero_coordinates, + create_temp_ligand_frag_files, +) + +log = RankedLogger(__name__, rank_zero_only=True) + +AVAILABLE_SAMPLING_TASKS = ["batched_structure_sampling"] + + +class SamplingDataset(torch.utils.data.Dataset): + """Dataset for sampling.""" + + def __init__(self, cfg: DictConfig): + """Initializes the SamplingDataset.""" + if cfg.sampling_task == "batched_structure_sampling": + if cfg.csv_path is not None: + # handle variable CSV inputs + df_rows = [] + self.df = pd.read_csv(cfg.csv_path) + for _, row in self.df.iterrows(): + sample_id = row.id + input_receptor = row.input_receptor + input_ligand = row.input_ligand + input_template = row.input_template + assert input_receptor is not None, "Receptor path is required for sampling." + if input_ligand is not None: + if input_ligand.endswith(".sdf"): + ligand_paths = create_temp_ligand_frag_files(input_ligand) + else: + ligand_paths = list(input_ligand.split("|")) + else: + ligand_paths = None # handle `null` ligand input + if not input_receptor.endswith(".pdb"): + log.warning( + "Assuming the provided receptor input is a protein sequence. Creating a dummy PDB file." + ) + create_full_pdb_with_zero_coordinates( + input_receptor, os.path.join(cfg.out_path, f"input_{sample_id}.pdb") + ) + input_receptor = os.path.join(cfg.out_path, f"input_{sample_id}.pdb") + df_row = { + "sample_id": sample_id, + "rec_path": input_receptor, + "lig_paths": ligand_paths, + } + if input_template is not None: + df_row["input_template"] = input_template + df_rows.append(df_row) + self.df = pd.DataFrame(df_rows) + else: + sample_id = cfg.sample_id + input_receptor = cfg.input_receptor + input_ligand = cfg.input_ligand + if input_ligand is not None: + if input_ligand.endswith(".sdf"): + ligand_paths = create_temp_ligand_frag_files(input_ligand) + else: + ligand_paths = list(input_ligand.split("|")) + else: + ligand_paths = None # handle `null` ligand input + if not input_receptor.endswith(".pdb"): + log.warning( + "Assuming the provided receptor input is a protein sequence. Creating a dummy PDB file." + ) + create_full_pdb_with_zero_coordinates( + input_receptor, os.path.join(cfg.out_path, "input.pdb") + ) + input_receptor = os.path.join(cfg.out_path, "input.pdb") + self.df = pd.DataFrame( + [ + { + "sample_id": sample_id, + "rec_path": input_receptor, + "lig_paths": ligand_paths, + } + ] + ) + if cfg.input_template is not None: + self.df["input_template"] = cfg.input_template + else: + raise NotImplementedError(f"Sampling task {cfg.sampling_task} is not implemented.") + + def __len__(self): + """Returns the length of the dataset.""" + return len(self.df) + + def __getitem__(self, idx: int) -> Tuple[str, str]: + """Returns the input receptor and input ligand.""" + return self.df.iloc[idx].to_dict() + + +@task_wrapper +def sample(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Samples using given checkpoint on a datamodule predictset. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. + """ + assert cfg.ckpt_path, "Please provide a checkpoint path with which to sample!" + assert os.path.exists(cfg.ckpt_path), f"Checkpoint path {cfg.ckpt_path} does not exist!" + assert ( + cfg.sampling_task in AVAILABLE_SAMPLING_TASKS + ), f"Sampling task {cfg.sampling_task} is not one of the following available tasks: {AVAILABLE_SAMPLING_TASKS}." + assert (cfg.input_receptor is not None and cfg.input_ligand is not None) or ( + cfg.csv_path is not None and os.path.exists(cfg.csv_path) + ), "Please provide either an input receptor and ligand or a CSV file with receptor and ligand sequences/filepaths." + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info( + f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}." + ) + torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision) + + # Establish model input arguments + with open_dict(cfg): + # NOTE: Structure trajectories will not be visualized when performing auxiliary estimation only + cfg.model.cfg.prior_type = cfg.prior_type + cfg.model.cfg.task.detect_covalent = cfg.detect_covalent + cfg.model.cfg.task.use_template = cfg.use_template + cfg.model.cfg.task.csv_path = cfg.csv_path + cfg.model.cfg.task.input_receptor = cfg.input_receptor + cfg.model.cfg.task.input_ligand = cfg.input_ligand + cfg.model.cfg.task.input_template = cfg.input_template + cfg.model.cfg.task.visualize_generated_samples = ( + cfg.visualize_sample_trajectories and not cfg.auxiliary_estimation_only + ) + cfg.model.cfg.task.auxiliary_estimation_only = cfg.auxiliary_estimation_only + if cfg.latent_model is not None: + with open_dict(cfg): + cfg.model.cfg.latent_model = cfg.latent_model + with open_dict(cfg): + if cfg.start_time == "auto": + cfg.start_time = 1.0 + else: + cfg.start_time = float(cfg.start_time) + + log.info("Converting sampling inputs into a ") + dataloaders: List[DataLoader] = [ + DataLoader( + SamplingDataset(cfg), + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=False, + ) + ] + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + model.hparams.cfg.update(cfg) # update model config with the sampling config + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + plugins = None + if "_target_" in cfg.environment: + log.info(f"Instantiating environment <{cfg.environment._target_}>") + plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment) + + strategy = getattr(cfg.trainer, "strategy", None) + if "_target_" in cfg.strategy: + log.info(f"Instantiating strategy <{cfg.strategy._target_}>") + strategy: Strategy = hydra.utils.instantiate(cfg.strategy) + if ( + "mixed_precision" in strategy.__dict__ + and getattr(strategy, "mixed_precision", None) is not None + ): + strategy.mixed_precision.param_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) + if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None + else None + ) + strategy.mixed_precision.reduce_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) + if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None + else None + ) + strategy.mixed_precision.buffer_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) + if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None + else None + ) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = ( + hydra.utils.instantiate( + cfg.trainer, + logger=logger, + plugins=plugins, + strategy=strategy, + ) + if strategy is not None + else hydra.utils.instantiate( + cfg.trainer, + logger=logger, + plugins=plugins, + ) + ) + + object_dict = { + "cfg": cfg, + "model": model, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) + + metric_dict = trainer.callback_metrics + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="sample.yaml") +def main(cfg: DictConfig) -> None: + """Main entry point for sampling. + + :param cfg: DictConfig configuration composed by Hydra. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + sample(cfg) + + +if __name__ == "__main__": + register_custom_omegaconf_resolvers() + main() diff --git a/flowdock/train.py b/flowdock/train.py new file mode 100644 index 0000000..a790073 --- /dev/null +++ b/flowdock/train.py @@ -0,0 +1,196 @@ +import os + +import hydra +import lightning as L +import lovely_tensors as lt +import rootutils +import torch +from beartype.typing import Any, Dict, List, Optional, Tuple +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies.strategy import Strategy +from omegaconf import DictConfig + +lt.monkey_patch() + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from flowdock import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable +from flowdock.utils import ( + RankedLogger, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +@task_wrapper +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info( + f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}." + ) + torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, stage="fit") + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + plugins = None + if "_target_" in cfg.environment: + log.info(f"Instantiating environment <{cfg.environment._target_}>") + plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment) + + strategy = getattr(cfg.trainer, "strategy", None) + if "_target_" in cfg.strategy: + log.info(f"Instantiating strategy <{cfg.strategy._target_}>") + strategy: Strategy = hydra.utils.instantiate(cfg.strategy) + if ( + "mixed_precision" in strategy.__dict__ + and getattr(strategy, "mixed_precision", None) is not None + ): + strategy.mixed_precision.param_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype) + if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None + else None + ) + strategy.mixed_precision.reduce_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype) + if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None + else None + ) + strategy.mixed_precision.buffer_dtype = ( + resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype) + if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None + else None + ) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = ( + hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + plugins=plugins, + strategy=strategy, + ) + if strategy is not None + else hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + plugins=plugins, + ) + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + ckpt_path = None + if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")): + ckpt_path = cfg.get("ckpt_path") + elif cfg.get("ckpt_path"): + log.warning( + "`ckpt_path` was given, but the path does not exist. Training with new model weights." + ) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + register_custom_omegaconf_resolvers() + main() diff --git a/flowdock/utils/__init__.py b/flowdock/utils/__init__.py new file mode 100644 index 0000000..9886aeb --- /dev/null +++ b/flowdock/utils/__init__.py @@ -0,0 +1,5 @@ +from flowdock.utils.instantiators import instantiate_callbacks, instantiate_loggers +from flowdock.utils.logging_utils import log_hyperparameters +from flowdock.utils.pylogger import RankedLogger +from flowdock.utils.rich_utils import enforce_tags, print_config_tree +from flowdock.utils.utils import extras, get_metric_value, task_wrapper diff --git a/flowdock/utils/data_utils.py b/flowdock/utils/data_utils.py new file mode 100644 index 0000000..7229f41 --- /dev/null +++ b/flowdock/utils/data_utils.py @@ -0,0 +1,1846 @@ +import copy +import dataclasses +import io +import os +import pickle # nosec +import random +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import rootutils +import torch +from beartype import beartype +from beartype.typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from Bio.PDB import PDBParser, Polypeptide +from Bio.PDB.Atom import Atom +from Bio.PDB.Chain import Chain +from Bio.PDB.Model import Model +from Bio.PDB.PDBIO import PDBIO +from Bio.PDB.Residue import Residue +from Bio.PDB.Structure import Structure +from openfold.np.protein import Protein as OFProtein +from openfold.np.protein import to_pdb as of_to_pdb +from openfold.utils import tensor_utils +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Geometry import Point3D + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.data.components import residue_constants +from flowdock.data.components.mol_features import ( + attach_pair_idx_and_encodings, + collate_numpy_samples, + process_mol_file, +) +from flowdock.data.components.process_mols import read_molecule +from flowdock.data.components.residue_constants import restype_1to3 as af_restype_1to3 +from flowdock.data.components.residue_constants import restypes as af_restypes +from flowdock.models.components.transforms import LatentCoordinateConverter +from flowdock.utils import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + +MODEL_BATCH = Dict[str, Any] + +MOAD_UNIT_CONVERSION_DICT = { + np.nan: np.nan, # NaN to NaN + "uM": 1e-6, # micromolar to M + "nM": 1e-9, # nanomolar to M + "mM": 1e-3, # millimolar to M + "pM": 1e-12, # picomolar to M + "M": 1, # Molar to M + "M^-1": 1, # Reciprocal Molar to M + "fM": 1e-15, # femtomolar to M +} + +# Complete sequence of chain IDs supported by the PDB format. +PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + +# From: https://github.com/uw-ipd/RoseTTAFold2 +NUM_TO_AA = [ + "ALA", + "ARG", + "ASN", + "ASP", + "CYS", + "GLN", + "GLU", + "GLY", + "HIS", + "ILE", + "LEU", + "LYS", + "MET", + "PHE", + "PRO", + "SER", + "THR", + "TRP", + "TYR", + "VAL", + "UNK", + "MAS", +] +AA_TO_NUM = {x: i for i, x in enumerate(NUM_TO_AA)} +AA_TO_LONG = [ + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + None, + None, + None, + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "3HB ", + None, + None, + None, + None, + None, + None, + None, + None, + ), # ala + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD ", + " NE ", + " CZ ", + " NH1", + " NH2", + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "1HG ", + "2HG ", + "1HD ", + "2HD ", + " HE ", + "1HH1", + "2HH1", + "1HH2", + "2HH2", + ), # arg + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " OD1", + " ND2", + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "1HD2", + "2HD2", + None, + None, + None, + None, + None, + None, + None, + ), # asn + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " OD1", + " OD2", + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + None, + None, + None, + None, + None, + None, + None, + None, + None, + ), # asp + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " SG ", + None, + None, + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + " HG ", + None, + None, + None, + None, + None, + None, + None, + None, + ), # cys + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD ", + " OE1", + " NE2", + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "1HG ", + "2HG ", + "1HE2", + "2HE2", + None, + None, + None, + None, + None, + ), # gln + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD ", + " OE1", + " OE2", + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "1HG ", + "2HG ", + None, + None, + None, + None, + None, + None, + None, + ), # glu + ( + " N ", + " CA ", + " C ", + " O ", + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + " H ", + "1HA ", + "2HA ", + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ), # gly + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " ND1", + " CD2", + " CE1", + " NE2", + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + " HD2", + " HE1", + " HE2", + None, + None, + None, + None, + None, + None, + ), # his + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG1", + " CG2", + " CD1", + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + " HB ", + "1HG2", + "2HG2", + "3HG2", + "1HG1", + "2HG1", + "1HD1", + "2HD1", + "3HD1", + None, + None, + ), # ile + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD1", + " CD2", + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + " HG ", + "1HD1", + "2HD1", + "3HD1", + "1HD2", + "2HD2", + "3HD2", + None, + None, + ), # leu + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD ", + " CE ", + " NZ ", + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "1HG ", + "2HG ", + "1HD ", + "2HD ", + "1HE ", + "2HE ", + "1HZ ", + "2HZ ", + "3HZ ", + ), # lys + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " SD ", + " CE ", + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "1HG ", + "2HG ", + "1HE ", + "2HE ", + "3HE ", + None, + None, + None, + None, + ), # met + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD1", + " CD2", + " CE1", + " CE2", + " CZ ", + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + " HD1", + " HD2", + " HE1", + " HE2", + " HZ ", + None, + None, + None, + None, + ), # phe + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD ", + None, + None, + None, + None, + None, + None, + None, + " HA ", + "1HB ", + "2HB ", + "1HG ", + "2HG ", + "1HD ", + "2HD ", + None, + None, + None, + None, + None, + None, + ), # pro + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " OG ", + None, + None, + None, + None, + None, + None, + None, + None, + " H ", + " HG ", + " HA ", + "1HB ", + "2HB ", + None, + None, + None, + None, + None, + None, + None, + None, + ), # ser + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " OG1", + " CG2", + None, + None, + None, + None, + None, + None, + None, + " H ", + " HG1", + " HA ", + " HB ", + "1HG2", + "2HG2", + "3HG2", + None, + None, + None, + None, + None, + None, + ), # thr + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD1", + " CD2", + " NE1", + " CE2", + " CE3", + " CZ2", + " CZ3", + " CH2", + " H ", + " HA ", + "1HB ", + "2HB ", + " HD1", + " HE1", + " HZ2", + " HH2", + " HZ3", + " HE3", + None, + None, + None, + ), # trp + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " CD1", + " CD2", + " CE1", + " CE2", + " CZ ", + " OH ", + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + " HD1", + " HE1", + " HE2", + " HD2", + " HH ", + None, + None, + None, + None, + ), # tyr + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG1", + " CG2", + None, + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + " HB ", + "1HG1", + "2HG1", + "3HG1", + "1HG2", + "2HG2", + "3HG2", + None, + None, + None, + None, + ), # val + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + None, + None, + None, + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "3HB ", + None, + None, + None, + None, + None, + None, + None, + None, + ), # unk + ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + None, + None, + None, + None, + None, + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + "3HB ", + None, + None, + None, + None, + None, + None, + None, + None, + ), # mask +] + + +@dataclasses.dataclass() +class FDProtein: + """Protein structure representation.""" + + # The first entry stores amino acid sequence in letter representation. + # The second entry stores a 0-1 mask for observed standard residues. + # Non-standard residues are mapped to to interact with protein language models. + letter_sequences: List[Tuple[str, str, np.ndarray]] + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, atom_type_num, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Added + # Integer for atom type. + atomtypes: np.ndarray # [num_res, element_type_num] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, atom_type_num] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, atom_type_num] + + def __post_init__(self): + if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: + raise ValueError( + f"Cannot build an instance with more than {PDB_MAX_CHAINS} chains " + "because these cannot be written to PDB format." + ) + + +@beartype +def combine_molecules(molecule_list: List[Chem.Mol]) -> Chem.Mol: + """Combine a list of RDKit molecules into a single molecule. + + :param molecule_list: A list of RDKit molecules. + :return: A single RDKit molecule. + """ + # Initialize the combined molecule with the first molecule in the list + new_mol = molecule_list[0] + + # Iterate through the remaining molecules and combine them pairwise + for mol in molecule_list[1:]: + new_mol = Chem.CombineMols(new_mol, mol) + + return new_mol + + +@beartype +def pdb_filepath_to_protein( + pdb_filepath: str, + model_id: int = 0, + atom_occupancy_min_threshold: int = 0.5, + filter_out_hetero_residues: bool = True, + allow_insertion_code: bool = True, + accept_only_valid_backbone_residues: bool = True, + chain_id: Optional[Union[List[str], str]] = None, + bounding_box: Optional[np.ndarray] = None, + res_start: Optional[int] = None, + res_end: Optional[int] = None, +) -> FDProtein: + """Takes a PDB filepath and constructs a FDProtein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. All water residues will be ignored. + All hetero residues will be ignored if `filter_out_hetero_residues` is `True`. + All residues without valid positions for their N, Ca, C, and O atoms will be ignored. + + Adapted from: https://github.com/aqlaboratory/openfold and https://github.com/zrqiao/NeuralPLexer + + :param pdb_filepath: The filepath to the PDB file to parse. + :param model_id: The model number to parse. + :param atom_occupancy_min_threshold: The minimum occupancy threshold for atoms. + :param filter_out_hetero_residues: If True, then hetero residues will be ignored. + :param allow_insertion_code: If True, residues with insertion codes are parsed. + :param accept_only_valid_backbone_residues: If True, only residues with valid N, Ca, C, and O atoms are parsed. + :param chain_id: If `chain_id` is specified (e.g., `[A, B]`), then only those chains + are parsed. Otherwise, all chains are parsed. + :param bounding_box: If provided, only chains with backbone intersecting with + the box are parsed. + :param res_start: If provided, only residues with index >= res_start are parsed. + :param res_end: If provided, only residues with index <= res_end are parsed. + :return: A new `FDProtein` parsed from the PDB contents. + """ + assert pdb_filepath.endswith(".pdb"), f"Invalid file extension: {pdb_filepath}" + assert os.path.exists(pdb_filepath), f"File not found: {pdb_filepath}" + + with open(pdb_filepath) as pdb_fh: + pdb_str = pdb_fh.read() + + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure("none", pdb_fh) + models = list(structure.get_models()) + if not (0 <= model_id < len(models)): + raise ValueError(f"Model ID {model_id} is out of range") + model = models[model_id] + if isinstance(chain_id, str): + chain_id = [chain_id] + + seq = [] + atom_positions = [] + aatype = [] + atomtypes = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if chain_id is not None and chain.get_id() not in chain_id: + log.info(f"In `pdb_filepath_to_protein()`, skipping chain {chain.get_id()}") + continue + if bounding_box is not None: + ca_pos = np.array([res["CA"].get_coord() for res in chain if res.has_id("CA")]) + ca_in_box = (ca_pos > bounding_box[0]) & (ca_pos < bounding_box[1]) + if not np.any(np.all(ca_in_box, axis=1), axis=0): + log.info( + f"In `pdb_filepath_to_protein()`, skipping chain {chain.get_id()} as it is not in the bounding box" + ) + continue + for res_idx, res in enumerate(chain): + if res_start is not None and res_idx < res_start: + continue + if res_end is not None and res_idx > res_end: + continue + if res.get_resname() == "HOH" or ( + filter_out_hetero_residues and len(res.get_id()[0]) > 1 + ): + log.info( + f"In `pdb_filepath_to_protein()`, skipping residue {res.get_id()} as it is a water residue or a hetero residue." + ) + continue + # strict bounding + if bounding_box is not None: + if not res.has_id("CA"): + continue + ca_pos = res["CA"].get_coord() + ca_in_box = (ca_pos > bounding_box[0]) & (ca_pos < bounding_box[1]) + if not np.all(ca_in_box): + continue + if res.id[2] != " ": + if allow_insertion_code: + log.warning( + f"PDB contains an insertion code at chain {chain.id} and residue " + f"index {res.id[1]} and `allow_insertion_code` is set to True. " + "Please ensure the residue indices are consecutive before performing downstream analysis." + ) + else: + raise ValueError( + f"PDB contains an insertion code at chain {chain.id} and residue " + f"index {res.id[1]}. Such samples are not supported by default." + ) + # NOTE: like the AlphaFold parser, we parse all non-standard residues + res_shortname = residue_constants.restype_3to1.get(res.resname, "X") + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num + ) + pos = np.zeros((residue_constants.atom_type_num, 3)) + eletypes = np.zeros((residue_constants.atom_type_num,)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + # sidechain_atom_order = 3 + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + # (potentially) remove atoms that are too flexible + if atom.occupancy and atom.occupancy < atom_occupancy_min_threshold: + # NOTE: this is not a standard AlphaFold filter; + # it was originally used in the NeuralPLexer parser + continue + eletypes[residue_constants.atom_order[atom.name]] = residue_constants.element_id[ + atom.element + ] + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1.0 + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if accept_only_valid_backbone_residues and (np.sum(mask[:3]) < 3 or mask[4] < 1): + # as requested, skip if the backbone atoms are not resolved (NOTE: atom37 ordering is N, Ca, C, CB, O, ... -> we only check N, Ca, C, and O) + log.warning( + f"In `pdb_filepath_to_protein()`, skipping residue {res.id[1]} in chain {chain.id} as not enough backbone atoms are reported." + ) + continue + seq.append(res_shortname) + aatype.append(restype_idx) + atomtypes.append(eletypes) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # NOTE: chain IDs are usually characters, so we will map these to integers. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + # evaluate the gapless protein sequence for each chain + seqs = [] + last_chain_idx = -1 + last_chain_seq = None + last_chain_mask = None + for site_idx in range(len(seq)): + if chain_ids[site_idx] != last_chain_idx: + if last_chain_seq is not None: + last_chain_seq = "".join(last_chain_seq) + seqs.append((last_chain_idx, "".join(last_chain_seq), np.array(last_chain_mask))) + last_chain_idx = chain_ids[site_idx] + last_chain_seq = [] + last_chain_mask = [] + last_res_id = -999 + if residue_index[site_idx] <= last_res_id: + raise ValueError( + f"PDB residue index is not monotonous at chain {chain.id} and residue " + f"index {res.id[1]}. The sample is discarded." + ) + elif last_res_id == -999: + gap_size = 0 + else: + gap_size = residue_index[site_idx] - last_res_id - 1 + for _ in range(gap_size): + last_chain_seq.append("") + last_chain_mask.append(False) + last_chain_seq.append(seq[site_idx]) + last_chain_mask.append(True) + seqs.append((last_chain_idx, "".join(last_chain_seq), np.array(last_chain_mask))) + + for chain_seq in seqs: + if np.mean(chain_seq[2]) < 0.75: + raise ValueError( + f"The PDB structure residue coverage for {chain.id}" + f"is below 75%. The sample is discarded." + ) + + return FDProtein( + letter_sequences=seqs, + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + atomtypes=np.array(atomtypes), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors), + ) + + +def get_protein_indexer(rec_features: Dict[str, Any], edge_cutoff: int = 50) -> Dict[str, Any]: + """Get the protein indexer. + + :param rec_features: Protein features. + :param edge_cutoff: Edge cutoff. + :return: Protein indexer. + """ + # Using a large cutoff here; dynamically remove edges along diffusion + res_xyzs = rec_features["res_atom_positions"] + n_res = len(res_xyzs) + res_atom_masks = rec_features["res_atom_mask"] + ca_xyzs = res_xyzs[:, 1, :] + distances = np.linalg.norm(ca_xyzs[:, np.newaxis, :] - ca_xyzs[np.newaxis, :, :], axis=2) + edge_mask = distances < edge_cutoff + # Mask out residues where the backbone is not resolved + res_mask = np.all(~res_atom_masks[:, :3], axis=1) + edge_mask[res_mask, :] = 0 + edge_mask[:, res_mask] = 0 + res_ids = np.broadcast_to(np.arange(n_res), (n_res, n_res)) + src_nid, dst_nid = res_ids[edge_mask], res_ids.T[edge_mask] + + indexer = { + "gather_idx_a_chainid": rec_features["res_chain_id"], + "gather_idx_a_structid": np.zeros((n_res,), dtype=np.int_), + "gather_idx_ab_a": src_nid, + "gather_idx_ab_b": dst_nid, + } + return indexer + + +def process_protein( + af_protein: FDProtein, + bounding_box: Optional[np.ndarray] = None, + no_indexer: bool = True, + sample_name: str = "", + sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None, + plddt: Optional[Iterable[float]] = None, + chain_id: Optional[str] = None, +) -> Dict[str, Any]: + """Process protein data. + + :param af_protein: FDProtein object. + :param bounding_box: Bounding box to filter atoms. + :param no_indexer: If True, the indexer is not added to the output. + :param sample_name: Name of the sample. + :param prefix_key: Prefix key. + :param sequences_to_embeddings: Optional dictionary mapping sequences to embeddings. + :param plddt: Optional pLDDT values. + :param chain_id: Optional chain ID for parsing of LM embeddings. + :return: Processed protein data. + """ + lm_embeddings = None if sequences_to_embeddings is None else [] + if sequences_to_embeddings is not None: + if chain_id is not None and len(af_protein.letter_sequences) == 1: + chain_seq = af_protein.letter_sequences[0][1] + chain_mask = af_protein.letter_sequences[0][2] + chain_seq_masked = "".join(np.array(list(chain_seq))[chain_mask]) + lm_embeddings.append(sequences_to_embeddings[chain_seq_masked + f":{chain_id}"]) + else: + for i, (_, chain_seq, chain_mask) in enumerate(af_protein.letter_sequences): + chain_seq_masked = "".join(np.array(list(chain_seq))[chain_mask]) + if i in sequences_to_embeddings: + lm_embeddings.append(sequences_to_embeddings[i]) + elif chain_seq_masked + f":{i}" in sequences_to_embeddings: + lm_embeddings.append(sequences_to_embeddings[chain_seq_masked + f":{i}"]) + else: + raise ValueError( + f"Sequence {chain_seq_masked}:{i} not found in the provided embeddings." + ) + lm_embeddings = np.concatenate(lm_embeddings, axis=0) + assert len(lm_embeddings) == len( + af_protein.aatype + ), f"LM sequence length must match OpenFold-parsed sequence length: {len(lm_embeddings)} != {len(af_protein.aatype)}" + if bounding_box: + raise NotImplementedError + ca_pos = af_protein.atom_positions[:, 1] + ca_in_box = np.all( + (ca_pos > bounding_box[0]) & (ca_pos < bounding_box[1]), + axis=1, + ) + af_protein.atom_positions = af_protein.atom_positions[ca_in_box] + af_protein.aatype = af_protein.aatype[ca_in_box] + af_protein.atomtypes = af_protein.atomtypes[ca_in_box] + af_protein.atom_mask = af_protein.atom_mask[ca_in_box] + af_protein.chain_index = af_protein.chain_index[ca_in_box] + af_protein.b_factors = af_protein.b_factors[ca_in_box] + chain_seqs = [ + (sample_name + seq_data[0], seq_data[1]) for seq_data in af_protein.letter_sequences + ] + chain_masks = [seq_data[2] for seq_data in af_protein.letter_sequences] + features = { + "res_atom_positions": af_protein.atom_positions, + "res_type": np.int_(af_protein.aatype), + "res_atom_types": np.int_(af_protein.atomtypes), + "res_atom_mask": np.bool_(af_protein.atom_mask), + "res_chain_id": np.int_(af_protein.chain_index), + "residue_index": np.int_(af_protein.residue_index), + "sequence_res_mask": np.bool_(np.concatenate(chain_masks)), + } + if lm_embeddings is not None: + features.update({"lm_embeddings": lm_embeddings}) + if plddt: + features.update({"pLDDT": np.array(plddt) / 100}) + n_res = len(af_protein.atom_positions) + metadata = { + "num_structid": 1, + "num_a": n_res, + "num_b": n_res, + "num_chainid": max(af_protein.chain_index) + 1, + } + if no_indexer: + return { + "metadata": metadata, + "indexer": { + "gather_idx_a_chainid": features["res_chain_id"], + "gather_idx_a_structid": np.zeros((n_res,), dtype=np.int_), + }, + "features": features, + "misc": {"sequence_data": chain_seqs}, + } + return { + "metadata": metadata, + "indexer": get_protein_indexer(features), + "features": features, + "misc": {"sequence_data": chain_seqs}, + } + + +def merge_protein_and_ligands( + lig_samples: List[Dict[str, Any]], + rec_sample: Dict[str, Any], + n_lig_patches: int, + label: Optional[str] = None, + random_lig_placement: bool = False, + subsample_frames: bool = False, +) -> Dict[str, Any]: + """Merge protein and ligands. + + :param lig_samples: List of ligand samples. + :param rec_sample: Receptor sample. + :param n_lig_patches: Number of ligand patches. + :param label: Optional label. + :param random_lig_placement: If True, randomly place ligands into the box. + :param subsample_frames: If True, subsample frames. + :return: Merged protein and ligands. + """ + # Assign frame sampling rate to each ligand + num_ligands = len(lig_samples) + if num_ligands > 0: + num_frames_sqrt = np.sqrt( + np.array([lig_sample["metadata"]["num_ijk"] for lig_sample in lig_samples]) + ) + if (n_lig_patches > sum(num_frames_sqrt)) and subsample_frames: + n_lig_patches = random.randint(int(sum(num_frames_sqrt)), n_lig_patches) # nosec + max_n_frames_arr = num_frames_sqrt * (n_lig_patches / sum(num_frames_sqrt)) + max_n_frames_arr = max_n_frames_arr.astype(np.int_) + lig_samples = [ + attach_pair_idx_and_encodings(lig_sample, max_n_frames=max_n_frames_arr[lig_idx]) + for lig_idx, lig_sample in enumerate(lig_samples) + ] + + if random_lig_placement: + # Data augmentation, randomly placing into box + rec_coords = rec_sample["features"]["res_atom_positions"][ + rec_sample["features"]["res_atom_mask"] + ] + box_lbound, box_ubound = ( + np.amin(rec_coords, axis=0), + np.amax(rec_coords, axis=0), + ) + for sid, lig_sample in enumerate(lig_samples): + lig_coords = lig_sample["features"]["sdf_coordinates"] + lig_center = np.mean(lig_coords, axis=0) + is_clash = True + padding = 0 + while is_clash: + padding += 1.0 + new_center = np.random.uniform(low=box_lbound - padding, high=box_ubound + padding) + new_lig_coords = lig_coords + (new_center - lig_center)[None, :] + intermol_distmat = np.linalg.norm( + new_lig_coords[None, :] - rec_coords[:, None, :], + axis=2, + ) + if np.amin(intermol_distmat) > 4.0: + is_clash = False + lig_samples[sid]["features"]["augmented_coordinates"] = new_lig_coords + del lig_samples[sid]["features"]["sdf_coordinates"] + lig_sample_merged = collate_numpy_samples(lig_samples) + merged = { + "metadata": {**lig_sample_merged["metadata"], **rec_sample["metadata"]}, + "features": {**lig_sample_merged["features"], **rec_sample["features"]}, + "indexer": {**lig_sample_merged["indexer"], **rec_sample["indexer"]}, + "misc": {**lig_sample_merged["misc"], **rec_sample["misc"]}, + } + merged["metadata"]["num_structid"] = 1 + if "num_molid" in merged["metadata"]: + merged["indexer"]["gather_idx_i_structid"] = np.zeros( + lig_sample_merged["metadata"]["num_i"], dtype=np.int_ + ) + merged["indexer"]["gather_idx_ijk_structid"] = np.zeros( + lig_sample_merged["metadata"]["num_ijk"], dtype=np.int_ + ) + assert np.sum(merged["features"]["res_atom_mask"]) > 0 + if label is not None: + merged["labels"] = np.array([label]) + return merged + + +@beartype +def convert_protein_pts_to_pdb( + processed_pt_filenames: List[str], + processed_pdb_filename: str, +) -> None: + """Convert protein chain structures in `.pt` format to a single `.pdb` file. + + :param processed_pt_filenames: Filepaths to `.pt` files containing the protein structure. + :param processed_pdb_filename: Filepath to the output `.pdb` file. + """ + structure = Structure("protein_structure") + model = Model(0) + structure.add(model) + + for processed_pt_filename in processed_pt_filenames: + chain_id = Path(processed_pt_filename).stem.split("_")[1] + chain = Chain(chain_id) + model.add(chain) + + protein_chain_data = torch.load(processed_pt_filename) # nosec + for residue_id in range(len(protein_chain_data["seq"])): + aa = Polypeptide.one_to_three(protein_chain_data["seq"][residue_id]) + aa_atoms = AA_TO_LONG[AA_TO_NUM[aa]][:14] + + residue = Residue((" ", residue_id + 1, " "), aa, "") + chain.add(residue) + + for atom_name, coord, mask_atom, bfac, occ in zip( + aa_atoms, + protein_chain_data["xyz"][residue_id], + protein_chain_data["mask"][residue_id], + protein_chain_data["bfac"][residue_id], + protein_chain_data["occ"][residue_id], + ): + if atom_name is None or mask_atom.item() == 0: # skip masked atoms + continue + atom = Atom( + atom_name.strip(), coord.numpy(), bfac.item(), occ.item(), " ", atom_name, 0 + ) + residue.add(atom) + + io = PDBIO() + io.set_structure(structure) + io.save(processed_pdb_filename) + + +@beartype +def create_full_prot( + atom37: np.ndarray, + atom37_mask: np.ndarray, + aatype: Optional[np.ndarray] = None, + b_factors: Optional[np.ndarray] = None, +) -> OFProtein: + """Create a full protein from an Nx37x3 array of atom positions, where N is the number of + residues in the protein. + + :param atom37: Nx37x3 array of atom positions + :param atom37_mask: Nx37x3 array of atom masks + :param aatype: N-sized array of amino acid types + :param b_factors: Nx37x3 array of B-factors + :return: OFProtein object + """ + assert atom37.ndim == 3, "atom37 must be 3D" + assert atom37.shape[-1] == 3, "atom37 must have 3 coordinates" + assert atom37.shape[-2] == 37, "atom37 must have 37 atoms per residue" + + n = atom37.shape[0] + residue_index = np.arange(n) + chain_index = np.zeros(n) + if b_factors is None: + b_factors = np.zeros([n, 37]) + if aatype is None: + aatype = np.zeros(n, dtype=int) + + return OFProtein( + atom_positions=atom37, + atom_mask=atom37_mask, + aatype=aatype, + residue_index=residue_index, + chain_index=chain_index, + b_factors=b_factors, + ) + + +@beartype +def get_mol_with_new_conformer_coords(mol: Chem.Mol, new_coords: np.ndarray) -> Chem.Mol: + """Create a new version of an RDKit `Chem.Mol` with new conformer coordinates. + + :param mol: RDKit Mol object. + :param new_coords: Numpy array of shape (num_atoms, 3) with new 3D coordinates. + :return: A new RDKit Mol object with the updated conformer coordinates. + """ + num_atoms = mol.GetNumAtoms() + assert new_coords.shape == ( + num_atoms, + 3, + ), f"`new_coords` must have shape `({num_atoms}, 3), but it has shape `{new_coords.shape}`" + + # Create a new molecule + new_mol = Chem.Mol(copy.deepcopy(mol)) + new_mol.RemoveAllConformers() + + # Create a new conformer and set atom positions + new_conf = Chem.Conformer(num_atoms) + for i in range(num_atoms): + x, y, z = new_coords[i].astype(np.double) + new_conf.SetAtomPosition(i, Point3D(x, y, z)) + + # Add the conformer to the molecule + new_mol.AddConformer(new_conf, assignId=True) + + return new_mol + + +@beartype +def get_rc_tensor(rc_np: np.ndarray, aatype: torch.Tensor) -> torch.Tensor: + """Get a residue constant tensor from a numpy array based on the amino acid type. + + :param rc_np: Numpy array of residue constants. + :param aatype: Amino acid type tensor. + :return: Residue constant tensor. + """ + return torch.tensor(rc_np, device=aatype.device)[aatype] + + +@beartype +def atom37_to_atom14( + aatype: torch.Tensor, all_atom_pos: torch.Tensor, all_atom_mask: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert OpenFold's `atom37` positions to `atom14` positions. + + :param aatype: Amino acid type tensor of shape `(..., num_residues)`. + :param all_atom_pos: All-atom (`atom37`) positions tensor of shape `(..., num_residues, 37, 3)`. + :param all_atom_mask: All-atom (`atom37`) mask tensor of shape `(..., num_residues, 37)`. + :return: Tuple of `atom14` positions and `atom14` mask of shapes `(..., num_residues, 14, 3)` and `(..., num_residues, 14)`, respectively. + """ + residx_atom14_to_atom37 = get_rc_tensor( + residue_constants.RESTYPE_ATOM14_TO_ATOM37, aatype # (..., num_residues) + ) # (..., num_residues, 14) + no_batch_dims = len(aatype.shape) - 1 + atom14_mask = tensor_utils.batched_gather( + all_atom_mask, # (..., num_residues, 37) + residx_atom14_to_atom37, # (..., num_residues, 14) + dim=no_batch_dims + 1, + no_batch_dims=no_batch_dims + 1, + ).to( + all_atom_pos.dtype + ) # (..., num_residues, 14) + # create a mask for known groundtruth positions + atom14_mask *= get_rc_tensor(residue_constants.RESTYPE_ATOM14_MASK, aatype) + # gather the groundtruth positions + atom14_positions = tensor_utils.batched_gather( + all_atom_pos, # (..., num_residues, 37, 3) + residx_atom14_to_atom37, + dim=no_batch_dims + 1, + no_batch_dims=no_batch_dims + 1, + ) + atom14_positions = atom14_mask[..., None] * atom14_positions + return (atom14_positions, atom14_mask) # (..., num_residues, 14, 3) # (..., num_residues, 14) + + +@beartype +def output_to_pdb(output: Dict) -> List[str]: + """Return a PDB (file) string given a model output dictionary. + + :param output: A dictionary containing the model output. + :return: A list of PDB strings. + """ + final_atom_positions = output["all_atom_positions"] + output = {k: v.to("cpu").numpy() for k, v in output.items()} + final_atom_positions = final_atom_positions.cpu().numpy() + final_atom_mask = output["all_atom_mask"] + pdbs = [] + for i in range(output["aatype"].shape[0]): + aa = output["aatype"][i] + pred_pos = final_atom_positions[i] + mask = final_atom_mask[i] + resid = output["residue_index"][i] + 1 + pred = OFProtein( + aatype=aa, + atom_positions=pred_pos, + atom_mask=mask, + residue_index=resid, + b_factors=output["plddt"][i], + chain_index=output["chain_index"][i] if "chain_index" in output else None, + ) + pdbs.append(of_to_pdb(pred)) + return pdbs + + +@beartype +def load_hetero_data_graph_as_pdb_file_pair( + input_filepaths: Tuple[str, List[str]], temp_dir_path: Path +) -> Dict[str, Dict[str, Union[str, int]]]: + """Load a pair of protein PDB files (and their metadata also including ligand statistics) from + a single preprocessed `HeteroData` graph file and given RDKit ligands, respectively. + + :param input_filepaths: Tuple of filepaths to the preprocessed `HeteroData` graph and the RDKit ligands. + :param temp_dir_path: Path to temporary directory for storing the PDB files. + :return: A dictionary containing the filepaths to the PDB files and their lengths as well as ligand statistics. + """ + hetero_graph_filepath = input_filepaths[0] + rdkit_ligand_filepaths = input_filepaths[1] + with open(hetero_graph_filepath, "rb") as f: + hetero_graph = pickle.load(f) # nosec + rdkit_ligand_mol_frags = [] + for rdkit_ligand_filepath in rdkit_ligand_filepaths: + with open(rdkit_ligand_filepath, "rb") as f: + rdkit_ligand = pickle.load(f) # nosec + rdkit_ligand_mol_frags.extend( + Chem.GetMolFrags(rdkit_ligand, asMols=True, sanitizeFrags=False) + ) + apo_graph = hetero_graph["apo_receptor"] + holo_graph = hetero_graph["receptor"] + output = { + "aatype": torch.stack([holo_graph.aatype for _ in range(2)]), + "all_atom_positions": torch.stack( + [ + apo_graph.all_atom_positions, + holo_graph.all_atom_positions, + ] + ), + "all_atom_mask": torch.stack([holo_graph.all_atom_mask for _ in range(2)]), + "residue_index": torch.stack( + [(torch.arange(len(holo_graph.aatype)) + 1) for _ in range(2)] + ), + "plddt": torch.stack( + [(100.0 * torch.ones_like(holo_graph.all_atom_mask)) for _ in range(2)] + ), + } + pdbs = output_to_pdb(output) + apo_pdb_string, holo_pdb_string = pdbs[0], pdbs[1] + apo_output_file = temp_dir_path / "apo_protein.pdb" + holo_output_file = temp_dir_path / "holo_protein.pdb" + apo_output_file.write_text(apo_pdb_string) + holo_output_file.write_text(holo_pdb_string) + return { + "apo_protein": { + "filepath": str(apo_output_file), + "length": len(apo_graph.aatype), + }, + "holo_protein": { + "filepath": str(holo_output_file), + "length": len(holo_graph.aatype), + }, + "ligand": { + "num_atoms_per_mol_frag": [mol.GetNumAtoms() for mol in rdkit_ligand_mol_frags], + }, + } + + +def get_standard_aa_features(): + """Get standard amino acid features.""" + standard_pdb_filepath = os.path.join( + Path(__file__).parent.parent.absolute(), + "data", + "components", + "chemical", + "20AA_template_peptide.pdb", + ) + standard_aa_template_protein = pdb_filepath_to_protein(standard_pdb_filepath) + standard_aa_template_featset = process_protein(standard_aa_template_protein) + standard_aa_graph_featset = [ + process_mol_file( + os.path.join( + Path(__file__).parent.parent.absolute(), + "data", + "components", + "chemical", + f"{af_restype_1to3[aa_code]}.pdb", + ), + sanitize=True, + pair_feats=True, + ) + for aa_code in af_restypes + ] + return standard_aa_template_featset, standard_aa_graph_featset + + +def erase_holo_coordinates( + batch: MODEL_BATCH, x: torch.Tensor, latent_converter: LatentCoordinateConverter +) -> torch.Tensor: + """Erase the holo protein and ligand coordinates in the input tensor, leaving the apo + coordinate untouched. + + :param batch: A batch dictionary. + :param x: Input tensor. + :param latent_converter: Latent converter. + :return: Holo-erased tensor. + """ + if batch["misc"]["protein_only"]: + ca_lat, apo_ca_lat, cother_lat, apo_cother_lat = torch.split( + x, + [ + latent_converter._n_res_per_sample, + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_cother_per_sample, + ], + dim=1, + ) + x_erased = torch.cat( + [ + torch.zeros_like(ca_lat), + apo_ca_lat, + torch.zeros_like(cother_lat), + apo_cother_lat, + ], + dim=1, + ) + else: + ( + ca_lat, + apo_ca_lat, + cother_lat, + apo_cother_lat, + ca_lat_centroid_coords, + apo_ca_lat_centroid_coords, + lig_lat, + ) = torch.split( + x, + [ + latent_converter._n_res_per_sample, + latent_converter._n_res_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_cother_per_sample, + latent_converter._n_molid_per_sample, + latent_converter._n_molid_per_sample, + latent_converter._n_ligha_per_sample, + ], + dim=1, + ) + x_erased = torch.cat( + [ + torch.zeros_like(ca_lat), + apo_ca_lat, + torch.zeros_like(cother_lat), + apo_cother_lat, + torch.zeros_like(ca_lat_centroid_coords), + apo_ca_lat_centroid_coords, + torch.zeros_like(lig_lat), + ], + dim=1, + ) + return x_erased + + +def prepare_batch(batch: MODEL_BATCH): + """Prepare batch for forward pass. + + :param batch: A batch dictionary. + """ + if "outputs" not in batch: + batch["outputs"] = {} + if "indexer" in batch and "gather_idx_a_cotherid" not in batch["indexer"]: + cother_mask = batch["features"]["res_atom_mask"].bool().clone() + cother_mask[:, 1] = False + atom37_mask = torch.zeros_like(cother_mask, dtype=torch.long) + atom37_mask += torch.arange(0, atom37_mask.size(0), device=atom37_mask.device).unsqueeze( + -1 + ) + batch["indexer"]["gather_idx_a_cotherid"] = atom37_mask[cother_mask] + if "features" in batch and "apo_res_alignment_mask" not in batch["features"]: + batch["features"]["apo_res_alignment_mask"] = torch.ones_like( + batch["features"]["res_atom_mask"][:, 1], dtype=torch.bool + ) + if "num_molid" in batch["metadata"].keys() and batch["metadata"]["num_molid"] > 0: + batch["misc"]["protein_only"] = False + else: + batch["misc"]["protein_only"] = True + + +def centralize_complex_graph(complex_graph: Dict[str, Any]) -> Dict[str, Any]: + """Centralize the protein and ligand coordinates in the complex graph. + + Note that the holo protein and ligand coordinates are centralized using the holo protein Ca + atoms' centroid coordinates, whereas the apo protein coordinates are instead centralized using + the apo Ca atoms' centroid coordinates. Afterwards, both versions of the protein coordinates + are aligned at the origin. + + :param complex_graph: A complex graph dictionary. + :return: Centralized complex graph dictionary. + """ + ca_atom_centroid_coords = complex_graph["features"]["res_atom_positions"][:, 1].mean( + dim=0, keepdim=True + ) + complex_graph["features"]["res_atom_positions"] -= ca_atom_centroid_coords[:, None, :] + complex_graph["features"]["sdf_coordinates"] -= ca_atom_centroid_coords + if "apo_res_atom_positions" in complex_graph["features"]: + apo_ca_atom_centroid_coords = complex_graph["features"]["apo_res_atom_positions"][ + :, 1 + ].mean(dim=0, keepdim=True) + complex_graph["features"]["apo_res_atom_positions"] -= apo_ca_atom_centroid_coords[ + :, None, : + ] + return complex_graph + + +@beartype +def convert_to_molar( + value: float, + unit: Union[str, float], + unit_conversions: Dict[str, Union[int, float]] = MOAD_UNIT_CONVERSION_DICT, +) -> Optional[float]: + """Convert a binding affinity value to molar units. + + :param value: The binding affinity value as a float in original units if available or as NaN if + not available. + :param unit: The binding affinity unit as a string. + :return: The binding affinity value in molar units if a conversion factor exists. None + otherwise. + """ + if unit in unit_conversions: + conversion_factor = unit_conversions[unit] + return value * conversion_factor + else: + return None + + +@beartype +def parse_pdbbind_binding_affinity_data_file( + data_filepath: str, default_ligand_ccd_id: str = "XXX" +) -> Dict[str, Dict[str, float]]: + """Extract binding affinities from the PDBBind database's metadata. + + :param data_filepath: Path to the PDBBind database's metadata file. + :param default_ligand_ccd_id: The default CCD ID to use for PDBBind ligands, since PDBBind + complexes only have a single ligand. + :return: A dictionary mapping PDB codes to ligand CCD IDs and their corresponding binding + affinities. + """ + binding_affinity_scores_dict = {} + with open(data_filepath) as file: + for line in file: + columns = line.strip().split() + if len(columns) in {8, 9}: + pdb_code = columns[0] + pK_value = float(columns[3]) + # NOTE: we have to handle for multi-ligands here + if pdb_code in binding_affinity_scores_dict: + assert ( + pK_value == binding_affinity_scores_dict[pdb_code][default_ligand_ccd_id] + ), "PDBBind complexes should only have a single ligand." + else: + binding_affinity_scores_dict[pdb_code] = {default_ligand_ccd_id: pK_value} + return binding_affinity_scores_dict + + +@beartype +def parse_moad_binding_affinity_data_file(data_filepath: str) -> Dict[str, Dict[str, float]]: + """Extract binding affinities from the Binding MOAD dataset's metadata. + + :param data_filepath: Path to the Binding MOAD dataset's metadata file. + :return: A dictionary mapping PDB codes to ligand CCD IDs and their corresponding binding + affinities. + """ + # read in CSV file carefully and manually install column names + df = pd.read_csv(data_filepath, header=None, skiprows=1) + df.columns = [ + "protein_class", + "protein_family", + "protein_id", + "ligand_name", + "ligand_validity", + "affinity_measure", + "=", + "affinity_value", + "affinity_unit", + "smiles_string", + "misc", + ] + # split up `ligand_name` column into its individual parts + df[["ligand_ccd_id", "ligand_ccd_id_index", "ligand_het_code"]] = df["ligand_name"].str.split( + pat=":", n=2, expand=True + ) + df.drop(columns=["ligand_name"], inplace=True) + # assign the corresponding PDB ID to each ligand (row) entry + df["protein_id"].ffill(inplace=True) + # filter for only `valid` ligands (rows) with `Ki` or `Kd` affinity measures + df = df[df["ligand_validity"] == "valid"] + df = df[df["affinity_measure"].isin(["Ki", "Kd"])] + # standardize affinity values in molar units + df["affinity_molar"] = df.apply( + lambda row: convert_to_molar(row["affinity_value"], row["affinity_unit"]), axis=1 + ) + # normalize affinity values to pK, with a range of approximately [0, 14] + df["pK"] = -np.log10(df["affinity_molar"]) + binding_affinity_scores_dict = ( + df.groupby("protein_id") + .apply( + lambda group: { + row["ligand_ccd_id"] + ":" + str(row["ligand_ccd_id_index"]): row["pK"] + for _, row in group.iterrows() + } + ) + .to_dict() + ) + return binding_affinity_scores_dict + + +def min_max_normalize_array(array: np.ndarray) -> np.ndarray: + """Min-max normalize an array. + + :param array: Array to min-max normalize. + :return: Min-max normalized array. + """ + min_val = np.min(array) + max_val = np.max(array) + return (array - min_val) / (max_val - min_val) + + +def create_full_pdb_with_zero_coordinates(sequence: str, filename: str): + """Create a PDB file with all atom coordinates set to zero for given protein sequences, + including all atoms (backbone and simplified side chain). Multiple protein chains are delimited + by "|". + + :param sequence: Protein chain sequences in single-letter code format, separated by "|". + :param filename: Output filename for the PDB file. + """ + os.makedirs(os.path.dirname(filename), exist_ok=True) + # Backbone atoms for all amino acids + backbone_atoms = ["N", "CA", "C", "O"] + + # Simplified representation of side chain atoms for each amino acid + side_chain_atoms = { + "A": ["CB"], + "R": ["CB", "CG", "CD", "NE", "CZ"], + "N": ["CB", "CG", "OD1"], + "D": ["CB", "CG", "OD1"], + "C": ["CB", "SG"], + "E": ["CB", "CG", "CD"], + "Q": ["CB", "CG", "CD", "OE1"], + "G": [], + "H": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "I": ["CB", "CG1", "CG2", "CD1"], + "L": ["CB", "CG", "CD1", "CD2"], + "K": ["CB", "CG", "CD", "CE", "NZ"], + "M": ["CB", "CG", "SD", "CE"], + "F": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"], + "P": ["CB", "CG", "CD"], + "S": ["CB", "OG"], + "T": ["CB", "OG1", "CG2"], + "W": ["CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], + "Y": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"], + "V": ["CB", "CG1", "CG2"], + } + + with open(filename, "w") as pdb_file: + atom_index = 1 + chain_id = "A" # Start with chain 'A' + + for chain in sequence.split("|"): + residue_index = 1 + for residue in chain: + # Add backbone atoms + for atom in backbone_atoms: + pdb_file.write( + f"ATOM {atom_index:5d} {atom:<3s} {af_restype_1to3.get(residue, 'UNK')} {chain_id}{residue_index:4d} " + f" 0.000 0.000 0.000 1.00 0.00 C\n" + ) + atom_index += 1 + + # Add side chain atoms + for atom in side_chain_atoms.get(residue, []): # type: ignore + pdb_file.write( + f"ATOM {atom_index:5d} {atom:<3s} {af_restype_1to3.get(residue, 'UNK')} {chain_id}{residue_index:4d} " + f" 0.000 0.000 0.000 1.00 0.00 C\n" + ) + atom_index += 1 + + residue_index += 1 + # Increment chain ID for next chain + chain_id = chr(ord(chain_id) + 1) + + +@beartype +def parse_inference_inputs_from_dir( + input_data_dir: Union[str, Path], pdb_ids: Optional[Set[Any]] = None +) -> List[Tuple[str, str]]: + """Parse a data directory containing subdirectories of protein-ligand complexes and return + corresponding SMILES strings and PDB IDs. + + :param input_data_dir: Path to the input data directory. + :param pdb_ids: Optional set of IDs by which to filter processing. + :return: A list of tuples each containing a SMILES string and a PDB ID. + """ + smiles_and_pdb_id_list = [] + for pdb_name in os.listdir(input_data_dir): + if any(substr in pdb_name.lower() for substr in ["sequence", "structure"]): + # e.g., skip ESMFold sequence files and structure directories + continue + if pdb_ids is not None and pdb_name not in pdb_ids: + # e.g., skip PoseBusters Benchmark PDBs that contain crystal contacts + # reference: https://github.com/maabuu/posebusters/issues/26 + continue + pdb_dir = os.path.join(input_data_dir, pdb_name) + if os.path.isdir(pdb_dir): + mol = None + pdb_id = os.path.split(pdb_dir)[-1] + # NOTE: we first try to parse `.mol2` and if necessary `.sdf` files, since e.g., PDBBind 2020's `.sdf` files do not contain chirality tags + if os.path.exists(os.path.join(pdb_dir, f"{pdb_id}_ligand.mol2")): + mol = read_molecule( + os.path.join(pdb_dir, f"{pdb_id}_ligand.mol2"), remove_hs=True, sanitize=True + ) + if mol is None and os.path.exists(os.path.join(pdb_dir, f"{pdb_id}_ligand.sdf")): + mol = read_molecule( + os.path.join(pdb_dir, f"{pdb_id}_ligand.sdf"), remove_hs=True, sanitize=True + ) + Chem.rdmolops.AssignAtomChiralTagsFromStructure(mol) + # NOTE: Binding MOAD/DockGen uses `.pdb` files to store its ligands + if mol is None and os.path.exists(os.path.join(pdb_dir, f"{pdb_id}_ligand.pdb")): + mol = read_molecule( + os.path.join(pdb_dir, f"{pdb_id}_ligand.pdb"), remove_hs=True, sanitize=True + ) + if mol is None: + mol = read_molecule( + os.path.join(pdb_dir, f"{pdb_id}_ligand.pdb"), + remove_hs=True, + sanitize=False, + ) + if mol is None: + raise ValueError(f"No ligand file found for PDB ID {pdb_id}") + mol_smiles = Chem.MolToSmiles(mol) + if mol_smiles is None: + raise ValueError(f"Failed to generate SMILES string for PDB ID {pdb_id}") + smiles_and_pdb_id_list.append((mol_smiles, pdb_id)) + return smiles_and_pdb_id_list + + +@beartype +def create_temp_ligand_frag_files(input_sdf_file: str) -> List[str]: + """Creates temporary SDF files for each fragment in the input SDF file.""" + # Get the fragments of the input molecule + mol = Chem.MolFromMolFile(input_sdf_file) + fragments = Chem.GetMolFrags(mol, asMols=True) + + temp_files = [] + for frag in fragments: + # Create a temporary SDF file for each fragment + temp_file = tempfile.NamedTemporaryFile(suffix=".sdf", delete=False) + temp_files.append(temp_file.name) + + # Write each fragment to the temporary SDF file + writer = AllChem.SDWriter(temp_file.name) + writer.write(frag) + writer.close() + + return temp_files diff --git a/flowdock/utils/frame_utils.py b/flowdock/utils/frame_utils.py new file mode 100644 index 0000000..02d4971 --- /dev/null +++ b/flowdock/utils/frame_utils.py @@ -0,0 +1,77 @@ +import torch +from beartype.typing import Optional + + +class RigidTransform: + """Rigid Transform class.""" + + def __init__(self, t: torch.Tensor, R: Optional[torch.Tensor] = None): + """Initialize Rigid Transform.""" + self.t = t + if R is None: + R = t.new_zeros(*t.shape, 3) + self.R = R + + def __getitem__(self, key): + """Get item from Rigid Transform.""" + return RigidTransform(self.t[key], self.R[key]) + + def unsqueeze(self, dim): + """Unsqueeze Rigid Transform.""" + return RigidTransform(self.t.unsqueeze(dim), self.R.unsqueeze(dim)) + + def squeeze(self, dim): + """Squeeze Rigid Transform.""" + return RigidTransform(self.t.squeeze(dim), self.R.squeeze(dim)) + + def concatenate(self, other, dim=0): + """Concatenate Rigid Transform.""" + return RigidTransform( + torch.cat([self.t, other.t], dim=dim), + torch.cat([self.R, other.R], dim=dim), + ) + + +def get_frame_matrix( + ri: torch.Tensor, rj: torch.Tensor, rk: torch.Tensor, eps: float = 1e-4, strict: bool = False +): + """Get frame matrix from three points using the regularized Gram-Schmidt algorithm. + + Note that this implementation allows for shearing. + """ + v1 = ri - rj + v2 = rk - rj + if strict: + # v1 = v1 + torch.randn_like(rj).mul(eps) + # v2 = v2 + torch.randn_like(rj).mul(eps) + e1 = v1 / v1.norm(dim=-1, keepdim=True) + # Project and pad + u2 = v2 - e1.mul(e1.mul(v2).sum(-1, keepdim=True)) + e2 = u2 / u2.norm(dim=-1, keepdim=True) + else: + e1 = v1 / v1.square().sum(dim=-1, keepdim=True).add(eps).sqrt() + # Project and pad + u2 = v2 - e1.mul(e1.mul(v2).sum(-1, keepdim=True)) + e2 = u2 / u2.square().sum(dim=-1, keepdim=True).add(eps).sqrt() + e3 = torch.cross(e1, e2, dim=-1) + # Rows - lab frame, columns - internal frame + rot_j = torch.stack([e1, e2, e3], dim=-1) + return RigidTransform(rj, torch.nan_to_num(rot_j, 0.0)) + + +def cartesian_to_internal(rs: torch.Tensor, frames: RigidTransform): + """Right-multiply the pose matrix.""" + rs_loc = rs - frames.t + rs_loc = torch.matmul(rs_loc.unsqueeze(-2), frames.R) + return rs_loc.squeeze(-2) + + +def apply_similarity_transform( + X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor +) -> torch.Tensor: + """Apply a similarity transform to a set of points X. + + From: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/ops/points_alignment.html + """ + X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :] + return X diff --git a/flowdock/utils/generate_wandb_id.py b/flowdock/utils/generate_wandb_id.py new file mode 100644 index 0000000..850852f --- /dev/null +++ b/flowdock/utils/generate_wandb_id.py @@ -0,0 +1,4 @@ +import wandb + +if __name__ == "__main__": + print(f"Generated WandB run ID: {wandb.util.generate_id()}") diff --git a/flowdock/utils/inspect_ode_samplers.py b/flowdock/utils/inspect_ode_samplers.py new file mode 100644 index 0000000..af7c353 --- /dev/null +++ b/flowdock/utils/inspect_ode_samplers.py @@ -0,0 +1,55 @@ +import argparse + +import torch +from beartype import beartype +from beartype.typing import Literal + + +def clamp_tensor(value: torch.Tensor, min: float = 1e-6, max: float = 1 - 1e-6) -> torch.Tensor: + """Set the upper and lower bounds of a tensor via clamping. + + :param value: The tensor to clamp. + :param min: The minimum value to clamp to. Default is `1e-6`. + :param max: The maximum value to clamp to. Default is `1 - 1e-6`. + :return: The clamped tensor. + """ + return value.clamp(min=min, max=max) + + +@beartype +def main( + start_time: float, num_steps: int, sampler: Literal["ODE", "VDODE"] = "VDODE", eta: float = 1.0 +): + """Inspect different ODE samplers by printing the left hand side (LHS) and right hand side. + + (RHS) of their time ratio schedules. Note that the LHS and RHS are clamped to the range + `[1e-6, 1 - 1e-6]` by default. + + :param start_time: The start time of the ODE sampler. + :param num_steps: The number of steps to take. + :param sampler: The ODE sampler to use. + :param eta: The variance diminishing factor. + """ + assert 0 < start_time <= 1.0, "The argument `start_time` must be in the range (0, 1]." + schedule = torch.linspace(start_time, 0, num_steps + 1) + for t, s in zip(schedule[:-1], schedule[1:]): + if sampler == "ODE": + # Baseline ODE + print( + f"t: {t:.3f}; s: {s:.3f}; LHS -> (1 - t) * x0_hat: {clamp_tensor((1 - t)):.3f}; RHS -> t * xt: {clamp_tensor(t):.3f}" + ) + elif sampler == "VDODE": + # Variance Diminishing ODE (VD-ODE) + print( + f"t: {t:.3f}; s: {s:.3f}; LHS -> (1 - ((s / t) * eta)) * x0_hat: {clamp_tensor(1 - ((s / t) * eta)):.3f}; RHS -> ((s / t) * eta) * xt: {clamp_tensor((s / t) * eta):.3f}" + ) + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument("--start_time", type=float, default=1.0) + argparser.add_argument("--num_steps", type=int, default=40) + argparser.add_argument("--sampler", type=str, choices=["ODE", "VDODE"], default="VDODE") + argparser.add_argument("--eta", type=float, default=1.0) + args = argparser.parse_args() + main(args.start_time, args.num_steps, sampler=args.sampler, eta=args.eta) diff --git a/flowdock/utils/instantiators.py b/flowdock/utils/instantiators.py new file mode 100644 index 0000000..84f8757 --- /dev/null +++ b/flowdock/utils/instantiators.py @@ -0,0 +1,58 @@ +import hydra +import rootutils +from beartype.typing import List +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/flowdock/utils/logging_utils.py b/flowdock/utils/logging_utils.py new file mode 100644 index 0000000..92dfdc3 --- /dev/null +++ b/flowdock/utils/logging_utils.py @@ -0,0 +1,59 @@ +import rootutils +from beartype.typing import Any, Dict +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import OmegaConf + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/flowdock/utils/metric_utils.py b/flowdock/utils/metric_utils.py new file mode 100644 index 0000000..c275c52 --- /dev/null +++ b/flowdock/utils/metric_utils.py @@ -0,0 +1,88 @@ +import subprocess # nosec + +import torch +from beartype import beartype +from beartype.typing import Any, Dict, List, Optional, Tuple + +MODEL_BATCH = Dict[str, Any] + + +@beartype +def calculate_usalign_metrics( + pred_pdb_filepath: str, + reference_pdb_filepath: str, + usalign_exec_path: str, + flags: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Calculates US-align structural metrics between predicted and reference macromolecular + structures. + + :param pred_pdb_filepath: Filepath to predicted macromolecular structure in PDB format. + :param reference_pdb_filepath: Filepath to reference macromolecular structure in PDB format. + :param usalign_exec_path: Path to US-align executable. + :param flags: Command-line flags to pass to US-align, optional. + :return: Dictionary containing macromolecular US-align structural metrics and metadata. + """ + # run US-align with subprocess and capture output + cmd = [usalign_exec_path, pred_pdb_filepath, reference_pdb_filepath] + if flags is not None: + cmd += flags + output = subprocess.check_output(cmd, text=True, stderr=subprocess.PIPE) # nosec + + # parse US-align output to extract structural metrics + metrics = {} + for line in output.splitlines(): + line = line.strip() + if line.startswith("Name of Structure_1:"): + metrics["Name of Structure_1"] = line.split(": ", 1)[1] + elif line.startswith("Name of Structure_2:"): + metrics["Name of Structure_2"] = line.split(": ", 1)[1] + elif line.startswith("Length of Structure_1:"): + metrics["Length of Structure_1"] = int(line.split(": ")[1].split()[0]) + elif line.startswith("Length of Structure_2:"): + metrics["Length of Structure_2"] = int(line.split(": ")[1].split()[0]) + elif line.startswith("Aligned length="): + aligned_length = line.split("=")[1].split(",")[0] + rmsd = line.split("=")[2].split(",")[0] + seq_id = line.split("=")[4] + metrics["Aligned length"] = int(aligned_length.strip()) + metrics["RMSD"] = float(rmsd.strip()) + metrics["Seq_ID"] = float(seq_id.strip()) + elif line.startswith("TM-score="): + if "normalized by length of Structure_1" in line: + metrics["TM-score_1"] = float(line.split("=")[1].split()[0]) + elif "normalized by length of Structure_2" in line: + metrics["TM-score_2"] = float(line.split("=")[1].split()[0]) + + return metrics + + +def compute_per_atom_lddt( + batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes per-atom local distance difference test (LDDT) between predicted and target + coordinates. + + :param batch: Dictionary containing metadata and target coordinates. + :param pred_coords: Predicted atomic coordinates. + :param target_coords: Target atomic coordinates. + :return: Tuple of lDDT and lDDT list. + """ + pred_coords = pred_coords.contiguous().view(batch["metadata"]["num_structid"], -1, 3) + target_coords = target_coords.contiguous().view(batch["metadata"]["num_structid"], -1, 3) + target_dist = (target_coords[:, :, None] - target_coords[:, None, :]).norm(dim=-1) + pred_dist = (pred_coords[:, :, None] - pred_coords[:, None, :]).norm(dim=-1) + conserved_mask = target_dist < 15.0 + lddt_list = [] + thresholds = [0, 0.5, 1, 2, 4, 6, 8, 12, 1e9] + for threshold_idx in range(8): + distdiff = (pred_dist - target_dist).abs() + bin_fraction = (distdiff > thresholds[threshold_idx]) & ( + distdiff < thresholds[threshold_idx + 1] + ) + lddt_list.append( + bin_fraction.mul(conserved_mask).long().sum(dim=2) / conserved_mask.long().sum(dim=2) + ) + lddt_list = torch.stack(lddt_list, dim=-1) + lddt = torch.cumsum(lddt_list[:, :, :4], dim=-1).mean(dim=-1) + return lddt, lddt_list diff --git a/flowdock/utils/model_utils.py b/flowdock/utils/model_utils.py new file mode 100644 index 0000000..f643775 --- /dev/null +++ b/flowdock/utils/model_utils.py @@ -0,0 +1,446 @@ +# Adapted from: https://github.com/zrqiao/NeuralPLexer + +import rootutils +import torch +import torch.nn.functional as F +from beartype.typing import Any, Dict, List, Optional, Tuple, Union +from torch_scatter import scatter_max, scatter_min + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils import RankedLogger + +MODEL_BATCH = Dict[str, Any] +STATE_DICT = Dict[str, Any] + +log = RankedLogger(__name__, rank_zero_only=True) + + +class GELUMLP(torch.nn.Module): + """Simple MLP with post-LayerNorm.""" + + def __init__( + self, + n_in_feats: int, + n_out_feats: int, + n_hidden_feats: Optional[int] = None, + dropout: float = 0.0, + zero_init: bool = False, + ): + """Initialize the GELUMLP.""" + super().__init__() + self.dropout = dropout + if n_hidden_feats is None: + self.layers = torch.nn.Sequential( + torch.nn.Linear(n_in_feats, n_in_feats), + torch.nn.GELU(), + torch.nn.LayerNorm(n_in_feats), + torch.nn.Linear(n_in_feats, n_out_feats), + ) + else: + self.layers = torch.nn.Sequential( + torch.nn.Linear(n_in_feats, n_hidden_feats), + torch.nn.GELU(), + torch.nn.Dropout(p=self.dropout), + torch.nn.Linear(n_hidden_feats, n_hidden_feats), + torch.nn.GELU(), + torch.nn.LayerNorm(n_hidden_feats), + torch.nn.Linear(n_hidden_feats, n_out_feats), + ) + torch.nn.init.xavier_uniform_(self.layers[0].weight, gain=1) + # zero init for residual branches + if zero_init: + self.layers[-1].weight.data.fill_(0.0) + else: + torch.nn.init.xavier_uniform_(self.layers[-1].weight, gain=1) + + def _zero_init(self, module): + """Zero-initialize weights and biases.""" + if isinstance(module, torch.nn.Linear): + module.weight.data.zero_() + if module.bias is not None: + module.bias.data.zero_() + + def forward(self, x: torch.Tensor): + """Forward pass through the GELUMLP.""" + x = F.dropout(x, p=self.dropout, training=self.training) + return self.layers(x) + + +class SumPooling(torch.nn.Module): + """Sum pooling layer.""" + + def __init__(self, learnable: bool, hidden_dim: int = 1): + """Initialize the SumPooling layer.""" + super().__init__() + self.pooled_transform = ( + torch.nn.Linear(hidden_dim, hidden_dim) if learnable else torch.nn.Identity() + ) + + def forward(self, x, dst_idx, dst_size): + """Forward pass through the SumPooling layer.""" + return self.pooled_transform(segment_sum(x, dst_idx, dst_size)) + + +class AveragePooling(torch.nn.Module): + """Average pooling layer.""" + + def __init__(self, learnable: bool, hidden_dim: int = 1): + """Initialize the AveragePooling layer.""" + super().__init__() + self.pooled_transform = ( + torch.nn.Linear(hidden_dim, hidden_dim) if learnable else torch.nn.Identity() + ) + + def forward(self, x, dst_idx, dst_size): + """Forward pass through the AveragePooling layer.""" + out = torch.zeros( + dst_size, + *x.shape[1:], + dtype=x.dtype, + device=x.device, + ).index_add_(0, dst_idx, x) + nmr = torch.zeros( + dst_size, + *x.shape[1:], + dtype=x.dtype, + device=x.device, + ).index_add_(0, dst_idx, torch.ones_like(x)) + return self.pooled_transform(out / (nmr + 1e-8)) + + +def init_weights(m): + """Initialize weights with Kaiming uniform.""" + if isinstance(m, torch.nn.Linear): + torch.nn.init.kaiming_uniform_(m.weight) + + +def segment_sum(src, dst_idx, dst_size): + """Computes the sum of each segment in a tensor.""" + out = torch.zeros( + dst_size, + *src.shape[1:], + dtype=src.dtype, + device=src.device, + ).index_add_(0, dst_idx, src) + return out + + +def segment_mean(src, dst_idx, dst_size): + """Computes the mean value of each segment in a tensor.""" + out = torch.zeros( + dst_size, + *src.shape[1:], + dtype=src.dtype, + device=src.device, + ).index_add_(0, dst_idx, src) + denom = ( + torch.zeros( + dst_size, + *src.shape[1:], + dtype=src.dtype, + device=src.device, + ).index_add_(0, dst_idx, torch.ones_like(src)) + + 1e-8 + ) + return out / denom + + +def segment_argmin(scores, dst_idx, dst_size, randomize: bool = False) -> torch.Tensor: + """Samples the index of the minimum value in each segment.""" + if randomize: + noise = torch.rand_like(scores) + scores = scores - torch.log(-torch.log(noise)) + _, sampled_idx = scatter_min(scores, dst_idx, dim=0, dim_size=dst_size) + return sampled_idx + + +def segment_logsumexp(src, dst_idx, dst_size, extra_dims=None): + """Computes the logsumexp of each segment in a tensor.""" + src_max, _ = scatter_max(src, dst_idx, dim=0, dim_size=dst_size) + if extra_dims is not None: + src_max = torch.amax(src_max, dim=extra_dims, keepdim=True) + src = src - src_max[dst_idx] + out = torch.zeros( + dst_size, + *src.shape[1:], + dtype=src.dtype, + device=src.device, + ).index_add_(0, dst_idx, torch.exp(src)) + if extra_dims is not None: + out = torch.sum(out, dim=extra_dims) + return torch.log(out + 1e-8) + src_max.view(*out.shape) + + +def segment_softmax(src, dst_idx, dst_size, extra_dims=None, floor_value=None): + """Computes the softmax of each segment in a tensor.""" + src_max, _ = scatter_max(src, dst_idx, dim=0, dim_size=dst_size) + if extra_dims is not None: + src_max = torch.amax(src_max, dim=extra_dims, keepdim=True) + src = src - src_max[dst_idx] + exp1 = torch.exp(src) + exp0 = torch.zeros( + dst_size, + *src.shape[1:], + dtype=src.dtype, + device=src.device, + ).index_add_(0, dst_idx, exp1) + if extra_dims is not None: + exp0 = torch.sum(exp0, dim=extra_dims, keepdim=True) + exp0 = torch.index_select(input=exp0, dim=0, index=dst_idx) + exp = exp1.div(exp0 + 1e-8) + if floor_value is not None: + exp = exp.clamp(min=floor_value) + exp0 = torch.zeros( + dst_size, + *src.shape[1:], + dtype=src.dtype, + device=src.device, + ).index_add_(0, dst_idx, exp) + if extra_dims is not None: + exp0 = torch.sum(exp0, dim=extra_dims, keepdim=True) + exp0 = torch.index_select(input=exp0, dim=0, index=dst_idx) + exp = exp.div(exp0 + 1e-8) + return exp + + +def batched_sample_onehot(logits, dim=0, max_only=False): + """Implements the Gumbel-Max trick to sample from a one-hot distribution.""" + if max_only: + sampled_idx = torch.argmax(logits, dim=dim, keepdim=True) + else: + noise = torch.rand_like(logits) + sampled_idx = torch.argmax(logits - torch.log(-torch.log(noise)), dim=dim, keepdim=True) + out_onehot = torch.zeros_like(logits, dtype=torch.bool) + out_onehot.scatter_(dim=dim, index=sampled_idx, value=1) + return out_onehot + + +def topk_edge_mask_from_logits(scores, k, randomize=False): + """Samples the top-k edges from a set of logits.""" + assert len(scores.shape) == 3, "Scores should have shape [B, N, N]" + if randomize: + noise = torch.rand_like(scores) + scores = scores - torch.log(-torch.log(noise)) + node_degree = min(k, scores.shape[2]) + _, topk_idx = torch.topk(scores, node_degree, dim=-1, largest=True) + edge_mask = scores.new_zeros(scores.shape, dtype=torch.bool) + edge_mask = edge_mask.scatter_(dim=2, index=topk_idx, value=1).bool() + return edge_mask + + +def sample_inplace_to_torch(sample): + """Convert NumPy sample to PyTorch tensors.""" + if sample is None: + return None + sample["features"] = {k: torch.FloatTensor(v) for k, v in sample["features"].items()} + sample["indexer"] = {k: torch.LongTensor(v) for k, v in sample["indexer"].items()} + if "labels" in sample.keys(): + sample["labels"] = {k: torch.FloatTensor(v) for k, v in sample["labels"].items()} + return sample + + +def inplace_to_device(sample, device): + """Move sample to device.""" + sample["features"] = {k: v.to(device) for k, v in sample["features"].items()} + sample["indexer"] = {k: v.to(device) for k, v in sample["indexer"].items()} + if "labels" in sample.keys(): + sample["labels"] = sample["labels"].to(device) + return sample + + +def inplace_to_torch(sample): + """Convert NumPy sample to PyTorch tensors.""" + if sample is None: + return None + sample["features"] = {k: torch.FloatTensor(v) for k, v in sample["features"].items()} + sample["indexer"] = {k: torch.LongTensor(v) for k, v in sample["indexer"].items()} + if "labels" in sample.keys(): + sample["labels"] = {k: torch.FloatTensor(v) for k, v in sample["labels"].items()} + return sample + + +def distance_to_gaussian_contact_logits( + x: torch.Tensor, contact_scale: float, cutoff: Optional[float] = None +) -> torch.Tensor: + """Convert distance to Gaussian contact logits. + + :param x: Distance tensor. + :param contact_scale: The contact scale. + :param cutoff: The distance cutoff. + :return: Gaussian contact logits. + """ + if cutoff is None: + cutoff = contact_scale * 2 + return torch.log(torch.clamp(1 - (x / cutoff), min=1e-9)) + + +def distogram_to_gaussian_contact_logits( + dgram: torch.Tensor, dist_bins: torch.Tensor, contact_scale: float +) -> torch.Tensor: + """Convert a distance histogram (distogram) matrix to a Gaussian contact map. + + :param dgram: A distogram matrix. + :return: A Gaussian contact map. + """ + return torch.logsumexp( + dgram + distance_to_gaussian_contact_logits(dist_bins, contact_scale), + dim=-1, + ) + + +def eval_true_contact_maps( + batch: MODEL_BATCH, contact_scale: float, **kwargs: Dict[str, Any] +) -> Tuple[torch.Tensor, torch.Tensor]: + """Evaluate true contact maps. + + :param batch: A batch dictionary. + :param contact_scale: The contact scale. + :param kwargs: Additional keyword arguments. + :return: True contact maps. + """ + indexer = batch["indexer"] + batch_size = batch["metadata"]["num_structid"] + with torch.no_grad(): + # Residue centroids + res_cent_coords = ( + batch["features"]["res_atom_positions"] + .mul(batch["features"]["res_atom_mask"].bool()[:, :, None]) + .sum(dim=1) + .div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9) + ) + res_lig_dist = ( + res_cent_coords.view(batch_size, -1, 3)[:, :, None] + - batch["features"]["sdf_coordinates"][indexer["gather_idx_U_u"]].view( + batch_size, -1, 3 + )[:, None, :] + ).norm(dim=-1) + res_lig_contact_logit = distance_to_gaussian_contact_logits( + res_lig_dist, contact_scale, **kwargs + ) + return res_lig_dist, res_lig_contact_logit.flatten() + + +def sample_reslig_contact_matrix( + batch: MODEL_BATCH, res_lig_logits: torch.Tensor, last: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Sample residue-ligand contact matrix. + + :param batch: A batch dictionary. + :param res_lig_logits: Residue-ligand contact logits. + :param last: The last contact matrix. + :return: Sampled residue-ligand contact matrix. + """ + metadata = batch["metadata"] + batch_size = metadata["num_structid"] + max(metadata["num_molid_per_sample"]) + n_a_per_sample = max(metadata["num_a_per_sample"]) + n_I_per_sample = max(metadata["num_I_per_sample"]) + res_lig_logits = res_lig_logits.view(batch_size, n_a_per_sample, n_I_per_sample) + # Sampling from unoccupied lattice sites + if last is None: + last = torch.zeros_like(res_lig_logits, dtype=torch.bool) + # Column-graph-wise masking for already sampled ligands + # sampled_ligand_mask = torch.amax(last, dim=1, keepdim=True) + sampled_frame_mask = torch.sum(last, dim=1, keepdim=True).contiguous() + masked_logits = res_lig_logits - sampled_frame_mask * 1e9 + sampled_block_onehot = batched_sample_onehot(masked_logits.flatten(1, 2), dim=1).view( + batch_size, n_a_per_sample, n_I_per_sample + ) + new_block_contact_mat = last + sampled_block_onehot + # Remove non-contact patches + valid_logit_mask = res_lig_logits > -16.0 + new_block_contact_mat = (new_block_contact_mat * valid_logit_mask).bool() + return new_block_contact_mat + + +def merge_res_lig_logits_to_graph( + batch: MODEL_BATCH, + res_lig_logits: torch.Tensor, + single_protein_batch: bool, +) -> torch.Tensor: + """Patch merging [B, N_res, N_atm] -> [B, N_res, N_graph]. + + :param batch: A batch dictionary. + :param res_lig_logits: Residue-ligand contact logits. + :param single_protein_batch: Whether to use single protein batch. + :return: Merged residue-ligand logits. + """ + assert single_protein_batch, "Only single protein batch is supported." + metadata = batch["metadata"] + indexer = batch["indexer"] + batch_size = metadata["num_structid"] + max(metadata["num_molid_per_sample"]) + n_mol_per_sample = max(metadata["num_molid_per_sample"]) + n_a_per_sample = max(metadata["num_a_per_sample"]) + n_I_per_sample = max(metadata["num_I_per_sample"]) + res_lig_logits = res_lig_logits.view(batch_size, n_a_per_sample, n_I_per_sample) + graph_wise_logits = segment_logsumexp( + res_lig_logits.permute(2, 0, 1), + indexer["gather_idx_I_molid"][:n_I_per_sample], + n_mol_per_sample, + ).permute(1, 2, 0) + return graph_wise_logits + + +def sample_res_rowmask_from_contacts( + batch: MODEL_BATCH, + res_ligatm_logits: torch.Tensor, + single_protein_batch: bool, +) -> torch.Tensor: + """Sample residue row mask from contacts. + + :param batch: A batch dictionary. + :param res_ligatm_logits: Residue-ligand atom contact logits. + :return: Sampled residue row mask. + """ + metadata = batch["metadata"] + max(metadata["num_molid_per_sample"]) + lig_wise_logits = ( + merge_res_lig_logits_to_graph(batch, res_ligatm_logits, single_protein_batch) + .permute(0, 2, 1) + .contiguous() + ) + sampled_res_onehot_mask = batched_sample_onehot(lig_wise_logits.flatten(0, 1), dim=1) + return sampled_res_onehot_mask + + +def extract_esm_embeddings( + esm_model: torch.nn.Module, + esm_alphabet: torch.nn.Module, + esm_batch_converter: torch.nn.Module, + sequences: List[str], + device: Union[str, torch.device], + esm_repr_layer: int = 33, +) -> List[torch.Tensor]: + """Extract embeddings from ESM model. + + :param esm_model: ESM model. + :param esm_alphabet: ESM alphabet. + :param esm_batch_converter: ESM batch converter. + :param sequences: A list of sequences. + :param device: Device to use. + :param esm_repr_layer: ESM representation layer index from which to extract embeddings. + :return: A corresponding list of embeddings. + """ + # Disable dropout for deterministic results + esm_model.eval() + + # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) + data = [(str(i), seq) for i, seq in enumerate(sequences)] + _, _, batch_tokens = esm_batch_converter(data) + batch_tokens = batch_tokens.to(device) + batch_lens = (batch_tokens != esm_alphabet.padding_idx).sum(1) + + # Extract per-residue representations (on CPU) + with torch.no_grad(): + results = esm_model(batch_tokens, repr_layers=[esm_repr_layer], return_contacts=True) + token_representations = results["representations"][esm_repr_layer] + + # Generate per-residue representations + # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. + sequence_representations = [] + for i, tokens_len in enumerate(batch_lens): + sequence_representations.append(token_representations[i, 1 : tokens_len - 1]) + + return sequence_representations diff --git a/flowdock/utils/pylogger.py b/flowdock/utils/pylogger.py new file mode 100644 index 0000000..51f249c --- /dev/null +++ b/flowdock/utils/pylogger.py @@ -0,0 +1,51 @@ +import logging + +from beartype.typing import Mapping, Optional +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/flowdock/utils/rich_utils.py b/flowdock/utils/rich_utils.py new file mode 100644 index 0000000..03c8c80 --- /dev/null +++ b/flowdock/utils/rich_utils.py @@ -0,0 +1,102 @@ +from pathlib import Path + +import rich +import rich.syntax +import rich.tree +import rootutils +from beartype.typing import Sequence +from hydra.core.hydra_config import HydraConfig +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/flowdock/utils/sampling_utils.py b/flowdock/utils/sampling_utils.py new file mode 100644 index 0000000..5f0311c --- /dev/null +++ b/flowdock/utils/sampling_utils.py @@ -0,0 +1,427 @@ +import os + +import numpy as np +import pandas as pd +import rootutils +import torch +from beartype.typing import Any, Dict, List, Optional, Tuple +from lightning import LightningModule +from omegaconf import DictConfig +from rdkit import Chem +from rdkit.Chem import AllChem + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.data.components.mol_features import ( + collate_numpy_samples, + process_mol_file, +) +from flowdock.utils import RankedLogger +from flowdock.utils.data_utils import ( + FDProtein, + merge_protein_and_ligands, + pdb_filepath_to_protein, + prepare_batch, + process_protein, +) +from flowdock.utils.model_utils import inplace_to_device, inplace_to_torch, segment_mean +from flowdock.utils.visualization_utils import ( + write_conformer_sdf, + write_pdb_models, + write_pdb_single, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +def featurize_protein_and_ligands( + rec_path: str, + lig_paths: List[str], + n_lig_patches: int, + apo_rec_path: Optional[str] = None, + chain_id: Optional[str] = None, + protein: Optional[FDProtein] = None, + sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None, + enforce_sanitization: bool = False, + discard_sdf_coords: bool = False, + **kwargs: Dict[str, Any], +): + """Featurize a protein-ligand complex. + + :param rec_path: Path to the receptor file. + :param lig_paths: List of paths to the ligand files. + :param n_lig_patches: Number of ligand patches. + :param apo_rec_path: Path to the apo receptor file. + :param chain_id: Chain ID of the receptor. + :param protein: Optional protein object. + :param sequences_to_embeddings: Mapping of sequences to embeddings. + :param enforce_sanitization: Whether to enforce sanitization. + :param discard_sdf_coords: Whether to discard SDF coordinates. + :param kwargs: Additional keyword arguments. + :return: Featurized protein-ligand complex. + """ + assert rec_path is not None + if lig_paths is None: + lig_paths = [] + if isinstance(lig_paths, str): + lig_paths = [lig_paths] + out_mol = None + lig_samples = [] + for lig_path in lig_paths: + try: + lig_sample, mol_ref = process_mol_file( + lig_path, + sanitize=True, + return_mol=True, + discard_coords=discard_sdf_coords, + ) + except Exception as e: + if enforce_sanitization: + raise + log.warning( + f"RDKit sanitization failed for ligand {lig_path} due to: {e}. Loading raw attributes." + ) + lig_sample, mol_ref = process_mol_file( + lig_path, + sanitize=False, + return_mol=True, + discard_coords=discard_sdf_coords, + ) + lig_samples.append(lig_sample) + if out_mol is None: + out_mol = mol_ref + else: + out_mol = AllChem.CombineMols(out_mol, mol_ref) + protein = protein if protein is not None else pdb_filepath_to_protein(rec_path) + rec_sample = process_protein( + protein, + chain_id=chain_id, + sequences_to_embeddings=None if apo_rec_path is not None else sequences_to_embeddings, + **kwargs, + ) + if apo_rec_path is not None: + apo_protein = pdb_filepath_to_protein(apo_rec_path) + apo_rec_sample = process_protein( + apo_protein, + chain_id=chain_id, + sequences_to_embeddings=sequences_to_embeddings, + **kwargs, + ) + for key in rec_sample.keys(): + for subkey, value in apo_rec_sample[key].items(): + rec_sample[key]["apo_" + subkey] = value + merged_sample = merge_protein_and_ligands( + lig_samples, + rec_sample, + n_lig_patches=n_lig_patches, + label=None, + ) + return merged_sample, out_mol + + +def multi_pose_sampling( + receptor_path: str, + ligand_path: str, + cfg: DictConfig, + lit_module: LightningModule, + out_path: str, + save_pdb: bool = True, + separate_pdb: bool = True, + chain_id: Optional[str] = None, + apo_receptor_path: Optional[str] = None, + sample_id: Optional[str] = None, + protein: Optional[FDProtein] = None, + sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None, + confidence: bool = True, + affinity: bool = True, + return_all_states: bool = False, + auxiliary_estimation_only: bool = False, + **kwargs: Dict[str, Any], +) -> Tuple[ + Optional[Chem.Mol], + Optional[List[float]], + Optional[List[float]], + Optional[List[float]], + Optional[List[Any]], + Optional[Any], + Optional[np.ndarray], + Optional[np.ndarray], +]: + """Sample multiple poses of a protein-ligand complex. + + :param receptor_path: Path to the receptor file. + :param ligand_path: Path to the ligand file. + :param cfg: Config dictionary. + :param lit_module: LightningModule instance. + :param out_path: Path to save the output files. + :param save_pdb: Whether to save PDB files. + :param separate_pdb: Whether to save separate PDB files for each pose. + :param chain_id: Chain ID of the receptor. + :param apo_receptor_path: Path to the optional apo receptor file. + :param sample_id: Optional sample ID. + :param protein: Optional protein object. + :param sequences_to_embeddings: Mapping of sequences to embeddings. + :param confidence: Whether to estimate confidence scores. + :param affinity: Whether to estimate affinity scores. + :param return_all_states: Whether to return all states. + :param auxiliary_estimation_only: Whether to only estimate auxiliary outputs (e.g., confidence, + affinity) for the input (generated) samples (potentially derived from external sources). + :param kwargs: Additional keyword arguments. + :return: Reference molecule, protein plDDTs, ligand plDDTs, ligand fragment plDDTs, estimated + binding affinities, structure trajectories, input batch, B-factors, and structure rankings. + """ + if return_all_states and auxiliary_estimation_only: + # NOTE: If auxiliary estimation is solely enabled, structure trajectory sampling will be disabled + return_all_states = False + struct_res_all, lig_res_all = [], [] + plddt_all, plddt_lig_all, plddt_ligs_all, res_plddt_all = [], [], [], [] + affinity_all, ligs_affinity_all = [], [] + frames_all = [] + chunk_size = cfg.chunk_size + for _ in range(cfg.n_samples // chunk_size): + # Resample anchor node frames + np_sample, mol = featurize_protein_and_ligands( + receptor_path, + ligand_path, + n_lig_patches=lit_module.hparams.cfg.mol_encoder.n_patches, + apo_rec_path=apo_receptor_path, + chain_id=chain_id, + protein=protein, + sequences_to_embeddings=sequences_to_embeddings, + discard_sdf_coords=cfg.discard_sdf_coords and not auxiliary_estimation_only, + **kwargs, + ) + np_sample_batched = collate_numpy_samples([np_sample for _ in range(chunk_size)]) + sample = inplace_to_device(inplace_to_torch(np_sample_batched), device=lit_module.device) + prepare_batch(sample) + if auxiliary_estimation_only: + # Predict auxiliary quantities using the provided input protein and ligand structures + if "num_molid" in sample["metadata"].keys() and sample["metadata"]["num_molid"] > 0: + sample["misc"]["protein_only"] = False + else: + sample["misc"]["protein_only"] = True + output_struct = { + "receptor": sample["features"]["res_atom_positions"].flatten(0, 1), + "receptor_padded": sample["features"]["res_atom_positions"], + "ligands": sample["features"]["sdf_coordinates"], + } + else: + output_struct = lit_module.net.sample_pl_complex_structures( + sample, + sampler=cfg.sampler, + sampler_eta=cfg.sampler_eta, + num_steps=cfg.num_steps, + return_all_states=return_all_states, + start_time=cfg.start_time, + exact_prior=cfg.exact_prior, + ) + frames_all.append(output_struct.get("all_frames", None)) + if mol is not None: + ref_mol = AllChem.Mol(mol) + out_x1 = np.split(output_struct["ligands"].cpu().numpy(), cfg.chunk_size) + out_x2 = np.split(output_struct["receptor_padded"].cpu().numpy(), cfg.chunk_size) + if confidence and affinity: + assert ( + lit_module.net.confidence_cfg.enabled + ), "Confidence estimation must be enabled in the model configuration." + assert ( + lit_module.net.affinity_cfg.enabled + ), "Affinity estimation must be enabled in the model configuration." + plddt, plddt_lig, plddt_ligs = lit_module.net.run_auxiliary_estimation( + sample, + output_struct, + return_avg_stats=True, + training=False, + ) + aff = sample["outputs"]["affinity_logits"] + elif confidence: + assert ( + lit_module.net.confidence_cfg.enabled + ), "Confidence estimation must be enabled in the model configuration." + plddt, plddt_lig, plddt_ligs = lit_module.net.run_auxiliary_estimation( + sample, + output_struct, + return_avg_stats=True, + training=False, + ) + elif affinity: + assert ( + lit_module.net.affinity_cfg.enabled + ), "Affinity estimation must be enabled in the model configuration." + lit_module.net.run_auxiliary_estimation( + sample, output_struct, return_avg_stats=True, training=False + ) + plddt, plddt_lig = None, None + aff = sample["outputs"]["affinity_logits"].cpu() + + mol_idx_i_structid = segment_mean( + sample["indexer"]["gather_idx_i_structid"], + sample["indexer"]["gather_idx_i_molid"], + sample["metadata"]["num_molid"], + ).long() + for struct_idx in range(cfg.chunk_size): + struct_res = { + "features": { + "asym_id": np_sample["features"]["res_chain_id"], + "residue_index": np.arange(len(np_sample["features"]["res_type"])) + 1, + "aatype": np_sample["features"]["res_type"], + }, + "structure_module": { + "final_atom_positions": out_x2[struct_idx], + "final_atom_mask": sample["features"]["res_atom_mask"].bool().cpu().numpy(), + }, + } + struct_res_all.append(struct_res) + if mol is not None: + lig_res_all.append(out_x1[struct_idx]) + if confidence: + plddt_all.append(plddt[struct_idx].item()) + res_plddt_all.append( + sample["outputs"]["plddt"][ + struct_idx, : sample["metadata"]["num_a_per_sample"][0] + ] + .cpu() + .numpy() + ) + if plddt_lig is None: + plddt_lig_all.append(None) + else: + plddt_lig_all.append(plddt_lig[struct_idx].item()) + if plddt_ligs is None: + plddt_ligs_all.append(None) + else: + plddt_ligs_all.append(plddt_ligs[mol_idx_i_structid == struct_idx].tolist()) + if affinity: + # collect the average affinity across all ligands in each complex + ligs_aff = aff[mol_idx_i_structid == struct_idx] + affinity_all.append(ligs_aff.mean().item()) + ligs_affinity_all.append(ligs_aff.tolist()) + if confidence and cfg.rank_outputs_by_confidence: + plddt_lig_predicted = all(plddt_lig_all) + if cfg.plddt_ranking_type == "protein": + struct_plddts = np.array(plddt_all) # rank outputs using average protein plDDT + elif cfg.plddt_ranking_type == "ligand": + struct_plddts = np.array( + plddt_lig_all if plddt_lig_predicted else plddt_all + ) # rank outputs using average ligand plDDT if available + if not plddt_lig_predicted: + log.warning( + "Ligand plDDT not available for all samples, using protein plDDT instead" + ) + elif cfg.plddt_ranking_type == "protein_ligand": + struct_plddts = np.array( + plddt_all + plddt_lig_all if plddt_lig_predicted else plddt_all + ) # rank outputs using the sum of the average protein and ligand plDDTs if ligand plDDT is available + if not plddt_lig_predicted: + log.warning( + "Ligand plDDT not available for all samples, using protein plDDT instead" + ) + struct_plddt_rankings = np.argsort( + -struct_plddts + ).argsort() # ensure that higher plDDTs have a higher rank (e.g., `rank1`) + receptor_plddt = np.array(res_plddt_all) if confidence else None + b_factors = ( + np.repeat( + receptor_plddt[..., None], + struct_res_all[0]["structure_module"]["final_atom_mask"].shape[-1], + axis=-1, + ) + if confidence + else None + ) + if save_pdb: + if separate_pdb: + for struct_id, struct_res in enumerate(struct_res_all): + if confidence and cfg.rank_outputs_by_confidence: + write_pdb_single( + struct_res, + out_path=os.path.join( + out_path, + f"prot_rank{struct_plddt_rankings[struct_id] + 1}_plddt{struct_plddts[struct_id]:.7f}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.pdb", + ), + b_factors=b_factors[struct_id] if confidence else None, + ) + else: + write_pdb_single( + struct_res, + out_path=os.path.join( + out_path, + f"prot_{struct_id}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.pdb", + ), + b_factors=b_factors[struct_id] if confidence else None, + ) + write_pdb_models( + struct_res_all, out_path=os.path.join(out_path, "prot_all.pdb"), b_factors=b_factors + ) + if mol is not None: + write_conformer_sdf(ref_mol, None, out_path=os.path.join(out_path, "lig_ref.sdf")) + lig_res_all = np.array(lig_res_all) + write_conformer_sdf(mol, lig_res_all, out_path=os.path.join(out_path, "lig_all.sdf")) + for struct_id in range(len(lig_res_all)): + if confidence and cfg.rank_outputs_by_confidence: + write_conformer_sdf( + mol, + lig_res_all[struct_id : struct_id + 1], + out_path=os.path.join( + out_path, + f"lig_rank{struct_plddt_rankings[struct_id] + 1}_plddt{struct_plddts[struct_id]:.7f}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.sdf", + ), + ) + else: + write_conformer_sdf( + mol, + lig_res_all[struct_id : struct_id + 1], + out_path=os.path.join( + out_path, + f"lig_{struct_id}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.sdf", + ), + ) + if confidence: + aux_estimation_all_df = pd.DataFrame( + { + "sample_id": [sample_id] * len(struct_res_all), + "rank": struct_plddt_rankings + 1 if cfg.rank_outputs_by_confidence else None, + "plddt_ligs": plddt_ligs_all, + "affinity_ligs": ligs_affinity_all, + } + ) + aux_estimation_all_df.to_csv( + os.path.join(out_path, f"{sample_id if sample_id is not None else 'sample'}_auxiliary_estimation.csv"), index=False + ) + else: + ref_mol = None + if not confidence: + plddt_all, plddt_lig_all, plddt_ligs_all = None, None, None + if not affinity: + affinity_all = None + if return_all_states: + if mol is not None: + np_sample["metadata"]["sample_ID"] = sample_id if sample_id is not None else "sample" + np_sample["metadata"]["mol"] = ref_mol + batch_all = inplace_to_torch( + collate_numpy_samples([np_sample for _ in range(cfg.n_samples)]) + ) + merge_frames_all = frames_all[0] + for frames in frames_all[1:]: + for frame_index, frame in enumerate(frames): + for key in frame.keys(): + merge_frames_all[frame_index][key] = torch.cat( + [merge_frames_all[frame_index][key], frame[key]], dim=0 + ) + frames_all = merge_frames_all + else: + frames_all = None + batch_all = None + if not (confidence and cfg.rank_outputs_by_confidence): + struct_plddt_rankings = None + return ( + ref_mol, + plddt_all, + plddt_lig_all, + plddt_ligs_all, + affinity_all, + frames_all, + batch_all, + b_factors, + struct_plddt_rankings, + ) diff --git a/flowdock/utils/utils.py b/flowdock/utils/utils.py new file mode 100644 index 0000000..ba2f396 --- /dev/null +++ b/flowdock/utils/utils.py @@ -0,0 +1,153 @@ +import warnings +from importlib.util import find_spec + +import rootutils +from beartype.typing import Any, Callable, Dict, List, Optional, Tuple +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.utils import pylogger, rich_utils + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: If provided, the name of the metric to retrieve. + :return: If a metric name was provided, the value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def read_strings_from_txt(path: str) -> List[str]: + """Reads strings from a text file and returns them as a list. + + :param path: Path to the text file. + :return: List of strings. + """ + with open(path) as file: + # NOTE: every line will be one element of the returned list + lines = file.readlines() + return [line.rstrip() for line in lines] + + +def fasta_to_dict(filename: str) -> Dict[str, str]: + """Converts a FASTA file to a dictionary where the keys are the sequence IDs and the values are + the sequences. + + :param filename: Path to the FASTA file. + :return: Dictionary with sequence IDs as keys and sequences as values. + """ + fasta_dict = {} + with open(filename) as file: + for line in file: + line = line.rstrip() # remove trailing whitespace + if line.startswith(">"): # identifier line + seq_id = line[1:] # remove the '>' character + fasta_dict[seq_id] = "" + else: # sequence line + fasta_dict[seq_id] += line + return fasta_dict diff --git a/flowdock/utils/visualization_utils.py b/flowdock/utils/visualization_utils.py new file mode 100644 index 0000000..9b424c1 --- /dev/null +++ b/flowdock/utils/visualization_utils.py @@ -0,0 +1,365 @@ +import os + +import numpy as np +import rootutils +from beartype import beartype +from beartype.typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from openfold.np.protein import Protein as OFProtein +from rdkit import Chem +from rdkit.Geometry.rdGeometry import Point3D + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from flowdock.data.components import residue_constants +from flowdock.utils.data_utils import ( + PDB_CHAIN_IDS, + PDB_MAX_CHAINS, + FDProtein, + create_full_prot, + get_mol_with_new_conformer_coords, +) + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. +PROT_LIG_PAIRS = List[Tuple[OFProtein, Tuple[Chem.Mol, ...]]] + + +@beartype +def _chain_end( + atom_index: Union[int, np.int64], + end_resname: str, + chain_name: str, + residue_index: Union[int, np.int64], +) -> str: + """Returns a PDB `TER` record for the end of a chain. + + Adapted from: https://github.com/jasonkyuyim/se3_diffusion + + :param atom_index: The index of the last atom in the chain. + :param end_resname: The residue name of the last residue in the chain. + :param chain_name: The chain name of the last residue in the chain. + :param residue_index: The residue index of the last residue in the chain. + :return: A PDB `TER` record. + """ + chain_end = "TER" + return ( + f"{chain_end:<6}{atom_index:>5} {end_resname:>3} " + f"{chain_name:>1}{residue_index:>4}" + ) + + +@beartype +def res_1to3(restypes: List[str], r: Union[int, np.int64]) -> str: + """Convert a residue type from 1-letter to 3-letter code. + + :param restypes: List of residue types. + :param r: Residue type index. + :return: 3-letter code as a string. + """ + return residue_constants.restype_1to3.get(restypes[r], "UNK") + + +@beartype +def to_pdb(prot: Union[OFProtein, FDProtein], model=1, add_end=True, add_endmdl=True) -> str: + """Converts a `Protein` instance to a PDB string. + + Adapted from: https://github.com/jasonkyuyim/se3_diffusion + + :param prot: The protein to convert to PDB. + :param model: The model number to use. + :param add_end: Whether to add an `END` record. + :param add_endmdl: Whether to add an `ENDMDL` record. + :return: PDB string. + """ + restypes = residue_constants.restypes + ["X"] + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(int) + chain_index = prot.chain_index.astype(int) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError("Invalid aatypes.") + + # construct a mapping from chain integer indices to chain ID strings + chain_ids = {} + for i in np.unique(chain_index): # NOTE: `np.unique` gives sorted output + if i >= PDB_MAX_CHAINS: + raise ValueError(f"The PDB format supports at most {PDB_MAX_CHAINS} chains.") + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append(f"MODEL {model}") + atom_index = 1 + last_chain_index = chain_index[0] + # add all atom sites + for i in range(aatype.shape[0]): + # close the previous chain if in a multichain PDB + if last_chain_index != chain_index[i]: + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(restypes, aatype[i - 1]), + chain_ids[chain_index[i - 1]], + residue_index[i - 1], + ) + ) + last_chain_index = chain_index[i] + atom_index += 1 # NOTE: atom index increases at the `TER` symbol + + res_name_3 = res_1to3(restypes, aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i] + ): + if mask < 0.5: + continue + + record_type = "ATOM" + name = atom_name if len(atom_name) == 4 else f" {atom_name}" + alt_loc = "" + insertion_code = "" + occupancy = 1.00 + element = atom_name[0] # NOTE: `Protein` supports only C, N, O, S, this works + charge = "" + # NOTE: PDB is a columnar format, every space matters here! + atom_line = ( + f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" + f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}" + f"{residue_index[i]:>4}{insertion_code:>1} " + f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" + f"{occupancy:>6.2f}{b_factor:>6.2f} " + f"{element:>2}{charge:>2}" + ) + pdb_lines.append(atom_line) + atom_index += 1 + + # close the final chain + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(restypes, aatype[-1]), + chain_ids[chain_index[-1]], + residue_index[-1], + ) + ) + if add_endmdl: + pdb_lines.append("ENDMDL") + if add_end: + pdb_lines.append("END") + + # pad all lines to 80 characters + pdb_lines = [line.ljust(80) for line in pdb_lines] + return "\n".join(pdb_lines) + "\n" # add terminating newline + + +@beartype +def construct_prot_lig_pairs(outputs: Dict[str, Any], batch_index: int) -> PROT_LIG_PAIRS: + """Construct protein-ligand pairs from model outputs. + + :param outputs: The model outputs. + :param batch_index: The index of the current batch. + :return: A list of protein-ligand object pairs. + """ + protein_batch_indexer = outputs["protein_batch_indexer"] + ligand_batch_indexer = outputs["ligand_batch_indexer"] + + protein_all_atom_mask = outputs["res_atom_mask"][protein_batch_indexer == batch_index] + protein_all_atom_coordinates_mask = np.broadcast_to( + np.expand_dims(protein_all_atom_mask, -1), (protein_all_atom_mask.shape[0], 37, 3) + ) + protein_aatype = outputs["aatype"][protein_batch_indexer == batch_index] + + # assemble predicted structures + prot_lig_pairs = [] + for protein_coordinates, ligand_coordinates in zip( + outputs["protein_coordinates_list"], outputs["ligand_coordinates_list"] + ): + protein_all_atom_coordinates = ( + protein_coordinates[protein_batch_indexer == batch_index] + * protein_all_atom_coordinates_mask + ) + protein = create_full_prot( + protein_all_atom_coordinates, + protein_all_atom_mask, + protein_aatype, + b_factors=outputs["b_factors"][batch_index] if "b_factors" in outputs else None, + ) + ligand = get_mol_with_new_conformer_coords( + outputs["ligand_mol"][batch_index], + ligand_coordinates[ligand_batch_indexer == batch_index], + ) + ligands = tuple(Chem.GetMolFrags(ligand, asMols=True, sanitizeFrags=False)) + prot_lig_pairs.append((protein, ligands)) + + # assemble ground-truth structures + if "gt_protein_coordinates" in outputs and "gt_ligand_coordinates" in outputs: + protein_gt_all_atom_coordinates = ( + outputs["gt_protein_coordinates"][protein_batch_indexer == batch_index] + * protein_all_atom_coordinates_mask + ) + gt_protein = create_full_prot( + protein_gt_all_atom_coordinates, + protein_all_atom_mask, + protein_aatype, + ) + gt_ligand = get_mol_with_new_conformer_coords( + outputs["ligand_mol"][batch_index], + outputs["gt_ligand_coordinates"][ligand_batch_indexer == batch_index], + ) + gt_ligands = tuple(Chem.GetMolFrags(gt_ligand, asMols=True, sanitizeFrags=False)) + prot_lig_pairs.append((gt_protein, gt_ligands)) + + return prot_lig_pairs + + +@beartype +def write_prot_lig_pairs_to_pdb_file(prot_lig_pairs: PROT_LIG_PAIRS, output_filepath: str): + """Write a list of protein-ligand pairs to a PDB file. + + :param prot_lig_pairs: List of protein-ligand object pairs, where each ligand may consist of + multiple ligand chains. + :param output_filepath: Output file path. + """ + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + with open(output_filepath, "w") as f: + model_id = 1 + for prot, lig_mols in prot_lig_pairs: + pdb_prot = to_pdb(prot, model=model_id, add_end=False, add_endmdl=False) + f.write(pdb_prot) + for lig_mol in lig_mols: + f.write( + Chem.MolToPDBBlock(lig_mol).replace( + "END\n", "TER\n" + ) # enable proper ligand chain separation + ) + f.write("END\n") + f.write("ENDMDL\n") # add `ENDMDL` line to separate models + model_id += 1 + + +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + remove_leading_feature_dimension: bool = False, +) -> FDProtein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + remove_leading_feature_dimension: Whether to remove the leading dimension + of the `features` values. + + Returns: + A protein instance. + """ + fold_output = result["structure_module"] + + def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: + return arr[0] if remove_leading_feature_dimension else arr + + if "asym_id" in features: + chain_index = _maybe_remove_leading_dim(features["asym_id"]) + else: + chain_index = np.zeros_like(_maybe_remove_leading_dim(features["aatype"])) + + if b_factors is None: + b_factors = np.zeros_like(fold_output["final_atom_mask"]) + + return FDProtein( + letter_sequences=None, + aatype=_maybe_remove_leading_dim(features["aatype"]), + atom_positions=fold_output["final_atom_positions"], + atom_mask=fold_output["final_atom_mask"], + residue_index=_maybe_remove_leading_dim(features["residue_index"]), + chain_index=chain_index, + b_factors=b_factors, + atomtypes=None, + ) + + +def write_pdb_single( + result: ModelOutput, + out_path: str = os.path.join("test_results", "debug.pdb"), + model: int = 1, + b_factors: Optional[np.ndarray] = None, +): + """Write a single model to a PDB file. + + :param result: Model results batch. + :param out_path: Output path. + :param model: Model ID. + :param b_factors: Optional B-factors. + """ + os.makedirs(os.path.dirname(out_path), exist_ok=True) + protein = from_prediction(result["features"], result, b_factors=b_factors) + out_string = to_pdb(protein, model=model) + with open(out_path, "w") as of: + of.write(out_string) + + +def write_pdb_models( + results, + out_path: str = os.path.join("test_results", "debug.pdb"), + b_factors: Optional[np.ndarray] = None, +): + """Write multiple models to a PDB file. + + :param results: Model results. + :param out_path: Output path. + :param b_factors: Optional B-factors. + """ + os.makedirs(os.path.dirname(out_path), exist_ok=True) + with open(out_path, "w") as of: + for mid, result in enumerate(results): + protein = from_prediction( + result["features"], + result, + b_factors=b_factors[mid] if b_factors is not None else None, + ) + out_string = to_pdb(protein, model=mid + 1) + of.write(out_string) + of.write("END") + + +def write_conformer_sdf( + mol: Chem.Mol, + confs: Optional[np.array] = None, + out_path: str = os.path.join("test_results", "debug.sdf"), +): + """Write a molecule with conformers to an SDF file. + + :param mol: RDKit molecule. + :param confs: Conformers. + :param out_path: Output path. + """ + os.makedirs(os.path.dirname(out_path), exist_ok=True) + if confs is None: + w = Chem.SDWriter(out_path) + w.write(mol) + w.close() + return 0 + mol.RemoveAllConformers() + for i in range(len(confs)): + conf = Chem.Conformer(mol.GetNumAtoms()) + for j in range(mol.GetNumAtoms()): + x, y, z = confs[i, j].tolist() + conf.SetAtomPosition(j, Point3D(x, y, z)) + mol.AddConformer(conf, assignId=True) + + w = Chem.SDWriter(out_path) + try: + for cid in range(len(confs)): + w.write(mol, confId=cid) + except Exception as e: + w.SetKekulize(False) + for cid in range(len(confs)): + w.write(mol, confId=cid) + w.close() + return 0 diff --git a/main.nf b/main.nf new file mode 100644 index 0000000..d95c74a --- /dev/null +++ b/main.nf @@ -0,0 +1,74 @@ +#!/usr/bin/env nextflow + +nextflow.enable.dsl=2 + +params.input_receptor = 'YNKIVHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNKLEGKVPMHKLFLEMLEAKV' +params.input_ligand = 'c1cc2c(cc1O)CCCC2' +params.input_template = '' +params.sample_id = 'flowdock_sample' +params.outdir = 's3://omic/eureka/flowdock/output/' +params.n_samples = 5 + +process FLOWDOCK { + container 'harbor.cluster.omic.ai/omic/flowdock:latest' + containerOptions '--rm --gpus all' + publishDir params.outdir, mode: 'copy' + stageInMode 'copy' + + input: + val receptor + val ligand + + output: + path "${params.sample_id}/*.pdb", optional: true + path "${params.sample_id}/*.sdf", optional: true + path "${params.sample_id}/**/*.pdb", optional: true + path "run.log" + + script: + """ + set +u + source /opt/conda/etc/profile.d/conda.sh + conda activate FlowDock + + mkdir -p ${params.sample_id} + + python /software/flowdock/flowdock/sample.py \\ + ckpt_path=/software/flowdock/checkpoints/esmfold_prior_paper_weights-EMA.ckpt \\ + model.cfg.prior_type=esmfold \\ + sampling_task=batched_structure_sampling \\ + input_receptor='${receptor}' \\ + input_ligand='"${ligand}"' \\ + ${params.input_template ? "input_template=${params.input_template}" : "input_template=null"} \\ + sample_id='${params.sample_id}' \\ + out_path=\$PWD/${params.sample_id}/ \\ + paths.output_dir=\$PWD/${params.sample_id}/ \\ + hydra.run.dir=\$PWD/${params.sample_id}/ \\ + n_samples=${params.n_samples} \\ + chunk_size=5 \\ + num_steps=40 \\ + sampler=VDODE \\ + sampler_eta=1.0 \\ + start_time='1.0' \\ + use_template=true \\ + separate_pdb=true \\ + visualize_sample_trajectories=false \\ + auxiliary_estimation_only=false \\ + esmfold_chunk_size=null \\ + trainer=gpu 2>&1 | tee run.log + + # Copy any outputs from FlowDock's log directory if they exist there + if ls /software/flowdock/logs/sample/runs/*/rank*.pdb 2>/dev/null; then + cp /software/flowdock/logs/sample/runs/*/rank*.pdb ${params.sample_id}/ 2>/dev/null || true + cp /software/flowdock/logs/sample/runs/*/rank*.sdf ${params.sample_id}/ 2>/dev/null || true + fi + + # List what was generated + echo "=== Generated files ===" >> run.log + find ${params.sample_id}/ -type f >> run.log 2>&1 || true + """ +} + +workflow { + FLOWDOCK(Channel.of(params.input_receptor), Channel.of(params.input_ligand)) +} diff --git a/nextflow.config b/nextflow.config new file mode 100644 index 0000000..5141a64 --- /dev/null +++ b/nextflow.config @@ -0,0 +1,34 @@ +manifest { + name = 'FlowDock-Nextflow' + author = 'BioinfoMachineLearning' + homePage = 'https://github.com/BioinfoMachineLearning/FlowDock' + description = 'Nextflow pipeline for FlowDock - Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction' + mainScript = 'main.nf' + version = '1.0.0' +} + +params { + input_receptor = 'YNKIVHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNKLEGKVPMHKLFLEMLEAKV' + input_ligand = 'c1cc2c(cc1O)CCCC2' + input_template = '' + sample_id = 'flowdock_sample' + outdir = 's3://omic/eureka/flowdock/output/' + n_samples = 5 +} + +docker { + enabled = true + runOptions = '--gpus all' +} + +process { + cpus = 4 + memory = '32 GB' +} + +executor { + $local { + cpus = 8 + memory = '64 GB' + } +} diff --git a/params.json b/params.json new file mode 100644 index 0000000..1d9ea64 --- /dev/null +++ b/params.json @@ -0,0 +1,101 @@ +{ + "params": { + "input_receptor": { + "type": "string", + "description": "Protein sequence(s) in single-letter amino acid code. Multiple chains separated by '|'", + "default": "YNKIVHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNKLEGKVPMHKLFLEMLEAKV", + "required": true, + "pipeline_io": "input", + "var_name": "params.input_receptor", + "examples": [ + "YNKIVHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNKLEGKVPMHKLFLEMLEAKV", + "MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP" + ], + "pattern": "^[ACDEFGHIKLMNPQRSTVWY|]+$", + "enum": [], + "validation": {}, + "notes": "For multi-chain proteins, separate sequences with '|' character." + }, + "input_ligand": { + "type": "string", + "description": "Ligand SMILES string for docking", + "default": "c1cc2c(cc1O)CCCC2", + "required": true, + "pipeline_io": "input", + "var_name": "params.input_ligand", + "examples": [ + "c1cc2c(cc1O)CCCC2", + "CC(=O)NC1C(O)OC(CO)C(OC2OC(CO)C(O)C(O)C2NC(C)=O)C1O" + ], + "pattern": ".*", + "enum": [], + "validation": {}, + "notes": "SMILES notation for the ligand molecule." + }, + "input_template": { + "type": "file", + "description": "Optional template PDB structure for protein. If empty, ESMFold will predict the structure", + "default": "", + "required": false, + "pipeline_io": "input", + "var_name": "params.input_template", + "examples": [ + "/path/to/template.pdb", + "" + ], + "pattern": ".*\\.pdb$", + "enum": [], + "validation": {}, + "notes": "Providing a template can improve docking quality. Leave empty to use ESMFold-predicted structure." + }, + "sample_id": { + "type": "string", + "description": "Unique identifier for the prediction sample", + "default": "flowdock_sample", + "required": false, + "pipeline_io": "parameter", + "var_name": "params.sample_id", + "examples": [ + "6i67", + "T1152", + "my_protein_ligand" + ], + "pattern": "^[a-zA-Z0-9_-]+$", + "enum": [], + "validation": {}, + "notes": "Used for naming output directories and files." + }, + "outdir": { + "type": "folder", + "description": "Directory for FlowDock output results", + "default": "s3://omic/eureka/flowdock/output/", + "required": true, + "pipeline_io": "output", + "var_name": "params.outdir", + "examples": [ + "./flowdock_output", + "/data/predictions/flowdock" + ], + "pattern": ".*", + "enum": [], + "validation": {}, + "notes": "Directory where PDB structures and logs will be stored." + }, + "n_samples": { + "type": "integer", + "description": "Number of prediction samples to generate", + "default": 5, + "required": false, + "pipeline_io": "parameter", + "var_name": "params.n_samples", + "examples": [1, 5, 10, 40], + "pattern": "^[1-9][0-9]*$", + "enum": [], + "validation": { + "min": 1, + "max": 100 + }, + "notes": "Higher values provide more conformational diversity but increase computation time." + } + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..300ebf0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,25 @@ +[tool.pytest.ini_options] +addopts = [ + "--color=yes", + "--durations=0", + "--strict-markers", + "--doctest-modules", +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::UserWarning", +] +log_cli = "True" +markers = [ + "slow: slow tests", +] +minversion = "6.0" +testpaths = "tests/" + +[tool.coverage.report] +exclude_lines = [ + "pragma: nocover", + "raise NotImplementedError", + "raise NotImplementedError()", + "if __name__ == .__main__.:", +] diff --git a/scripts/esmfold_prior_plinder_finetuning.sh b/scripts/esmfold_prior_plinder_finetuning.sh new file mode 100644 index 0000000..172344e --- /dev/null +++ b/scripts/esmfold_prior_plinder_finetuning.sh @@ -0,0 +1,64 @@ +#!/bin/bash -l +######################### Batch Headers ######################### +#SBATCH --partition chengji-lab-gpu # NOTE: use reserved partition `chengji-lab-gpu` to use reserved A100 or H100 GPUs +#SBATCH --account chengji-lab # NOTE: this must be specified to use the reserved partition above +#SBATCH --nodes=1 # NOTE: this needs to match Lightning's `Trainer(num_nodes=...)` +#SBATCH --gres gpu:1 # request A100/H100 GPU resource(s) +#SBATCH --ntasks-per-node=1 # NOTE: this needs to be `1` on SLURM clusters when using Lightning's `ddp_spawn` strategy`; otherwise, set to match Lightning's quantity of `Trainer(devices=...)` +#SBATCH --mem=59G # NOTE: use `--mem=0` to request all memory "available" on the assigned node +#SBATCH -t 2-00:00:00 # time limit for the job (up to 2 days: `2-00:00:00`) +#SBATCH -J esmfold_prior_plinder_finetuning # job name +#SBATCH --output=R-%x.%j.out # output log file +#SBATCH --error=R-%x.%j.err # error log file + +module purge +module load cuda/11.8.0_gcc_9.5.0 + +# determine location of the project directory +use_private_project_dir=false # NOTE: customize as needed +if [ "$use_private_project_dir" = true ]; then + project_dir="/home/acmwhb/data/Repositories/Lab_Repositories/FlowDock" +else + project_dir="/cluster/pixstor/chengji-lab/acmwhb/Repositories/Lab_Repositories/FlowDock" +fi + +# shellcheck source=/dev/null +source /cluster/pixstor/chengji-lab/acmwhb/miniforge3/etc/profile.d/conda.sh +conda activate "$project_dir"/FlowDock/ + +# Reference Conda system libraries +export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH" + +echo "Calling flowdock/train.py!" +cd "$project_dir" || exit +srun python3 flowdock/train.py \ + callbacks.last_model_checkpoint.filename=null \ + callbacks.last_model_checkpoint.every_n_train_steps=200 \ + callbacks.last_model_checkpoint.every_n_epochs=null \ + ckpt_path="$(realpath 'logs/train/runs/2025-03-17_17-39-39/checkpoints/169-562000.ckpt')" \ + data=plinder \ + experiment='flowdock_fm' \ + environment=slurm \ + logger=wandb \ + logger.wandb.entity='bml-lab' \ + logger.wandb.group='FlowDock-FM' \ + +logger.wandb.name='2025-03-17_17:00:00-ESMFold-Prior-PLINDER-Finetuning' \ + +logger.wandb.id='1x2k5a79' \ + model.cfg.prior_type=esmfold \ + model.cfg.task.freeze_score_head=false \ + model.cfg.task.freeze_affinity=true \ + paths.output_dir="$(realpath 'logs/train/runs/2025-03-17_17-39-39')" \ + strategy=ddp \ + trainer=ddp \ + +trainer.accumulate_grad_batches=4 \ + trainer.devices=1 \ + trainer.num_nodes=1 +echo "Finished calling flowdock/train.py!" + +# NOTE: the following commands must be used to resume training from a checkpoint +# ckpt_path="$(realpath 'logs/train/runs/2025-03-17_17-39-39/checkpoints/169-562000.ckpt')" \ +# paths.output_dir="$(realpath 'logs/train/runs/2025-03-17_17-39-39')" \ + +# NOTE: the following commands may be used to speed up training +# model.compile=false \ +# +trainer.precision=bf16-mixed diff --git a/scripts/esmfold_prior_tiered_training.sh b/scripts/esmfold_prior_tiered_training.sh new file mode 100644 index 0000000..2593faf --- /dev/null +++ b/scripts/esmfold_prior_tiered_training.sh @@ -0,0 +1,64 @@ +#!/bin/bash -l +######################### Batch Headers ######################### +#SBATCH --partition chengji-lab-gpu # NOTE: use reserved partition `chengji-lab-gpu` to use reserved A100 or H100 GPUs +#SBATCH --account chengji-lab # NOTE: this must be specified to use the reserved partition above +#SBATCH --nodes=1 # NOTE: this needs to match Lightning's `Trainer(num_nodes=...)` +#SBATCH --gres gpu:H100:4 # request H100 GPU resource(s) +#SBATCH --ntasks-per-node=4 # NOTE: this needs to be `1` on SLURM clusters when using Lightning's `ddp_spawn` strategy`; otherwise, set to match Lightning's quantity of `Trainer(devices=...)` +#SBATCH --mem=0 # NOTE: use `--mem=0` to request all memory "available" on the assigned node +#SBATCH -t 7-00:00:00 # time limit for the job (up to 7 days: `7-00:00:00`) +#SBATCH -J esmfold_prior_tiered_training # job name +#SBATCH --output=R-%x.%j.out # output log file +#SBATCH --error=R-%x.%j.err # error log file + +random_seconds=$(( (RANDOM % 100) + 1 )) +echo "Sleeping for $random_seconds seconds before starting run" +sleep "$random_seconds" + +module purge +module load cuda/11.8.0_gcc_9.5.0 + +# determine location of the project directory +use_private_project_dir=false # NOTE: customize as needed +if [ "$use_private_project_dir" = true ]; then + project_dir="/home/acmwhb/data/Repositories/Lab_Repositories/FlowDock" +else + project_dir="/cluster/pixstor/chengji-lab/acmwhb/Repositories/Lab_Repositories/FlowDock" +fi + +# shellcheck source=/dev/null +source /cluster/pixstor/chengji-lab/acmwhb/miniforge3/etc/profile.d/conda.sh +conda activate "$project_dir"/FlowDock/ + +# Reference Conda system libraries +export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH" + +# NOTE: for tiered training, start by setting `model.cfg.task.freeze_score_head=false` and `model.cfg.task.freeze_affinity=true`, +# and once the model's score head has been trained to convergence, resume training with `model.cfg.task.freeze_score_head=true` and `model.cfg.task.freeze_affinity=false` + +echo "Calling flowdock/train.py!" +cd "$project_dir" || exit +srun python3 flowdock/train.py \ + experiment='flowdock_fm' \ + environment=slurm \ + logger=wandb \ + logger.wandb.entity='bml-lab' \ + logger.wandb.group='FlowDock-FM' \ + +logger.wandb.name='2024-12-06_18:00:00-ESMFold-Prior-Tiered-Training' \ + +logger.wandb.id='z1u52tvj' \ + model.cfg.prior_type=esmfold \ + model.cfg.task.freeze_score_head=false \ + model.cfg.task.freeze_affinity=true \ + strategy=ddp \ + trainer=ddp \ + trainer.devices=4 \ + trainer.num_nodes=1 +echo "Finished calling flowdock/train.py!" + +# NOTE: the following commands must be used to resume training from a checkpoint +# ckpt_path="$(realpath 'logs/train/runs/2024-05-17_13-45-06/checkpoints/last.ckpt')" \ +# paths.output_dir="$(realpath 'logs/train/runs/2024-05-17_13-45-06')" \ + +# NOTE: the following commands may be used to speed up training +# model.compile=false \ +# +trainer.precision=bf16-mixed diff --git a/scripts/esmfold_prior_training.sh b/scripts/esmfold_prior_training.sh new file mode 100644 index 0000000..6719383 --- /dev/null +++ b/scripts/esmfold_prior_training.sh @@ -0,0 +1,61 @@ +#!/bin/bash -l +######################### Batch Headers ######################### +#SBATCH --partition chengji-lab-gpu # NOTE: use reserved partition `chengji-lab-gpu` to use reserved A100 or H100 GPUs +#SBATCH --account chengji-lab # NOTE: this must be specified to use the reserved partition above +#SBATCH --nodes=1 # NOTE: this needs to match Lightning's `Trainer(num_nodes=...)` +#SBATCH --gres gpu:H100:4 # request H100 GPU resource(s) +#SBATCH --ntasks-per-node=4 # NOTE: this needs to be `1` on SLURM clusters when using Lightning's `ddp_spawn` strategy`; otherwise, set to match Lightning's quantity of `Trainer(devices=...)` +#SBATCH --mem=0 # NOTE: use `--mem=0` to request all memory "available" on the assigned node +#SBATCH -t 7-00:00:00 # time limit for the job (up to 7 days: `7-00:00:00`) +#SBATCH -J esmfold_prior_training # job name +#SBATCH --output=R-%x.%j.out # output log file +#SBATCH --error=R-%x.%j.err # error log file + +random_seconds=$(( (RANDOM % 100) + 1 )) +echo "Sleeping for $random_seconds seconds before starting run" +sleep "$random_seconds" + +module purge +module load cuda/11.8.0_gcc_9.5.0 + +# determine location of the project directory +use_private_project_dir=false # NOTE: customize as needed +if [ "$use_private_project_dir" = true ]; then + project_dir="/home/acmwhb/data/Repositories/Lab_Repositories/FlowDock" +else + project_dir="/cluster/pixstor/chengji-lab/acmwhb/Repositories/Lab_Repositories/FlowDock" +fi + +# shellcheck source=/dev/null +source /cluster/pixstor/chengji-lab/acmwhb/miniforge3/etc/profile.d/conda.sh +conda activate "$project_dir"/FlowDock/ + +# Reference Conda system libraries +export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH" + +echo "Calling flowdock/train.py!" +cd "$project_dir" || exit +srun python3 flowdock/train.py \ + experiment='flowdock_fm' \ + environment=slurm \ + logger=wandb \ + logger.wandb.entity='bml-lab' \ + logger.wandb.group='FlowDock-FM' \ + +logger.wandb.name='2024-12-06_18:00:00-ESMFold-Prior-Training' \ + +logger.wandb.id='z0u52tvj' \ + model.cfg.prior_type=esmfold \ + model.cfg.task.freeze_score_head=false \ + model.cfg.task.freeze_affinity=false \ + strategy=ddp \ + trainer=ddp \ + trainer.devices=4 \ + trainer.num_nodes=1 +echo "Finished calling flowdock/train.py!" + +# NOTE: the following commands must be used to resume training from a checkpoint +# ckpt_path="$(realpath 'logs/train/runs/2024-05-17_13-45-06/checkpoints/last.ckpt')" \ +# paths.output_dir="$(realpath 'logs/train/runs/2024-05-17_13-45-06')" \ + +# NOTE: the following commands may be used to speed up training +# model.compile=false \ +# +trainer.precision=bf16-mixed diff --git a/scripts/harmonic_prior_training.sh b/scripts/harmonic_prior_training.sh new file mode 100644 index 0000000..f7e05c2 --- /dev/null +++ b/scripts/harmonic_prior_training.sh @@ -0,0 +1,61 @@ +#!/bin/bash -l +######################### Batch Headers ######################### +#SBATCH --partition chengji-lab-gpu # NOTE: use reserved partition `chengji-lab-gpu` to use reserved A100 or H100 GPUs +#SBATCH --account chengji-lab # NOTE: this must be specified to use the reserved partition above +#SBATCH --nodes=1 # NOTE: this needs to match Lightning's `Trainer(num_nodes=...)` +#SBATCH --gres gpu:H100:4 # request H100 GPU resource(s) +#SBATCH --ntasks-per-node=4 # NOTE: this needs to be `1` on SLURM clusters when using Lightning's `ddp_spawn` strategy`; otherwise, set to match Lightning's quantity of `Trainer(devices=...)` +#SBATCH --mem=0 # NOTE: use `--mem=0` to request all memory "available" on the assigned node +#SBATCH -t 7-00:00:00 # time limit for the job (up to 7 days: `7-00:00:00`) +#SBATCH -J harmonic_prior_training # job name +#SBATCH --output=R-%x.%j.out # output log file +#SBATCH --error=R-%x.%j.err # error log file + +random_seconds=$(( (RANDOM % 100) + 1 )) +echo "Sleeping for $random_seconds seconds before starting run" +sleep "$random_seconds" + +module purge +module load cuda/11.8.0_gcc_9.5.0 + +# determine location of the project directory +use_private_project_dir=false # NOTE: customize as needed +if [ "$use_private_project_dir" = true ]; then + project_dir="/home/acmwhb/data/Repositories/Lab_Repositories/FlowDock" +else + project_dir="/cluster/pixstor/chengji-lab/acmwhb/Repositories/Lab_Repositories/FlowDock" +fi + +# shellcheck source=/dev/null +source /cluster/pixstor/chengji-lab/acmwhb/miniforge3/etc/profile.d/conda.sh +conda activate "$project_dir"/FlowDock/ + +# Reference Conda system libraries +export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH" + +echo "Calling flowdock/train.py!" +cd "$project_dir" || exit +srun python3 flowdock/train.py \ + experiment='flowdock_fm' \ + environment=slurm \ + logger=wandb \ + logger.wandb.entity='bml-lab' \ + logger.wandb.group='FlowDock-FM' \ + +logger.wandb.name='2024-12-06_18:00:00-Harmonic-Prior-Training' \ + +logger.wandb.id='z2u52tvj' \ + model.cfg.prior_type=harmonic \ + model.cfg.task.freeze_score_head=false \ + model.cfg.task.freeze_affinity=false \ + strategy=ddp \ + trainer=ddp \ + trainer.devices=4 \ + trainer.num_nodes=1 +echo "Finished calling flowdock/train.py!" + +# NOTE: the following commands must be used to resume training from a checkpoint +# ckpt_path="$(realpath 'logs/train/runs/2024-05-17_13-45-06/checkpoints/last.ckpt')" \ +# paths.output_dir="$(realpath 'logs/train/runs/2024-05-17_13-45-06')" \ + +# NOTE: the following commands may be used to speed up training +# model.compile=false \ +# +trainer.precision=bf16-mixed diff --git a/scripts/plinder_download.sh b/scripts/plinder_download.sh new file mode 100644 index 0000000..b117b73 --- /dev/null +++ b/scripts/plinder_download.sh @@ -0,0 +1,38 @@ +#!/bin/bash -l +######################### Batch Headers ######################### +#SBATCH --partition general # NOTE: use reserved partition `chengji-lab-gpu` to use reserved A100 or H100 GPUs +#SBATCH --account chengji-lab # NOTE: this must be specified to use the reserved partition above +#SBATCH --nodes=1 # NOTE: this needs to match Lightning's `Trainer(num_nodes=...)` +#SBATCH --ntasks-per-node=1 # NOTE: this needs to be `1` on SLURM clusters when using Lightning's `ddp_spawn` strategy`; otherwise, set to match Lightning's quantity of `Trainer(devices=...)` +#SBATCH --mem=59G # NOTE: use `--mem=0` to request all memory "available" on the assigned node +#SBATCH -t 0-02:00:00 # time limit for the job (up to 2 days: `2-00:00:00`) +#SBATCH -J plinder_download # job name +#SBATCH --output=R-%x.%j.out # output log file +#SBATCH --error=R-%x.%j.err # error log file + +module purge +module load cuda/11.8.0_gcc_9.5.0 + +# determine location of the project directory +use_private_project_dir=false # NOTE: customize as needed +if [ "$use_private_project_dir" = true ]; then + project_dir="/home/acmwhb/data/Repositories/Lab_Repositories/FlowDock" +else + project_dir="/cluster/pixstor/chengji-lab/acmwhb/Repositories/Lab_Repositories/FlowDock" +fi + +# shellcheck source=/dev/null +source /cluster/pixstor/chengji-lab/acmwhb/miniforge3/etc/profile.d/conda.sh +conda activate "$project_dir"/FlowDock/ + +# Reference Conda system libraries +export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH" + +# determine location of PLINDER dataset +export PLINDER_MOUNT="$project_dir/data/PLINDER" # NOTE: customize as needed +mkdir -p "$PLINDER_MOUNT" # create the directory if it doesn't exist + +echo "Downloading PLINDER to $PLINDER_MOUNT!" +cd "$project_dir" || exit +plinder_download -y +echo "Finished downloading PLINDER to $PLINDER_MOUNT!" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..914c900 --- /dev/null +++ b/setup.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +setup( + name="FlowDock", + version="0.0.2", + description="Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction", + author="", + author_email="", + url="https://github.com/BioinfoMachineLearning/FlowDock", + install_requires=["lightning", "hydra-core"], + packages=find_packages(), + # use this to customize global commands available in the terminal after installing the package + entry_points={ + "console_scripts": [ + "train_command = flowdock.train:main", + "eval_command = flowdock.eval:main", + ] + }, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/package_available.py b/tests/helpers/package_available.py new file mode 100644 index 0000000..0afdba8 --- /dev/null +++ b/tests/helpers/package_available.py @@ -0,0 +1,32 @@ +import platform + +import pkg_resources +from lightning.fabric.accelerators import TPUAccelerator + + +def _package_available(package_name: str) -> bool: + """Check if a package is available in your environment. + + :param package_name: The name of the package to be checked. + + :return: `True` if the package is available. `False` otherwise. + """ + try: + return pkg_resources.require(package_name) is not None + except pkg_resources.DistributionNotFound: + return False + + +_TPU_AVAILABLE = TPUAccelerator.is_available() + +_IS_WINDOWS = platform.system() == "Windows" + +_SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") + +_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") + +_WANDB_AVAILABLE = _package_available("wandb") +_NEPTUNE_AVAILABLE = _package_available("neptune") +_COMET_AVAILABLE = _package_available("comet_ml") +_MLFLOW_AVAILABLE = _package_available("mlflow") diff --git a/tests/helpers/run_if.py b/tests/helpers/run_if.py new file mode 100644 index 0000000..9545950 --- /dev/null +++ b/tests/helpers/run_if.py @@ -0,0 +1,142 @@ +"""Adapted from: + +https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py +""" + +import sys + +import pytest +import torch +from beartype.typing import Any, Dict, Optional +from packaging.version import Version +from pkg_resources import get_distribution +from pytest import MarkDecorator + +from tests.helpers.package_available import ( + _COMET_AVAILABLE, + _DEEPSPEED_AVAILABLE, + _FAIRSCALE_AVAILABLE, + _IS_WINDOWS, + _MLFLOW_AVAILABLE, + _NEPTUNE_AVAILABLE, + _SH_AVAILABLE, + _TPU_AVAILABLE, + _WANDB_AVAILABLE, +) + + +class RunIf: + """RunIf wrapper for conditional skipping of tests. + + Fully compatible with `@pytest.mark`. + + Example: + + ```python + @RunIf(min_torch="1.8") + @pytest.mark.parametrize("arg1", [1.0, 2.0]) + def test_wrapper(arg1): + assert arg1 > 0 + ``` + """ + + def __new__( + cls, + min_gpus: int = 0, + min_torch: Optional[str] = None, + max_torch: Optional[str] = None, + min_python: Optional[str] = None, + skip_windows: bool = False, + sh: bool = False, + tpu: bool = False, + fairscale: bool = False, + deepspeed: bool = False, + wandb: bool = False, + neptune: bool = False, + comet: bool = False, + mlflow: bool = False, + **kwargs: Dict[Any, Any], + ) -> MarkDecorator: + """Creates a new `@RunIf` `MarkDecorator` decorator. + + :param min_gpus: Min number of GPUs required to run test. + :param min_torch: Minimum pytorch version to run test. + :param max_torch: Maximum pytorch version to run test. + :param min_python: Minimum python version required to run test. + :param skip_windows: Skip test for Windows platform. + :param tpu: If TPU is available. + :param sh: If `sh` module is required to run the test. + :param fairscale: If `fairscale` module is required to run the test. + :param deepspeed: If `deepspeed` module is required to run the test. + :param wandb: If `wandb` module is required to run the test. + :param neptune: If `neptune` module is required to run the test. + :param comet: If `comet` module is required to run the test. + :param mlflow: If `mlflow` module is required to run the test. + :param kwargs: Native `pytest.mark.skipif` keyword arguments. + """ + conditions = [] + reasons = [] + + if min_gpus: + conditions.append(torch.cuda.device_count() < min_gpus) + reasons.append(f"GPUs>={min_gpus}") + + if min_torch: + torch_version = get_distribution("torch").version + conditions.append(Version(torch_version) < Version(min_torch)) + reasons.append(f"torch>={min_torch}") + + if max_torch: + torch_version = get_distribution("torch").version + conditions.append(Version(torch_version) >= Version(max_torch)) + reasons.append(f"torch<{max_torch}") + + if min_python: + py_version = ( + f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ) + conditions.append(Version(py_version) < Version(min_python)) + reasons.append(f"python>={min_python}") + + if skip_windows: + conditions.append(_IS_WINDOWS) + reasons.append("does not run on Windows") + + if tpu: + conditions.append(not _TPU_AVAILABLE) + reasons.append("TPU") + + if sh: + conditions.append(not _SH_AVAILABLE) + reasons.append("sh") + + if fairscale: + conditions.append(not _FAIRSCALE_AVAILABLE) + reasons.append("fairscale") + + if deepspeed: + conditions.append(not _DEEPSPEED_AVAILABLE) + reasons.append("deepspeed") + + if wandb: + conditions.append(not _WANDB_AVAILABLE) + reasons.append("wandb") + + if neptune: + conditions.append(not _NEPTUNE_AVAILABLE) + reasons.append("neptune") + + if comet: + conditions.append(not _COMET_AVAILABLE) + reasons.append("comet") + + if mlflow: + conditions.append(not _MLFLOW_AVAILABLE) + reasons.append("mlflow") + + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] + return pytest.mark.skipif( + condition=any(conditions), + reason=f"Requires: [{' + '.join(reasons)}]", + **kwargs, + ) diff --git a/tests/helpers/run_sh_command.py b/tests/helpers/run_sh_command.py new file mode 100644 index 0000000..075a984 --- /dev/null +++ b/tests/helpers/run_sh_command.py @@ -0,0 +1,21 @@ +import pytest +from beartype.typing import List + +from tests.helpers.package_available import _SH_AVAILABLE + +if _SH_AVAILABLE: + import sh + + +def run_sh_command(command: List[str]) -> None: + """Default method for executing shell commands with `pytest` and `sh` package. + + :param command: A list of shell commands as strings. + """ + msg = None + try: + sh.python(command) + except sh.ErrorReturnCode as e: + msg = e.stderr.decode() + if msg: + pytest.fail(msg=msg) diff --git a/tests/test_configs.py b/tests/test_configs.py new file mode 100644 index 0000000..d7041dc --- /dev/null +++ b/tests/test_configs.py @@ -0,0 +1,37 @@ +import hydra +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig + + +def test_train_config(cfg_train: DictConfig) -> None: + """Tests the training configuration provided by the `cfg_train` pytest fixture. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ + assert cfg_train + assert cfg_train.data + assert cfg_train.model + assert cfg_train.trainer + + HydraConfig().set_config(cfg_train) + + hydra.utils.instantiate(cfg_train.data) + hydra.utils.instantiate(cfg_train.model) + hydra.utils.instantiate(cfg_train.trainer) + + +def test_eval_config(cfg_eval: DictConfig) -> None: + """Tests the evaluation configuration provided by the `cfg_eval` pytest fixture. + + :param cfg_train: A DictConfig containing a valid evaluation configuration. + """ + assert cfg_eval + assert cfg_eval.data + assert cfg_eval.model + assert cfg_eval.trainer + + HydraConfig().set_config(cfg_eval) + + hydra.utils.instantiate(cfg_eval.data) + hydra.utils.instantiate(cfg_eval.model) + hydra.utils.instantiate(cfg_eval.trainer)