Initial commit: FlowDock pipeline configured for WES execution
Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled
Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled
This commit is contained in:
6
.env.example
Normal file
6
.env.example
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# example of file for storing private and user specific environment variables, like keys or system paths
|
||||||
|
# rename it to ".env" (excluded from version control by default)
|
||||||
|
# .env is loaded by train.py automatically
|
||||||
|
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
|
||||||
|
|
||||||
|
PLINDER_MOUNT="$(pwd)/data/PLINDER"
|
||||||
22
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
22
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!--
|
||||||
|
Please include a summary of the change and which issue is fixed.
|
||||||
|
Please also include relevant motivation and context.
|
||||||
|
List any dependencies that are required for this change.
|
||||||
|
List all the breaking changes introduced by this pull request.
|
||||||
|
-->
|
||||||
|
|
||||||
|
Fixes #\<issue_number>
|
||||||
|
|
||||||
|
## Before submitting
|
||||||
|
|
||||||
|
- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**?
|
||||||
|
- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together?
|
||||||
|
- [ ] Did you list all the **breaking changes** introduced by this pull request?
|
||||||
|
- [ ] Did you **test your PR locally** with `pytest` command?
|
||||||
|
- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command?
|
||||||
|
|
||||||
|
## Did you have fun?
|
||||||
|
|
||||||
|
Make sure you had fun coding 🙃
|
||||||
15
.github/codecov.yml
vendored
Normal file
15
.github/codecov.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
coverage:
|
||||||
|
status:
|
||||||
|
# measures overall project coverage
|
||||||
|
project:
|
||||||
|
default:
|
||||||
|
threshold: 100% # how much decrease in coverage is needed to not consider success
|
||||||
|
|
||||||
|
# measures PR or single commit coverage
|
||||||
|
patch:
|
||||||
|
default:
|
||||||
|
threshold: 100% # how much decrease in coverage is needed to not consider success
|
||||||
|
|
||||||
|
|
||||||
|
# project: off
|
||||||
|
# patch: off
|
||||||
16
.github/dependabot.yml
vendored
Normal file
16
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# To get started with Dependabot version updates, you'll need to specify which
|
||||||
|
# package ecosystems to update and where the package manifests are located.
|
||||||
|
# Please see the documentation for all configuration options:
|
||||||
|
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
||||||
|
|
||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "pip" # See documentation for possible values
|
||||||
|
directory: "/" # Location of package manifests
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
ignore:
|
||||||
|
- dependency-name: "pytorch-lightning"
|
||||||
|
update-types: ["version-update:semver-patch"]
|
||||||
|
- dependency-name: "torchmetrics"
|
||||||
|
update-types: ["version-update:semver-patch"]
|
||||||
44
.github/release-drafter.yml
vendored
Normal file
44
.github/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
name-template: "v$RESOLVED_VERSION"
|
||||||
|
tag-template: "v$RESOLVED_VERSION"
|
||||||
|
|
||||||
|
categories:
|
||||||
|
- title: "🚀 Features"
|
||||||
|
labels:
|
||||||
|
- "feature"
|
||||||
|
- "enhancement"
|
||||||
|
- title: "🐛 Bug Fixes"
|
||||||
|
labels:
|
||||||
|
- "fix"
|
||||||
|
- "bugfix"
|
||||||
|
- "bug"
|
||||||
|
- title: "🧹 Maintenance"
|
||||||
|
labels:
|
||||||
|
- "maintenance"
|
||||||
|
- "dependencies"
|
||||||
|
- "refactoring"
|
||||||
|
- "cosmetic"
|
||||||
|
- "chore"
|
||||||
|
- title: "📝️ Documentation"
|
||||||
|
labels:
|
||||||
|
- "documentation"
|
||||||
|
- "docs"
|
||||||
|
|
||||||
|
change-template: "- $TITLE @$AUTHOR (#$NUMBER)"
|
||||||
|
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions
|
||||||
|
|
||||||
|
version-resolver:
|
||||||
|
major:
|
||||||
|
labels:
|
||||||
|
- "major"
|
||||||
|
minor:
|
||||||
|
labels:
|
||||||
|
- "minor"
|
||||||
|
patch:
|
||||||
|
labels:
|
||||||
|
- "patch"
|
||||||
|
default: patch
|
||||||
|
|
||||||
|
template: |
|
||||||
|
## Changes
|
||||||
|
|
||||||
|
$CHANGES
|
||||||
22
.github/workflows/code-quality-main.yaml
vendored
Normal file
22
.github/workflows/code-quality-main.yaml
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Same as `code-quality-pr.yaml` but triggered on commit to main branch
|
||||||
|
# and runs on all files (instead of only the changed ones)
|
||||||
|
|
||||||
|
name: Code Quality Main
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
code-quality:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
|
||||||
|
- name: Run pre-commits
|
||||||
|
uses: pre-commit/action@v2.0.3
|
||||||
36
.github/workflows/code-quality-pr.yaml
vendored
Normal file
36
.github/workflows/code-quality-pr.yaml
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# This workflow finds which files were changed, prints them,
|
||||||
|
# and runs `pre-commit` on those files.
|
||||||
|
|
||||||
|
# Inspired by the sktime library:
|
||||||
|
# https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml
|
||||||
|
|
||||||
|
name: Code Quality PR
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [main, "release/*", "dev"]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
code-quality:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
|
||||||
|
- name: Find modified files
|
||||||
|
id: file_changes
|
||||||
|
uses: trilom/file-changes-action@v1.2.4
|
||||||
|
with:
|
||||||
|
output: " "
|
||||||
|
|
||||||
|
- name: List modified files
|
||||||
|
run: echo '${{ steps.file_changes.outputs.files}}'
|
||||||
|
|
||||||
|
- name: Run pre-commits
|
||||||
|
uses: pre-commit/action@v2.0.3
|
||||||
|
with:
|
||||||
|
extra_args: --files ${{ steps.file_changes.outputs.files}}
|
||||||
27
.github/workflows/release-drafter.yml
vendored
Normal file
27
.github/workflows/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Release Drafter
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
# branches to consider in the event; optional, defaults to all
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update_release_draft:
|
||||||
|
permissions:
|
||||||
|
# write permission is required to create a github release
|
||||||
|
contents: write
|
||||||
|
# write permission is required for autolabeler
|
||||||
|
# otherwise, read permission is required at least
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
# Drafts your next Release notes as Pull Requests are merged into "master"
|
||||||
|
- uses: release-drafter/release-drafter@v5
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
139
.github/workflows/test.yml
vendored
Normal file
139
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
name: Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main, "release/*", "dev"]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_tests_ubuntu:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: ["ubuntu-latest"]
|
||||||
|
python-version: ["3.8", "3.9", "3.10"]
|
||||||
|
|
||||||
|
timeout-minutes: 20
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
conda env create -f environment.yml
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pytest
|
||||||
|
pip install sh
|
||||||
|
|
||||||
|
- name: List dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip list
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
run: |
|
||||||
|
pytest -v
|
||||||
|
|
||||||
|
run_tests_macos:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: ["macos-latest"]
|
||||||
|
python-version: ["3.8", "3.9", "3.10"]
|
||||||
|
|
||||||
|
timeout-minutes: 20
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
conda env create -f environment.yml
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pytest
|
||||||
|
pip install sh
|
||||||
|
|
||||||
|
- name: List dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip list
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
run: |
|
||||||
|
pytest -v
|
||||||
|
|
||||||
|
run_tests_windows:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: ["windows-latest"]
|
||||||
|
python-version: ["3.8", "3.9", "3.10"]
|
||||||
|
|
||||||
|
timeout-minutes: 20
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
conda env create -f environment.yml
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pytest
|
||||||
|
|
||||||
|
- name: List dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip list
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
run: |
|
||||||
|
pytest -v
|
||||||
|
|
||||||
|
# upload code coverage report
|
||||||
|
code-coverage:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
conda env create -f environment.yml
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pytest
|
||||||
|
pip install pytest-cov[toml]
|
||||||
|
pip install sh
|
||||||
|
|
||||||
|
- name: Run tests and collect coverage
|
||||||
|
run: pytest --cov flowdock # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
|
||||||
|
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v3
|
||||||
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
work/
|
||||||
|
.nextflow/
|
||||||
|
.nextflow.log*
|
||||||
|
*.log.*
|
||||||
|
results/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.tmp
|
||||||
|
*.swp
|
||||||
150
.pre-commit-config.yaml
Normal file
150
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
default_language_version:
|
||||||
|
python: python3
|
||||||
|
|
||||||
|
exclude: "^forks/"
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
# list of supported hooks: https://pre-commit.com/hooks.html
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-docstring-first
|
||||||
|
- id: check-yaml
|
||||||
|
- id: debug-statements
|
||||||
|
- id: detect-private-key
|
||||||
|
- id: check-executables-have-shebangs
|
||||||
|
- id: check-toml
|
||||||
|
- id: check-case-conflict
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ["--maxkb=20000"]
|
||||||
|
|
||||||
|
# python code formatting
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 25.1.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
args: [--line-length, "99"]
|
||||||
|
|
||||||
|
# python import sorting
|
||||||
|
- repo: https://github.com/PyCQA/isort
|
||||||
|
rev: 6.0.1
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
args: ["--profile", "black", "--filter-files"]
|
||||||
|
|
||||||
|
# python upgrading syntax to newer version
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.19.1
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args: [--py38-plus]
|
||||||
|
|
||||||
|
# python docstring formatting
|
||||||
|
- repo: https://github.com/myint/docformatter
|
||||||
|
rev: eb1df347edd128b30cd3368dddc3aa65edcfac38 # Don't autoupdate until https://github.com/PyCQA/docformatter/issues/293 is fixed
|
||||||
|
hooks:
|
||||||
|
- id: docformatter
|
||||||
|
args:
|
||||||
|
[
|
||||||
|
--in-place,
|
||||||
|
--wrap-summaries=99,
|
||||||
|
--wrap-descriptions=99,
|
||||||
|
--style=sphinx,
|
||||||
|
--black,
|
||||||
|
]
|
||||||
|
|
||||||
|
# python docstring coverage checking
|
||||||
|
- repo: https://github.com/econchick/interrogate
|
||||||
|
rev: 1.7.0 # or master if you're bold
|
||||||
|
hooks:
|
||||||
|
- id: interrogate
|
||||||
|
args:
|
||||||
|
[
|
||||||
|
--verbose,
|
||||||
|
--fail-under=80,
|
||||||
|
--ignore-init-module,
|
||||||
|
--ignore-init-method,
|
||||||
|
--ignore-module,
|
||||||
|
--ignore-nested-functions,
|
||||||
|
-vv,
|
||||||
|
]
|
||||||
|
|
||||||
|
# python check (PEP8), programming errors and code complexity
|
||||||
|
- repo: https://github.com/PyCQA/flake8
|
||||||
|
rev: 7.1.2
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
args:
|
||||||
|
[
|
||||||
|
"--extend-ignore",
|
||||||
|
"E203,E402,E501,F401,F841,RST2,RST301",
|
||||||
|
"--exclude",
|
||||||
|
"logs/*,data/*",
|
||||||
|
]
|
||||||
|
additional_dependencies: [flake8-rst-docstrings==0.3.0]
|
||||||
|
|
||||||
|
# python security linter
|
||||||
|
- repo: https://github.com/PyCQA/bandit
|
||||||
|
rev: "1.8.3"
|
||||||
|
hooks:
|
||||||
|
- id: bandit
|
||||||
|
args: ["-s", "B101"]
|
||||||
|
|
||||||
|
# yaml formatting
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||||
|
rev: v4.0.0-alpha.8
|
||||||
|
hooks:
|
||||||
|
- id: prettier
|
||||||
|
types: [yaml]
|
||||||
|
exclude: "environment.yaml"
|
||||||
|
|
||||||
|
# shell scripts linter
|
||||||
|
- repo: https://github.com/shellcheck-py/shellcheck-py
|
||||||
|
rev: v0.10.0.1
|
||||||
|
hooks:
|
||||||
|
- id: shellcheck
|
||||||
|
|
||||||
|
# md formatting
|
||||||
|
- repo: https://github.com/executablebooks/mdformat
|
||||||
|
rev: 0.7.22
|
||||||
|
hooks:
|
||||||
|
- id: mdformat
|
||||||
|
args: ["--number"]
|
||||||
|
additional_dependencies:
|
||||||
|
- mdformat-gfm
|
||||||
|
- mdformat-tables
|
||||||
|
- mdformat_frontmatter
|
||||||
|
# - mdformat-toc
|
||||||
|
# - mdformat-black
|
||||||
|
|
||||||
|
# word spelling linter
|
||||||
|
- repo: https://github.com/codespell-project/codespell
|
||||||
|
rev: v2.4.1
|
||||||
|
hooks:
|
||||||
|
- id: codespell
|
||||||
|
args:
|
||||||
|
- --skip=logs/**,data/**,*.ipynb,flowdock/data/components/constants.py,flowdock/data/components/process_mols.py,flowdock/data/components/residue_constants.py,flowdock/data/components/uff_parameters.csv,flowdock/data/components/chemical/*,flowdock/utils/data_utils.py
|
||||||
|
# - --ignore-words-list=abc,def
|
||||||
|
|
||||||
|
# jupyter notebook cell output clearing
|
||||||
|
- repo: https://github.com/kynan/nbstripout
|
||||||
|
rev: 0.8.1
|
||||||
|
hooks:
|
||||||
|
- id: nbstripout
|
||||||
|
|
||||||
|
# jupyter notebook linting
|
||||||
|
- repo: https://github.com/nbQA-dev/nbQA
|
||||||
|
rev: 1.9.1
|
||||||
|
hooks:
|
||||||
|
- id: nbqa-black
|
||||||
|
args: ["--line-length=99"]
|
||||||
|
- id: nbqa-isort
|
||||||
|
args: ["--profile=black"]
|
||||||
|
- id: nbqa-flake8
|
||||||
|
args:
|
||||||
|
[
|
||||||
|
"--extend-ignore=E203,E402,E501,F401,F841",
|
||||||
|
"--exclude=logs/*,data/*",
|
||||||
|
]
|
||||||
2
.project-root
Normal file
2
.project-root
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# this file is required for inferring the project root directory
|
||||||
|
# do not delete
|
||||||
49
Dockerfile
Normal file
49
Dockerfile
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-runtime
|
||||||
|
|
||||||
|
LABEL authors="BioinfoMachineLearning"
|
||||||
|
|
||||||
|
# Install system requirements
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --reinstall ca-certificates && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
libxml2 \
|
||||||
|
libgl-dev \
|
||||||
|
libgl1 \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
procps && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
RUN mkdir -p /software/flowdock
|
||||||
|
WORKDIR /software/flowdock
|
||||||
|
|
||||||
|
# Clone FlowDock repository
|
||||||
|
RUN git clone https://github.com/BioinfoMachineLearning/FlowDock /software/flowdock
|
||||||
|
|
||||||
|
# Create conda environment
|
||||||
|
RUN conda env create -f /software/flowdock/environments/flowdock_environment.yaml
|
||||||
|
|
||||||
|
# Install local package and ProDy
|
||||||
|
RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh && \
|
||||||
|
conda activate FlowDock && \
|
||||||
|
pip install --no-cache-dir -e /software/flowdock && \
|
||||||
|
pip install --no-cache-dir --no-dependencies prody==2.4.1"
|
||||||
|
|
||||||
|
# Create checkpoints directory
|
||||||
|
RUN mkdir -p /software/flowdock/checkpoints
|
||||||
|
|
||||||
|
# Download pretrained weights
|
||||||
|
RUN wget -q https://zenodo.org/records/15066450/files/flowdock_checkpoints.tar.gz && \
|
||||||
|
tar -xzf flowdock_checkpoints.tar.gz && \
|
||||||
|
rm flowdock_checkpoints.tar.gz
|
||||||
|
|
||||||
|
# Activate conda environment by default
|
||||||
|
RUN echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
|
||||||
|
echo "conda activate FlowDock" >> ~/.bashrc
|
||||||
|
|
||||||
|
# Default shell
|
||||||
|
SHELL ["/bin/bash", "-l", "-c"]
|
||||||
|
CMD ["/bin/bash"]
|
||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 BioinfoMachineLearning
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
30
Makefile
Normal file
30
Makefile
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
|
||||||
|
help: ## Show help
|
||||||
|
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
|
clean: ## Clean autogenerated files
|
||||||
|
rm -rf dist
|
||||||
|
find . -type f -name "*.DS_Store" -ls -delete
|
||||||
|
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
|
||||||
|
find . | grep -E ".pytest_cache" | xargs rm -rf
|
||||||
|
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
|
||||||
|
rm -f .coverage
|
||||||
|
|
||||||
|
clean-logs: ## Clean logs
|
||||||
|
rm -rf logs/**
|
||||||
|
|
||||||
|
format: ## Run pre-commit hooks
|
||||||
|
pre-commit run -a
|
||||||
|
|
||||||
|
sync: ## Merge changes from main branch to your current branch
|
||||||
|
git pull
|
||||||
|
git pull origin main
|
||||||
|
|
||||||
|
test: ## Run not slow tests
|
||||||
|
pytest -k "not slow"
|
||||||
|
|
||||||
|
test-full: ## Run all tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
train: ## Train the model
|
||||||
|
python flowdock/train.py
|
||||||
471
README.md
Normal file
471
README.md
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
<div align="center">
|
||||||
|
|
||||||
|
# FlowDock
|
||||||
|
|
||||||
|
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
|
||||||
|
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
|
||||||
|
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
|
||||||
|
|
||||||
|
<!-- <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br> -->
|
||||||
|
|
||||||
|
[](https://arxiv.org/abs/2412.10966)
|
||||||
|
[](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i198/8199366)
|
||||||
|
[](https://doi.org/10.5281/zenodo.15066450)
|
||||||
|
|
||||||
|
<img src="./img/FlowDock.png" width="600">
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This is the official codebase of the paper
|
||||||
|
|
||||||
|
**FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction**
|
||||||
|
|
||||||
|
\[[arXiv](https://arxiv.org/abs/2412.10966)\] \[[ISMB](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i198/8199366)\] \[[Neurosnap](https://neurosnap.ai/service/FlowDock)\] \[[Tamarind Bio](https://app.tamarind.bio/tools/flowdock)\]
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
|
||||||
|
- [FlowDock](#flowdock)
|
||||||
|
- [Description](#description)
|
||||||
|
- [Contents](#contents)
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [How to prepare data for `FlowDock`](#how-to-prepare-data-for-flowdock)
|
||||||
|
- [Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint)](#generating-esm2-embeddings-for-each-protein-optional-cached-input-data-available-on-sharepoint)
|
||||||
|
- [Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo)](#predicting-apo-protein-structures-using-esmfold-optional-cached-data-available-on-zenodo)
|
||||||
|
- [How to train `FlowDock`](#how-to-train-flowdock)
|
||||||
|
- [How to evaluate `FlowDock`](#how-to-evaluate-flowdock)
|
||||||
|
- [How to create comparative plots of benchmarking results](#how-to-create-comparative-plots-of-benchmarking-results)
|
||||||
|
- [How to predict new protein-ligand complex structures and their affinities using `FlowDock`](#how-to-predict-new-protein-ligand-complex-structures-and-their-affinities-using-flowdock)
|
||||||
|
- [For developers](#for-developers)
|
||||||
|
- [Docker](#docker)
|
||||||
|
- [Acknowledgements](#acknowledgements)
|
||||||
|
- [License](#license)
|
||||||
|
- [Citing this work](#citing-this-work)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
Install Mamba
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||||
|
bash Miniforge3-$(uname)-$(uname -m).sh # accept all terms and install to the default location
|
||||||
|
rm Miniforge3-$(uname)-$(uname -m).sh # (optionally) remove installer after using it
|
||||||
|
source ~/.bashrc # alternatively, one can restart their shell session to achieve the same result
|
||||||
|
```
|
||||||
|
|
||||||
|
Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# clone project
|
||||||
|
git clone https://github.com/BioinfoMachineLearning/FlowDock
|
||||||
|
cd FlowDock
|
||||||
|
|
||||||
|
# create conda environment
|
||||||
|
mamba env create -f environments/flowdock_environment.yaml
|
||||||
|
conda activate FlowDock # NOTE: one still needs to use `conda` to (de)activate environments
|
||||||
|
pip3 install -e . # install local project as package
|
||||||
|
pip3 install prody==2.4.1 --no-dependencies # install ProDy without NumPy dependency
|
||||||
|
```
|
||||||
|
|
||||||
|
Download checkpoints
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# pretrained NeuralPLexer weights
|
||||||
|
cd checkpoints/
|
||||||
|
wget https://zenodo.org/records/10373581/files/neuralplexermodels_downstream_datasets_predictions.zip
|
||||||
|
unzip neuralplexermodels_downstream_datasets_predictions.zip
|
||||||
|
rm neuralplexermodels_downstream_datasets_predictions.zip
|
||||||
|
cd ../
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# pretrained FlowDock weights
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_checkpoints.tar.gz
|
||||||
|
tar -xzf flowdock_checkpoints.tar.gz
|
||||||
|
rm flowdock_checkpoints.tar.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
Download preprocessed datasets
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# cached input data for training/validation/testing
|
||||||
|
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/ER1hctIBhDVFjM7YepOI6WcBXNBm4_e6EBjFEHAM1A3y5g?download=1"
|
||||||
|
tar -xzf flowdock_data_cache.tar.gz
|
||||||
|
rm flowdock_data_cache.tar.gz
|
||||||
|
|
||||||
|
# cached data for PDBBind, Binding MOAD, DockGen, and the PDB-based van der Mers (vdM) dataset
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_pdbbind_data.tar.gz
|
||||||
|
tar -xzf flowdock_pdbbind_data.tar.gz
|
||||||
|
rm flowdock_pdbbind_data.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_moad_data.tar.gz
|
||||||
|
tar -xzf flowdock_moad_data.tar.gz
|
||||||
|
rm flowdock_moad_data.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_dockgen_data.tar.gz
|
||||||
|
tar -xzf flowdock_dockgen_data.tar.gz
|
||||||
|
rm flowdock_dockgen_data.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_pdbsidechain_data.tar.gz
|
||||||
|
tar -xzf flowdock_pdbsidechain_data.tar.gz
|
||||||
|
rm flowdock_pdbsidechain_data.tar.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## How to prepare data for `FlowDock`
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
**NOTE:** The following steps (besides downloading PDBBind and Binding MOAD's PDB files) are only necessary if one wants to fully process each of the following datasets manually.
|
||||||
|
Otherwise, preprocessed versions of each dataset can be found on [Zenodo](https://zenodo.org/records/15066450).
|
||||||
|
|
||||||
|
Download data
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# fetch preprocessed PDBBind and Binding MOAD (as well as the optional DockGen and vdM datasets)
|
||||||
|
cd data/
|
||||||
|
|
||||||
|
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/EXesf4oh6ztOusGqFcDyqP0Bvk-LdJ1DagEl8GNK-HxDtg?download=1"
|
||||||
|
wget https://zenodo.org/records/10656052/files/BindingMOAD_2020_processed.tar
|
||||||
|
wget https://zenodo.org/records/10656052/files/DockGen.tar
|
||||||
|
wget https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02.tar.gz
|
||||||
|
|
||||||
|
mv EXesf4oh6ztOusGqFcDyqP0Bvk-LdJ1DagEl8GNK-HxDtg?download=1 PDBBind.tar.gz
|
||||||
|
|
||||||
|
tar -xzf PDBBind.tar.gz
|
||||||
|
tar -xf BindingMOAD_2020_processed.tar
|
||||||
|
tar -xf DockGen.tar
|
||||||
|
tar -xzf pdb_2021aug02.tar.gz
|
||||||
|
|
||||||
|
rm PDBBind.tar.gz BindingMOAD_2020_processed.tar DockGen.tar pdb_2021aug02.tar.gz
|
||||||
|
|
||||||
|
mkdir pdbbind/ moad/ pdbsidechain/
|
||||||
|
mv PDBBind_processed/ pdbbind/
|
||||||
|
mv BindingMOAD_2020_processed/ moad/
|
||||||
|
mv pdb_2021aug02/ pdbsidechain/
|
||||||
|
|
||||||
|
cd ../
|
||||||
|
```
|
||||||
|
|
||||||
|
Lastly, to finetune `FlowDock` using the `PLINDER` dataset, one must first prepare this data for training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# fetch PLINDER data (NOTE: requires ~1 hour to download and ~750G of storage)
|
||||||
|
export PLINDER_MOUNT="$(pwd)/data/PLINDER"
|
||||||
|
mkdir -p "$PLINDER_MOUNT" # create the directory if it doesn't exist
|
||||||
|
|
||||||
|
plinder_download -y
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint)
|
||||||
|
|
||||||
|
To generate the ESM2 embeddings for the protein inputs,
|
||||||
|
first create all the corresponding FASTA files for each protein sequence
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/data/components/esm_embedding_preparation.py --dataset pdbbind --data_dir data/pdbbind/PDBBind_processed/ --out_file data/pdbbind/pdbbind_sequences.fasta
|
||||||
|
python flowdock/data/components/esm_embedding_preparation.py --dataset moad --data_dir data/moad/BindingMOAD_2020_processed/pdb_protein/ --out_file data/moad/moad_sequences.fasta
|
||||||
|
python flowdock/data/components/esm_embedding_preparation.py --dataset dockgen --data_dir data/DockGen/processed_files/ --out_file data/DockGen/dockgen_sequences.fasta
|
||||||
|
python flowdock/data/components/esm_embedding_preparation.py --dataset pdbsidechain --data_dir data/pdbsidechain/pdb_2021aug02/pdb/ --out_file data/pdbsidechain/pdbsidechain_sequences.fasta
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, generate all ESM2 embeddings in batch using the ESM repository's helper script
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/pdbbind/pdbbind_sequences.fasta data/pdbbind/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
|
||||||
|
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/moad/moad_sequences.fasta data/moad/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
|
||||||
|
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/DockGen/dockgen_sequences.fasta data/DockGen/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
|
||||||
|
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/pdbsidechain/pdbsidechain_sequences.fasta data/pdbsidechain/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo)
|
||||||
|
|
||||||
|
To generate the apo version of each protein structure,
|
||||||
|
first create ESMFold-ready versions of the combined FASTA files
|
||||||
|
prepared above by the script `esm_embedding_preparation.py`
|
||||||
|
for the PDBBind, Binding MOAD, DockGen, and PDBSidechain datasets, respectively
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/data/components/esmfold_sequence_preparation.py dataset=pdbbind
|
||||||
|
python flowdock/data/components/esmfold_sequence_preparation.py dataset=moad
|
||||||
|
python flowdock/data/components/esmfold_sequence_preparation.py dataset=dockgen
|
||||||
|
python flowdock/data/components/esmfold_sequence_preparation.py dataset=pdbsidechain
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, predict each apo protein structure using ESMFold's batch
|
||||||
|
inference script
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Note: Having a CUDA-enabled device available when running this script is highly recommended
|
||||||
|
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/pdbbind/pdbbind_esmfold_sequences.fasta -o data/pdbbind/pdbbind_esmfold_structures --cuda-device-index 0 --skip-existing
|
||||||
|
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/moad/moad_esmfold_sequences.fasta -o data/moad/moad_esmfold_structures --cuda-device-index 0 --skip-existing
|
||||||
|
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/DockGen/dockgen_esmfold_sequences.fasta -o data/DockGen/dockgen_esmfold_structures --cuda-device-index 0 --skip-existing
|
||||||
|
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/pdbsidechain/pdbsidechain_esmfold_sequences.fasta -o data/pdbsidechain/pdbsidechain_esmfold_structures --cuda-device-index 0 --skip-existing
|
||||||
|
```
|
||||||
|
|
||||||
|
Align each apo protein structure to its corresponding
|
||||||
|
holo protein structure counterpart in PDBBind, Binding MOAD, and PDBSidechain,
|
||||||
|
taking ligand conformations into account during each alignment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=pdbbind num_workers=1
|
||||||
|
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=moad num_workers=1
|
||||||
|
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=dockgen num_workers=1
|
||||||
|
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=pdbsidechain num_workers=1
|
||||||
|
```
|
||||||
|
|
||||||
|
Lastly, assess the apo-to-holo alignments in terms of statistics and structural metrics
|
||||||
|
to enable runtime-dynamic dataset filtering using such information
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=pdbbind usalign_exec_path=$MY_USALIGN_EXEC_PATH
|
||||||
|
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=moad usalign_exec_path=$MY_USALIGN_EXEC_PATH
|
||||||
|
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=dockgen usalign_exec_path=$MY_USALIGN_EXEC_PATH
|
||||||
|
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=pdbsidechain usalign_exec_path=$MY_USALIGN_EXEC_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## How to train `FlowDock`
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
Train model with default configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# train on CPU
|
||||||
|
python flowdock/train.py trainer=cpu
|
||||||
|
|
||||||
|
# train on GPU
|
||||||
|
python flowdock/train.py trainer=gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/train.py experiment=experiment_name.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, reproduce `FlowDock`'s default model training run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/train.py experiment=flowdock_fm
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** You can override any parameter from command line like this
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/train.py experiment=flowdock_fm trainer.max_epochs=20 data.batch_size=8
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, override parameters to finetune `FlowDock`'s pretrained weights using a new dataset such as [PLINDER](https://www.plinder.sh/)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/train.py experiment=flowdock_fm data=plinder ckpt_path=checkpoints/esmfold_prior_paper_weights.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## How to evaluate `FlowDock`
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
To reproduce `FlowDock`'s evaluation results for structure prediction, please refer to its documentation in version `0.6.0-FlowDock` of the [PoseBench](https://github.com/BioinfoMachineLearning/PoseBench/tree/0.6.0-FlowDock?tab=readme-ov-file#how-to-run-inference-with-flowdock) GitHub repository.
|
||||||
|
|
||||||
|
To reproduce `FlowDock`'s evaluation results for binding affinity prediction using the PDBBind dataset
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/eval.py data.test_datasets=[pdbbind] ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt trainer=gpu
|
||||||
|
... # re-run two more times to gather triplicate results
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## How to create comparative plots of benchmarking results
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
Download baseline method predictions and results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# cached predictions and evaluation metrics for reproducing structure prediction paper results
|
||||||
|
wget https://zenodo.org/records/15066450/files/alphafold3_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf alphafold3_baseline_method_predictions.tar.gz
|
||||||
|
rm alphafold3_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/chai_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf chai_baseline_method_predictions.tar.gz
|
||||||
|
rm chai_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/diffdock_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf diffdock_baseline_method_predictions.tar.gz
|
||||||
|
rm diffdock_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/dynamicbind_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf dynamicbind_baseline_method_predictions.tar.gz
|
||||||
|
rm dynamicbind_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf flowdock_baseline_method_predictions.tar.gz
|
||||||
|
rm flowdock_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_aft_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf flowdock_aft_baseline_method_predictions.tar.gz
|
||||||
|
rm flowdock_aft_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_pft_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf flowdock_pft_baseline_method_predictions.tar.gz
|
||||||
|
rm flowdock_pft_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_esmfold_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf flowdock_esmfold_baseline_method_predictions.tar.gz
|
||||||
|
rm flowdock_esmfold_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_chai_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf flowdock_chai_baseline_method_predictions.tar.gz
|
||||||
|
rm flowdock_chai_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/flowdock_hp_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf flowdock_hp_baseline_method_predictions.tar.gz
|
||||||
|
rm flowdock_hp_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/neuralplexer_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf neuralplexer_baseline_method_predictions.tar.gz
|
||||||
|
rm neuralplexer_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/vina_p2rank_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf vina_p2rank_baseline_method_predictions.tar.gz
|
||||||
|
rm vina_p2rank_baseline_method_predictions.tar.gz
|
||||||
|
|
||||||
|
wget https://zenodo.org/records/15066450/files/rfaa_baseline_method_predictions.tar.gz
|
||||||
|
tar -xzf rfaa_baseline_method_predictions.tar.gz
|
||||||
|
rm rfaa_baseline_method_predictions.tar.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
Reproduce paper result figures
|
||||||
|
|
||||||
|
```bash
|
||||||
|
jupyter notebook notebooks/casp16_binding_affinity_prediction_results_plotting.ipynb
|
||||||
|
jupyter notebook notebooks/casp16_flowdock_vs_multicom_ligand_structure_prediction_results_plotting.ipynb
|
||||||
|
jupyter notebook notebooks/dockgen_structure_prediction_results_plotting.ipynb
|
||||||
|
jupyter notebook notebooks/posebusters_benchmark_structure_prediction_chemical_similarity_analysis.ipynb
|
||||||
|
jupyter notebook notebooks/posebusters_benchmark_structure_prediction_results_plotting.ipynb
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## How to predict new protein-ligand complex structures and their affinities using `FlowDock`
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
For example, generate new protein-ligand complexes for a pair of protein sequence and ligand SMILES strings such as those of the PDBBind 2020 test target `6i67`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling input_receptor='YNKIVHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNKLEGKVPMHKLFLEMLEAKV' input_ligand='"c1cc2c(cc1O)CCCC2"' input_template=data/pdbbind/pdbbind_holo_aligned_esmfold_structures/6i67_holo_aligned_esmfold_protein.pdb sample_id='6i67' out_path='./6i67_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=true auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
Or, for example, generate new protein-ligand complexes for pairs of protein sequences and (multi-)ligand SMILES strings (delimited via `|`) such as those of the CASP15 target `T1152`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling input_receptor='MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIPN' input_ligand='"CC(=O)NC1C(O)OC(CO)C(OC2OC(CO)C(OC3OC(CO)C(O)C(O)C3NC(C)=O)C(O)C2NC(C)=O)C1O"' input_template=data/test_cases/predicted_structures/T1152.pdb sample_id='T1152' out_path='./T1152_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=true auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
If you do not already have a template protein structure available for your target of interest, set `input_template=null` to instead have the sampling script predict the ESMFold structure of your provided `input_protein` sequence before running the sampling pipeline. For more information regarding the input arguments available for sampling, please refer to the config at `configs/sample.yaml`.
|
||||||
|
|
||||||
|
**NOTE:** To optimize prediction runtimes, a `csv_path` can be specified instead of the `input_receptor`, `input_ligand`, and `input_template` CLI arguments to perform *batched* prediction for a collection of protein-ligand sequence pairs, each represented as a CSV row containing column values for `id`, `input_receptor`, `input_ligand`, and `input_template`. Additionally, disabling `visualize_sample_trajectories` may reduce storage requirements when predicting a large batch of inputs.
|
||||||
|
|
||||||
|
For instance, one can perform batched prediction as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling csv_path='./data/test_cases/prediction_inputs/flowdock_batched_inputs.csv' out_path='./T1152_batch_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=false auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## For developers
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
Set up `pre-commit` (one time only) for automatic code linting and formatting upon each `git commit`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
Manually reformat all files in the project, as desired
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pre-commit run -a
|
||||||
|
```
|
||||||
|
|
||||||
|
Update dependencies in a `*_environment.yml` file
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mamba env export > env.yaml # e.g., run this after installing new dependencies locally
|
||||||
|
diff environments/flowdock_environment.yaml env.yaml # note the differences and copy accepted changes back into e.g., `environments/flowdock_environment.yaml`
|
||||||
|
rm env.yaml # clean up temporary environment file
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Docker
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
Given that this tool has a number of dependencies, it may be easier to run it in a Docker container.
|
||||||
|
|
||||||
|
Pull from [Docker Hub](https://hub.docker.com/repository/docker/cford38/flowdock): `docker pull cford38/flowdock:latest`
|
||||||
|
|
||||||
|
Alternatively, build the Docker image locally:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 -t flowdock .
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, run the Docker container (and mount your local `checkpoints/` directory)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --name flowdock -it flowdock /bin/bash
|
||||||
|
|
||||||
|
# docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --name flowdock -it cford38/flowdock:latest /bin/bash
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
`FlowDock` builds upon the source code and data from the following projects:
|
||||||
|
|
||||||
|
- [DiffDock](https://github.com/gcorso/DiffDock)
|
||||||
|
- [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template)
|
||||||
|
- [NeuralPLexer](https://github.com/zrqiao/NeuralPLexer)
|
||||||
|
|
||||||
|
We thank all their contributors and maintainers!
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is covered under the **MIT License**.
|
||||||
|
|
||||||
|
## Citing this work
|
||||||
|
|
||||||
|
If you use the code or data associated with this package or otherwise find this work useful, please cite:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{morehead2025flowdock,
|
||||||
|
title={FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction},
|
||||||
|
author={Alex Morehead and Jianlin Cheng},
|
||||||
|
booktitle={Intelligent Systems for Molecular Biology (ISMB)},
|
||||||
|
year=2025,
|
||||||
|
}
|
||||||
|
```
|
||||||
6
citation.bib
Normal file
6
citation.bib
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
@inproceedings{morehead2025flowdock,
|
||||||
|
title={FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction},
|
||||||
|
author={Alex Morehead and Jianlin Cheng},
|
||||||
|
booktitle={Intelligent Systems for Molecular Biology (ISMB)},
|
||||||
|
year=2025,
|
||||||
|
}
|
||||||
1
configs/__init__.py
Normal file
1
configs/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# this file is needed here to include configs when building project as a package
|
||||||
21
configs/callbacks/default.yaml
Normal file
21
configs/callbacks/default.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
defaults:
|
||||||
|
- ema
|
||||||
|
- last_model_checkpoint
|
||||||
|
- learning_rate_monitor
|
||||||
|
- model_checkpoint
|
||||||
|
- model_summary
|
||||||
|
- rich_progress_bar
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
last_model_checkpoint:
|
||||||
|
dirpath: ${paths.output_dir}/checkpoints
|
||||||
|
filename: "last"
|
||||||
|
monitor: null
|
||||||
|
verbose: True
|
||||||
|
auto_insert_metric_name: False
|
||||||
|
every_n_epochs: 1
|
||||||
|
save_on_train_epoch_end: True
|
||||||
|
enable_version_counter: False
|
||||||
|
|
||||||
|
model_summary:
|
||||||
|
max_depth: -1
|
||||||
15
configs/callbacks/early_stopping.yaml
Normal file
15
configs/callbacks/early_stopping.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
||||||
|
|
||||||
|
early_stopping:
|
||||||
|
_target_: lightning.pytorch.callbacks.EarlyStopping
|
||||||
|
monitor: ??? # quantity to be monitored, must be specified !!!
|
||||||
|
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
||||||
|
patience: 3 # number of checks with no improvement after which training will be stopped
|
||||||
|
verbose: False # verbosity mode
|
||||||
|
mode: "min" # "max" means higher metric value is better, can be also "min"
|
||||||
|
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
||||||
|
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
||||||
|
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
||||||
|
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
||||||
|
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
||||||
|
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
||||||
10
configs/callbacks/ema.yaml
Normal file
10
configs/callbacks/ema.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
|
||||||
|
|
||||||
|
# Maintains an exponential moving average (EMA) of model weights.
|
||||||
|
# Look at the above link for more detailed information regarding the original implementation.
|
||||||
|
ema:
|
||||||
|
_target_: flowdock.models.components.callbacks.ema.EMA
|
||||||
|
decay: 0.999
|
||||||
|
validate_original_weights: false
|
||||||
|
every_n_steps: 4
|
||||||
|
cpu_offload: false
|
||||||
21
configs/callbacks/last_model_checkpoint.yaml
Normal file
21
configs/callbacks/last_model_checkpoint.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
||||||
|
|
||||||
|
last_model_checkpoint:
|
||||||
|
# NOTE: this is a direct copy of `model_checkpoint`,
|
||||||
|
# which is necessary to make to work around the
|
||||||
|
# key-duplication limitations of YAML config files
|
||||||
|
_target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint
|
||||||
|
dirpath: null # directory to save the model file
|
||||||
|
filename: null # checkpoint filename
|
||||||
|
monitor: null # name of the logged metric which determines when model is improving
|
||||||
|
verbose: False # verbosity mode
|
||||||
|
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
||||||
|
save_top_k: 1 # save k best models (determined by above metric)
|
||||||
|
mode: "min" # "max" means higher metric value is better, can be also "min"
|
||||||
|
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
||||||
|
save_weights_only: False # if True, then only the model’s weights will be saved
|
||||||
|
every_n_train_steps: null # number of training steps between checkpoints
|
||||||
|
train_time_interval: null # checkpoints are monitored at the specified time interval
|
||||||
|
every_n_epochs: null # number of epochs between checkpoints
|
||||||
|
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
||||||
|
enable_version_counter: True # enables versioning for checkpoint names
|
||||||
7
configs/callbacks/learning_rate_monitor.yaml
Normal file
7
configs/callbacks/learning_rate_monitor.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html
|
||||||
|
|
||||||
|
learning_rate_monitor:
|
||||||
|
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
||||||
|
logging_interval: null # set to `epoch` or `step` to log learning rate of all optimizers at the same interval, or set to `null` to log at individual interval according to the interval key of each scheduler
|
||||||
|
log_momentum: false # whether to also log the momentum values of the optimizer, if the optimizer has the `momentum` or `betas` attribute
|
||||||
|
log_weight_decay: false # whether to also log the weight decay values of the optimizer, if the optimizer has the `weight_decay` attribute
|
||||||
18
configs/callbacks/model_checkpoint.yaml
Normal file
18
configs/callbacks/model_checkpoint.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
||||||
|
|
||||||
|
model_checkpoint:
|
||||||
|
_target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint
|
||||||
|
dirpath: null # directory to save the model file
|
||||||
|
filename: "best" # checkpoint filename
|
||||||
|
monitor: val_sampling/ligand_hit_score_2A_epoch # name of the logged metric which determines when model is improving
|
||||||
|
verbose: True # verbosity mode
|
||||||
|
save_last: False # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
||||||
|
save_top_k: 1 # save k best models (determined by above metric)
|
||||||
|
mode: "max" # "max" means higher metric value is better, can be also "min"
|
||||||
|
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
||||||
|
save_weights_only: False # if True, then only the model’s weights will be saved
|
||||||
|
every_n_train_steps: null # number of training steps between checkpoints
|
||||||
|
train_time_interval: null # checkpoints are monitored at the specified time interval
|
||||||
|
every_n_epochs: null # number of epochs between checkpoints
|
||||||
|
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
||||||
|
enable_version_counter: False # enables versioning for checkpoint names
|
||||||
5
configs/callbacks/model_summary.yaml
Normal file
5
configs/callbacks/model_summary.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
||||||
|
|
||||||
|
model_summary:
|
||||||
|
_target_: lightning.pytorch.callbacks.RichModelSummary
|
||||||
|
max_depth: 1 # the maximum depth of layer nesting that the summary will include
|
||||||
0
configs/callbacks/none.yaml
Normal file
0
configs/callbacks/none.yaml
Normal file
4
configs/callbacks/rich_progress_bar.yaml
Normal file
4
configs/callbacks/rich_progress_bar.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
|
||||||
|
|
||||||
|
rich_progress_bar:
|
||||||
|
_target_: lightning.pytorch.callbacks.RichProgressBar
|
||||||
35
configs/debug/default.yaml
Normal file
35
configs/debug/default.yaml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# default debugging setup, runs 1 full epoch
|
||||||
|
# other debugging configs can inherit from this one
|
||||||
|
|
||||||
|
# overwrite task name so debugging logs are stored in separate folder
|
||||||
|
task_name: "debug"
|
||||||
|
|
||||||
|
# disable callbacks and loggers during debugging
|
||||||
|
callbacks: null
|
||||||
|
logger: null
|
||||||
|
|
||||||
|
extras:
|
||||||
|
ignore_warnings: False
|
||||||
|
enforce_tags: False
|
||||||
|
|
||||||
|
# sets level of all command line loggers to 'DEBUG'
|
||||||
|
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
||||||
|
hydra:
|
||||||
|
job_logging:
|
||||||
|
root:
|
||||||
|
level: DEBUG
|
||||||
|
|
||||||
|
# use this to also set hydra loggers to 'DEBUG'
|
||||||
|
# verbose: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 1
|
||||||
|
accelerator: cpu # debuggers don't like gpus
|
||||||
|
devices: 1 # debuggers don't like multiprocessing
|
||||||
|
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
|
||||||
|
|
||||||
|
data:
|
||||||
|
num_workers: 0 # debuggers don't like multiprocessing
|
||||||
|
pin_memory: False # disable gpu memory pin
|
||||||
9
configs/debug/fdr.yaml
Normal file
9
configs/debug/fdr.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# runs 1 train, 1 validation and 1 test step
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
fast_dev_run: true
|
||||||
12
configs/debug/limit.yaml
Normal file
12
configs/debug/limit.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# uses only 1% of the training data and 5% of validation/test data
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 3
|
||||||
|
limit_train_batches: 0.01
|
||||||
|
limit_val_batches: 0.05
|
||||||
|
limit_test_batches: 0.05
|
||||||
13
configs/debug/overfit.yaml
Normal file
13
configs/debug/overfit.yaml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# overfits to 3 batches
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 20
|
||||||
|
overfit_batches: 3
|
||||||
|
|
||||||
|
# model ckpt and early stopping need to be disabled during overfitting
|
||||||
|
callbacks: null
|
||||||
12
configs/debug/profiler.yaml
Normal file
12
configs/debug/profiler.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# runs with execution time profiling
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 1
|
||||||
|
profiler: "simple"
|
||||||
|
# profiler: "advanced"
|
||||||
|
# profiler: "pytorch"
|
||||||
2
configs/environment/default.yaml
Normal file
2
configs/environment/default.yaml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
defaults:
|
||||||
|
- _self_
|
||||||
1
configs/environment/lightning.yaml
Normal file
1
configs/environment/lightning.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
_target_: lightning.fabric.plugins.environments.LightningEnvironment
|
||||||
3
configs/environment/slurm.yaml
Normal file
3
configs/environment/slurm.yaml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
_target_: lightning.fabric.plugins.environments.SLURMEnvironment
|
||||||
|
auto_requeue: true
|
||||||
|
requeue_signal: null
|
||||||
49
configs/eval.yaml
Normal file
49
configs/eval.yaml
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- data: combined # choose datamodule with `test_dataloader()` for evaluation
|
||||||
|
- model: flowdock_fm
|
||||||
|
- logger: null
|
||||||
|
- strategy: default
|
||||||
|
- trainer: default
|
||||||
|
- paths: default
|
||||||
|
- extras: default
|
||||||
|
- hydra: default
|
||||||
|
- environment: default
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
task_name: "eval"
|
||||||
|
|
||||||
|
tags: ["eval", "combined", "flowdock_fm"]
|
||||||
|
|
||||||
|
# passing checkpoint path is necessary for evaluation
|
||||||
|
ckpt_path: ???
|
||||||
|
|
||||||
|
# seed for random number generators in pytorch, numpy and python.random
|
||||||
|
seed: null
|
||||||
|
|
||||||
|
# model arguments
|
||||||
|
model:
|
||||||
|
cfg:
|
||||||
|
mol_encoder:
|
||||||
|
from_pretrained: false
|
||||||
|
protein_encoder:
|
||||||
|
from_pretrained: false
|
||||||
|
relational_reasoning:
|
||||||
|
from_pretrained: false
|
||||||
|
contact_predictor:
|
||||||
|
from_pretrained: false
|
||||||
|
score_head:
|
||||||
|
from_pretrained: false
|
||||||
|
confidence:
|
||||||
|
from_pretrained: false
|
||||||
|
affinity:
|
||||||
|
from_pretrained: false
|
||||||
|
task:
|
||||||
|
freeze_mol_encoder: true
|
||||||
|
freeze_protein_encoder: false
|
||||||
|
freeze_relational_reasoning: false
|
||||||
|
freeze_contact_predictor: false
|
||||||
|
freeze_score_head: false
|
||||||
|
freeze_confidence: true
|
||||||
|
freeze_affinity: false
|
||||||
35
configs/experiment/flowdock_fm.yaml
Normal file
35
configs/experiment/flowdock_fm.yaml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# to execute this experiment run:
|
||||||
|
# python train.py experiment=flowdock_fm
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- override /data: combined
|
||||||
|
- override /model: flowdock_fm
|
||||||
|
- override /callbacks: default
|
||||||
|
- override /trainer: default
|
||||||
|
|
||||||
|
# all parameters below will be merged with parameters from default configurations set above
|
||||||
|
# this allows you to overwrite only specified parameters
|
||||||
|
|
||||||
|
tags: ["flowdock_fm", "combined_dataset"]
|
||||||
|
|
||||||
|
seed: 496
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 300
|
||||||
|
check_val_every_n_epoch: 5 # NOTE: we increase this since validation steps involve full model sampling and evaluation
|
||||||
|
reload_dataloaders_every_n_epochs: 1
|
||||||
|
|
||||||
|
model:
|
||||||
|
optimizer:
|
||||||
|
lr: 2e-4
|
||||||
|
compile: false
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 16
|
||||||
|
|
||||||
|
logger:
|
||||||
|
wandb:
|
||||||
|
tags: ${tags}
|
||||||
|
group: "FlowDock-FM"
|
||||||
8
configs/extras/default.yaml
Normal file
8
configs/extras/default.yaml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# disable python warnings if they annoy you
|
||||||
|
ignore_warnings: False
|
||||||
|
|
||||||
|
# ask user for tags if none are provided in the config
|
||||||
|
enforce_tags: True
|
||||||
|
|
||||||
|
# pretty print config tree at the start of the run using Rich library
|
||||||
|
print_config: True
|
||||||
50
configs/hparams_search/combined_optuna.yaml
Normal file
50
configs/hparams_search/combined_optuna.yaml
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# example hyperparameter optimization of some experiment with Optuna:
|
||||||
|
# python train.py -m hparams_search=mnist_optuna experiment=example
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- override /hydra/sweeper: optuna
|
||||||
|
|
||||||
|
# choose metric which will be optimized by Optuna
|
||||||
|
# make sure this is the correct name of some metric logged in lightning module!
|
||||||
|
optimized_metric: "val/loss"
|
||||||
|
|
||||||
|
# here we define Optuna hyperparameter search
|
||||||
|
# it optimizes for value returned from function with @hydra.main decorator
|
||||||
|
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
||||||
|
hydra:
|
||||||
|
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
|
||||||
|
|
||||||
|
sweeper:
|
||||||
|
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
||||||
|
|
||||||
|
# storage URL to persist optimization results
|
||||||
|
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
||||||
|
storage: null
|
||||||
|
|
||||||
|
# name of the study to persist optimization results
|
||||||
|
study_name: null
|
||||||
|
|
||||||
|
# number of parallel workers
|
||||||
|
n_jobs: 1
|
||||||
|
|
||||||
|
# 'minimize' or 'maximize' the objective
|
||||||
|
direction: minimize
|
||||||
|
|
||||||
|
# total number of runs that will be executed
|
||||||
|
n_trials: 20
|
||||||
|
|
||||||
|
# choose Optuna hyperparameter sampler
|
||||||
|
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
|
||||||
|
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
||||||
|
sampler:
|
||||||
|
_target_: optuna.samplers.TPESampler
|
||||||
|
seed: 1234
|
||||||
|
n_startup_trials: 10 # number of random sampling runs before optimization starts
|
||||||
|
|
||||||
|
# define hyperparameter search space
|
||||||
|
params:
|
||||||
|
model.optimizer.lr: interval(0.0001, 0.1)
|
||||||
|
data.batch_size: choice(2, 4, 8, 16)
|
||||||
|
model.net.hidden_dim: choice(64, 128, 256)
|
||||||
19
configs/hydra/default.yaml
Normal file
19
configs/hydra/default.yaml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# https://hydra.cc/docs/configure_hydra/intro/
|
||||||
|
|
||||||
|
# enable color logging
|
||||||
|
defaults:
|
||||||
|
- override hydra_logging: colorlog
|
||||||
|
- override job_logging: colorlog
|
||||||
|
|
||||||
|
# output directory, generated dynamically on each run
|
||||||
|
run:
|
||||||
|
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
||||||
|
sweep:
|
||||||
|
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
||||||
|
subdir: ${hydra.job.num}
|
||||||
|
|
||||||
|
job_logging:
|
||||||
|
handlers:
|
||||||
|
file:
|
||||||
|
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
||||||
|
filename: ${hydra.runtime.output_dir}/${task_name}.log
|
||||||
0
configs/local/.gitkeep
Normal file
0
configs/local/.gitkeep
Normal file
28
configs/logger/aim.yaml
Normal file
28
configs/logger/aim.yaml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# https://aimstack.io/
|
||||||
|
|
||||||
|
# example usage in lightning module:
|
||||||
|
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
|
||||||
|
|
||||||
|
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
|
||||||
|
# `aim up`
|
||||||
|
|
||||||
|
aim:
|
||||||
|
_target_: aim.pytorch_lightning.AimLogger
|
||||||
|
repo: ${paths.root_dir} # .aim folder will be created here
|
||||||
|
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
|
||||||
|
|
||||||
|
# aim allows to group runs under experiment name
|
||||||
|
experiment: null # any string, set to "default" if not specified
|
||||||
|
|
||||||
|
train_metric_prefix: "train/"
|
||||||
|
val_metric_prefix: "val/"
|
||||||
|
test_metric_prefix: "test/"
|
||||||
|
|
||||||
|
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
|
||||||
|
system_tracking_interval: 10 # set to null to disable system metrics tracking
|
||||||
|
|
||||||
|
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
|
||||||
|
log_system_params: true
|
||||||
|
|
||||||
|
# enable/disable tracking console logs (default value is true)
|
||||||
|
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
|
||||||
12
configs/logger/comet.yaml
Normal file
12
configs/logger/comet.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# https://www.comet.ml
|
||||||
|
|
||||||
|
comet:
|
||||||
|
_target_: lightning.pytorch.loggers.comet.CometLogger
|
||||||
|
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
||||||
|
save_dir: "${paths.output_dir}"
|
||||||
|
project_name: "FlowDock_FM"
|
||||||
|
rest_api_key: null
|
||||||
|
# experiment_name: ""
|
||||||
|
experiment_key: null # set to resume experiment
|
||||||
|
offline: False
|
||||||
|
prefix: ""
|
||||||
7
configs/logger/csv.yaml
Normal file
7
configs/logger/csv.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# csv logger built in lightning
|
||||||
|
|
||||||
|
csv:
|
||||||
|
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
||||||
|
save_dir: "${paths.output_dir}"
|
||||||
|
name: "csv/"
|
||||||
|
prefix: ""
|
||||||
9
configs/logger/many_loggers.yaml
Normal file
9
configs/logger/many_loggers.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# train with many loggers at once
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
# - comet
|
||||||
|
- csv
|
||||||
|
# - mlflow
|
||||||
|
# - neptune
|
||||||
|
- tensorboard
|
||||||
|
- wandb
|
||||||
12
configs/logger/mlflow.yaml
Normal file
12
configs/logger/mlflow.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# https://mlflow.org
|
||||||
|
|
||||||
|
mlflow:
|
||||||
|
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
|
||||||
|
# experiment_name: ""
|
||||||
|
# run_name: ""
|
||||||
|
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
||||||
|
tags: null
|
||||||
|
# save_dir: "./mlruns"
|
||||||
|
prefix: ""
|
||||||
|
artifact_location: null
|
||||||
|
# run_id: ""
|
||||||
9
configs/logger/neptune.yaml
Normal file
9
configs/logger/neptune.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# https://neptune.ai
|
||||||
|
|
||||||
|
neptune:
|
||||||
|
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
|
||||||
|
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
||||||
|
project: username/FlowDock_FM
|
||||||
|
# name: ""
|
||||||
|
log_model_checkpoints: True
|
||||||
|
prefix: ""
|
||||||
10
configs/logger/tensorboard.yaml
Normal file
10
configs/logger/tensorboard.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# https://www.tensorflow.org/tensorboard/
|
||||||
|
|
||||||
|
tensorboard:
|
||||||
|
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
||||||
|
save_dir: "${paths.output_dir}/tensorboard/"
|
||||||
|
name: null
|
||||||
|
log_graph: False
|
||||||
|
default_hp_metric: True
|
||||||
|
prefix: ""
|
||||||
|
# version: ""
|
||||||
16
configs/logger/wandb.yaml
Normal file
16
configs/logger/wandb.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# https://wandb.ai
|
||||||
|
|
||||||
|
wandb:
|
||||||
|
_target_: lightning.pytorch.loggers.wandb.WandbLogger
|
||||||
|
# name: "" # name of the run (normally generated by wandb)
|
||||||
|
save_dir: "${paths.output_dir}"
|
||||||
|
offline: False
|
||||||
|
id: null # pass correct id to resume experiment!
|
||||||
|
anonymous: null # enable anonymous logging
|
||||||
|
project: "FlowDock_FM"
|
||||||
|
log_model: False # upload lightning ckpts
|
||||||
|
prefix: "" # a string to put at the beginning of metric keys
|
||||||
|
entity: "bml-lab" # set to name of your wandb team
|
||||||
|
group: ""
|
||||||
|
tags: []
|
||||||
|
job_type: ""
|
||||||
148
configs/model/flowdock_fm.yaml
Normal file
148
configs/model/flowdock_fm.yaml
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
_target_: flowdock.models.flowdock_fm_module.FlowDockFMLitModule
|
||||||
|
|
||||||
|
net:
|
||||||
|
_target_: flowdock.models.components.flowdock.FlowDock
|
||||||
|
_partial_: true
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
_target_: torch.optim.Adam
|
||||||
|
_partial_: true
|
||||||
|
lr: 2e-4
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
_target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
|
||||||
|
_partial_: true
|
||||||
|
T_0: ${int_divide:${trainer.max_epochs},15}
|
||||||
|
T_mult: 2
|
||||||
|
eta_min: 1e-8
|
||||||
|
verbose: true
|
||||||
|
|
||||||
|
# compile model for faster training with pytorch 2.0
|
||||||
|
compile: false
|
||||||
|
|
||||||
|
# model arguments
|
||||||
|
cfg:
|
||||||
|
mol_encoder:
|
||||||
|
node_channels: 512
|
||||||
|
pair_channels: 64
|
||||||
|
n_atom_encodings: 23
|
||||||
|
n_bond_encodings: 4
|
||||||
|
n_atom_pos_encodings: 6
|
||||||
|
n_stereo_encodings: 14
|
||||||
|
n_attention_heads: 8
|
||||||
|
attention_head_dim: 8
|
||||||
|
hidden_dim: 2048
|
||||||
|
max_path_integral_length: 6
|
||||||
|
n_transformer_stacks: 8
|
||||||
|
n_heads: 8
|
||||||
|
n_patches: ${data.n_lig_patches}
|
||||||
|
checkpoint_file: ${oc.env:PROJECT_ROOT}/checkpoints/neuralplexermodels_downstream_datasets_predictions/models/complex_structure_prediction.ckpt
|
||||||
|
megamolbart: null
|
||||||
|
from_pretrained: true
|
||||||
|
|
||||||
|
protein_encoder:
|
||||||
|
use_esm_embedding: true
|
||||||
|
esm_version: esm2_t33_650M_UR50D
|
||||||
|
esm_repr_layer: 33
|
||||||
|
residue_dim: 512
|
||||||
|
plm_embed_dim: 1280
|
||||||
|
n_aa_types: 21
|
||||||
|
atom_padding_dim: 37
|
||||||
|
n_atom_types: 4 # [C, N, O, S]
|
||||||
|
n_patches: ${data.n_protein_patches}
|
||||||
|
n_attention_heads: 8
|
||||||
|
scalar_dim: 16
|
||||||
|
point_dim: 4
|
||||||
|
pair_dim: 64
|
||||||
|
n_heads: 8
|
||||||
|
head_dim: 8
|
||||||
|
max_residue_degree: 32
|
||||||
|
n_encoder_stacks: 2
|
||||||
|
from_pretrained: true
|
||||||
|
|
||||||
|
relational_reasoning:
|
||||||
|
from_pretrained: true
|
||||||
|
|
||||||
|
contact_predictor:
|
||||||
|
n_stacks: 4
|
||||||
|
dropout: 0.01
|
||||||
|
from_pretrained: true
|
||||||
|
|
||||||
|
score_head:
|
||||||
|
fiber_dim: 64
|
||||||
|
hidden_dim: 512
|
||||||
|
n_stacks: 4
|
||||||
|
max_atom_degree: 8
|
||||||
|
from_pretrained: true
|
||||||
|
|
||||||
|
confidence:
|
||||||
|
enabled: true # whether the confidence prediction head is to be used e.g., during inference
|
||||||
|
fiber_dim: ${..score_head.fiber_dim}
|
||||||
|
hidden_dim: ${..score_head.hidden_dim}
|
||||||
|
n_stacks: ${..score_head.n_stacks}
|
||||||
|
from_pretrained: true
|
||||||
|
|
||||||
|
affinity:
|
||||||
|
enabled: true # whether the affinity prediction head is to be used e.g., during inference
|
||||||
|
fiber_dim: ${..score_head.fiber_dim}
|
||||||
|
hidden_dim: ${..score_head.hidden_dim}
|
||||||
|
n_stacks: ${..score_head.n_stacks}
|
||||||
|
ligand_pooling: sum # NOTE: must be a value in (`sum`, `mean`)
|
||||||
|
dropout: 0.01
|
||||||
|
from_pretrained: false
|
||||||
|
|
||||||
|
latent_model: default
|
||||||
|
prior_type: esmfold # NOTE: must be a value in (`gaussian`, `harmonic`, `esmfold`)
|
||||||
|
|
||||||
|
task:
|
||||||
|
pretrained: null
|
||||||
|
ligands: true
|
||||||
|
epoch_frac: ${data.epoch_frac}
|
||||||
|
label_name: null
|
||||||
|
sequence_crop_size: 1600
|
||||||
|
edge_crop_size: ${data.edge_crop_size} # NOTE: for dynamic batching via `max_n_edges`
|
||||||
|
max_masking_rate: 0.0
|
||||||
|
n_modes: 8
|
||||||
|
dropout: 0.01
|
||||||
|
# pretraining: true
|
||||||
|
freeze_mol_encoder: true
|
||||||
|
freeze_protein_encoder: false
|
||||||
|
freeze_relational_reasoning: false
|
||||||
|
freeze_contact_predictor: true
|
||||||
|
freeze_score_head: false
|
||||||
|
freeze_confidence: true
|
||||||
|
freeze_affinity: false
|
||||||
|
use_template: true
|
||||||
|
use_plddt: false
|
||||||
|
block_contact_decoding_scheme: "beam"
|
||||||
|
frozen_ligand_backbone: false
|
||||||
|
frozen_protein_backbone: false
|
||||||
|
single_protein_batch: true
|
||||||
|
contact_loss_weight: 0.2
|
||||||
|
global_score_loss_weight: 0.2
|
||||||
|
ligand_score_loss_weight: 0.1
|
||||||
|
clash_loss_weight: 10.0
|
||||||
|
local_distgeom_loss_weight: 10.0
|
||||||
|
drmsd_loss_weight: 2.0
|
||||||
|
distogram_loss_weight: 0.05
|
||||||
|
plddt_loss_weight: 1.0
|
||||||
|
affinity_loss_weight: 0.1
|
||||||
|
aux_batch_freq: 10 # NOTE: e.g., `10` means that auxiliary estimation losses will be calculated every 10th batch
|
||||||
|
global_max_sigma: 5.0
|
||||||
|
internal_max_sigma: 2.0
|
||||||
|
detect_covalent: true
|
||||||
|
# runtime configs
|
||||||
|
float32_matmul_precision: highest
|
||||||
|
# sampling configs
|
||||||
|
constrained_inpainting: false
|
||||||
|
visualize_generated_samples: true
|
||||||
|
# testing configs
|
||||||
|
loss_mode: auxiliary_estimation # NOTE: must be one of (`structure_prediction`, `auxiliary_estimation`, `auxiliary_estimation_without_structure_prediction`)
|
||||||
|
num_steps: 20
|
||||||
|
sampler: VDODE # NOTE: must be one of (`ODE`, `VDODE`)
|
||||||
|
sampler_eta: 1.0 # NOTE: this corresponds to the variance diminishing factor for the `VDODE` sampler, which offers a trade-off between exploration (1.0) and exploitation (> 1.0)
|
||||||
|
start_time: 1.0
|
||||||
|
eval_structure_prediction: false # whether to evaluate structure prediction performance (`true`) or instead only binding affinity performance (`false`) when running a test epoch
|
||||||
|
# overfitting configs
|
||||||
|
overfitting_example_name: ${data.overfitting_example_name}
|
||||||
21
configs/paths/default.yaml
Normal file
21
configs/paths/default.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# path to root directory
|
||||||
|
# this requires PROJECT_ROOT environment variable to exist
|
||||||
|
# you can replace it with "." if you want the root to be the current working directory
|
||||||
|
root_dir: ${oc.env:PROJECT_ROOT}
|
||||||
|
|
||||||
|
# path to data directory
|
||||||
|
data_dir: ${paths.root_dir}/data/
|
||||||
|
|
||||||
|
# path to logging directory
|
||||||
|
log_dir: ${paths.root_dir}/logs/
|
||||||
|
|
||||||
|
# path to output directory, created dynamically by hydra
|
||||||
|
# path generation pattern is specified in `configs/hydra/default.yaml`
|
||||||
|
# use it to store all files generated during the run, like ckpts and metrics
|
||||||
|
output_dir: ${hydra:runtime.output_dir}
|
||||||
|
|
||||||
|
# path to working directory
|
||||||
|
work_dir: ${hydra:runtime.cwd}
|
||||||
|
|
||||||
|
# path to the directory containing the model checkpoints
|
||||||
|
ckpt_dir: ${paths.root_dir}/checkpoints/
|
||||||
78
configs/sample.yaml
Normal file
78
configs/sample.yaml
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- data: combined # NOTE: this will not be referenced during sampling
|
||||||
|
- model: flowdock_fm
|
||||||
|
- logger: null
|
||||||
|
- strategy: default
|
||||||
|
- trainer: default
|
||||||
|
- paths: default
|
||||||
|
- extras: default
|
||||||
|
- hydra: default
|
||||||
|
- environment: default
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
task_name: "sample"
|
||||||
|
|
||||||
|
tags: ["sample", "combined", "flowdock_fm"]
|
||||||
|
|
||||||
|
# passing checkpoint path is necessary for sampling
|
||||||
|
ckpt_path: ???
|
||||||
|
|
||||||
|
# seed for random number generators in pytorch, numpy and python.random
|
||||||
|
seed: null
|
||||||
|
|
||||||
|
# sampling arguments
|
||||||
|
sampling_task: batched_structure_sampling # NOTE: must be one of (`batched_structure_sampling`)
|
||||||
|
sample_id: null # optional identifier for the sampling run
|
||||||
|
input_receptor: null # NOTE: must be either a protein sequence string (with chains separated by `|`) or a path to a PDB file (from which protein chain sequences will be parsed)
|
||||||
|
input_ligand: null # NOTE: must be either a ligand SMILES string (with chains/fragments separated by `|`) or a path to a ligand SDF file (from which ligand SMILES will be parsed)
|
||||||
|
input_template: null # path to a protein PDB file to use as a starting protein template for sampling (with an ESMFold prior model)
|
||||||
|
out_path: ??? # path to which to save the output PDB and SDF files
|
||||||
|
n_samples: 5 # number of structures to sample
|
||||||
|
chunk_size: 5 # number of structures to concurrently sample within each batch segment - NOTE: `n_samples` should be evenly divisible by `chunk_size` to produce the expected number of outputs
|
||||||
|
num_steps: 40 # number of sampling steps to perform
|
||||||
|
latent_model: null # if provided, the type of latent model to use
|
||||||
|
sampler: VDODE # sampling algorithm to use - NOTE: must be one of (`ODE`, `VDODE`)
|
||||||
|
sampler_eta: 1.0 # the variance diminishing factor for the `VDODE` sampler - NOTE: offers a trade-off between exploration (1.0) and exploitation (> 1.0)
|
||||||
|
start_time: "1.0" # time at which to start sampling
|
||||||
|
max_chain_encoding_k: -1 # maximum number of chains to encode in the chain encoding
|
||||||
|
exact_prior: false # whether to use the "ground-truth" binding site for sampling, if available
|
||||||
|
prior_type: esmfold # the type of prior to use for sampling - NOTE: must be one of (`gaussian`, `harmonic`, `esmfold`)
|
||||||
|
discard_ligand: false # whether to discard a given input ligand during sampling
|
||||||
|
discard_sdf_coords: true # whether to discard the input ligand's 3D structure during sampling, if available
|
||||||
|
detect_covalent: false # whether to detect covalent bonds between the input receptor and ligand
|
||||||
|
use_template: true # whether to use the input protein template for sampling if one is provided
|
||||||
|
separate_pdb: true # whether to save separate PDB files for each sampled structure instead of simply a single PDB file
|
||||||
|
rank_outputs_by_confidence: true # whether to rank the sampled structures by estimated confidence
|
||||||
|
plddt_ranking_type: ligand # the type of plDDT ranking to apply to generated samples - NOTE: must be one of (`protein`, `ligand`, `protein_ligand`)
|
||||||
|
visualize_sample_trajectories: false # whether to visualize the generated samples' trajectories
|
||||||
|
auxiliary_estimation_only: false # whether to only estimate auxiliary outputs (e.g., confidence, affinity) for the input (generated) samples (potentially derived from external sources)
|
||||||
|
csv_path: null # if provided, the CSV file (with columns `id`, `input_receptor`, `input_ligand`, and `input_template`) from which to parse input receptors and ligands for sampling, overriding the `input_receptor` and `input_ligand` arguments in the process and ignoring the `input_template` for now
|
||||||
|
esmfold_chunk_size: null # chunks axial attention computation to reduce memory usage from O(L^2) to O(L); equivalent to running a for loop over chunks of of each dimension; lower values will result in lower memory usage at the cost of speed; recommended values: 128, 64, 32
|
||||||
|
|
||||||
|
# model arguments
|
||||||
|
model:
|
||||||
|
cfg:
|
||||||
|
mol_encoder:
|
||||||
|
from_pretrained: false
|
||||||
|
protein_encoder:
|
||||||
|
from_pretrained: false
|
||||||
|
relational_reasoning:
|
||||||
|
from_pretrained: false
|
||||||
|
contact_predictor:
|
||||||
|
from_pretrained: false
|
||||||
|
score_head:
|
||||||
|
from_pretrained: false
|
||||||
|
confidence:
|
||||||
|
from_pretrained: false
|
||||||
|
affinity:
|
||||||
|
from_pretrained: false
|
||||||
|
task:
|
||||||
|
freeze_mol_encoder: true
|
||||||
|
freeze_protein_encoder: false
|
||||||
|
freeze_relational_reasoning: false
|
||||||
|
freeze_contact_predictor: false
|
||||||
|
freeze_score_head: false
|
||||||
|
freeze_confidence: true
|
||||||
|
freeze_affinity: false
|
||||||
4
configs/strategy/ddp.yaml
Normal file
4
configs/strategy/ddp.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
_target_: lightning.pytorch.strategies.DDPStrategy
|
||||||
|
static_graph: false
|
||||||
|
gradient_as_bucket_view: false
|
||||||
|
find_unused_parameters: true
|
||||||
5
configs/strategy/ddp_spawn.yaml
Normal file
5
configs/strategy/ddp_spawn.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
_target_: lightning.pytorch.strategies.DDPStrategy
|
||||||
|
static_graph: false
|
||||||
|
gradient_as_bucket_view: false
|
||||||
|
find_unused_parameters: true
|
||||||
|
start_method: spawn
|
||||||
5
configs/strategy/deepspeed.yaml
Normal file
5
configs/strategy/deepspeed.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
_target_: lightning.pytorch.strategies.DeepSpeedStrategy
|
||||||
|
stage: 2
|
||||||
|
offload_optimizer: false
|
||||||
|
allgather_bucket_size: 200_000_000
|
||||||
|
reduce_bucket_size: 200_000_000
|
||||||
2
configs/strategy/default.yaml
Normal file
2
configs/strategy/default.yaml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
defaults:
|
||||||
|
- _self_
|
||||||
12
configs/strategy/fsdp.yaml
Normal file
12
configs/strategy/fsdp.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
_target_: lightning.pytorch.strategies.FSDPStrategy
|
||||||
|
sharding_strategy: ${resolve_variable:torch.distributed.fsdp.ShardingStrategy.FULL_SHARD}
|
||||||
|
cpu_offload: null
|
||||||
|
activation_checkpointing: null
|
||||||
|
mixed_precision:
|
||||||
|
_target_: torch.distributed.fsdp.MixedPrecision
|
||||||
|
param_dtype: null
|
||||||
|
reduce_dtype: null
|
||||||
|
buffer_dtype: null
|
||||||
|
keep_low_precision_grads: false
|
||||||
|
cast_forward_inputs: false
|
||||||
|
cast_root_forward_inputs: true
|
||||||
4
configs/strategy/optimized_ddp.yaml
Normal file
4
configs/strategy/optimized_ddp.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
_target_: lightning.pytorch.strategies.DDPStrategy
|
||||||
|
static_graph: true
|
||||||
|
gradient_as_bucket_view: true
|
||||||
|
find_unused_parameters: false
|
||||||
51
configs/train.yaml
Normal file
51
configs/train.yaml
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# specify here default configuration
|
||||||
|
# order of defaults determines the order in which configs override each other
|
||||||
|
defaults:
|
||||||
|
- _self_
|
||||||
|
- data: combined
|
||||||
|
- model: flowdock_fm
|
||||||
|
- callbacks: default
|
||||||
|
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
||||||
|
- strategy: default
|
||||||
|
- trainer: default
|
||||||
|
- paths: default
|
||||||
|
- extras: default
|
||||||
|
- hydra: default
|
||||||
|
- environment: default
|
||||||
|
|
||||||
|
# experiment configs allow for version control of specific hyperparameters
|
||||||
|
# e.g. best hyperparameters for given model and datamodule
|
||||||
|
- experiment: null
|
||||||
|
|
||||||
|
# config for hyperparameter optimization
|
||||||
|
- hparams_search: null
|
||||||
|
|
||||||
|
# optional local config for machine/user specific settings
|
||||||
|
# it's optional since it doesn't need to exist and is excluded from version control
|
||||||
|
- optional local: default
|
||||||
|
|
||||||
|
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
||||||
|
- debug: null
|
||||||
|
|
||||||
|
# task name, determines output directory path
|
||||||
|
task_name: "train"
|
||||||
|
|
||||||
|
# tags to help you identify your experiments
|
||||||
|
# you can overwrite this in experiment configs
|
||||||
|
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
||||||
|
tags: ["train", "combined", "flowdock_fm"]
|
||||||
|
|
||||||
|
# set False to skip model training
|
||||||
|
train: True
|
||||||
|
|
||||||
|
# evaluate on test set, using best model weights achieved during training
|
||||||
|
# lightning chooses best weights based on the metric specified in checkpoint callback
|
||||||
|
test: False
|
||||||
|
|
||||||
|
# simply provide checkpoint path to resume training
|
||||||
|
ckpt_path: null
|
||||||
|
|
||||||
|
# seed for random number generators in pytorch, numpy and python.random
|
||||||
|
seed: null
|
||||||
5
configs/trainer/cpu.yaml
Normal file
5
configs/trainer/cpu.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
accelerator: cpu
|
||||||
|
devices: 1
|
||||||
9
configs/trainer/ddp.yaml
Normal file
9
configs/trainer/ddp.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
strategy: ddp
|
||||||
|
|
||||||
|
accelerator: gpu
|
||||||
|
devices: 4
|
||||||
|
num_nodes: 1
|
||||||
|
sync_batchnorm: True
|
||||||
7
configs/trainer/ddp_sim.yaml
Normal file
7
configs/trainer/ddp_sim.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
# simulate DDP on CPU, useful for debugging
|
||||||
|
accelerator: cpu
|
||||||
|
devices: 2
|
||||||
|
strategy: ddp_spawn
|
||||||
9
configs/trainer/ddp_spawn.yaml
Normal file
9
configs/trainer/ddp_spawn.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
strategy: ddp_spawn
|
||||||
|
|
||||||
|
accelerator: gpu
|
||||||
|
devices: 4
|
||||||
|
num_nodes: 1
|
||||||
|
sync_batchnorm: True
|
||||||
29
configs/trainer/default.yaml
Normal file
29
configs/trainer/default.yaml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
_target_: lightning.pytorch.trainer.Trainer
|
||||||
|
|
||||||
|
default_root_dir: ${paths.output_dir}
|
||||||
|
|
||||||
|
min_epochs: 1 # prevents early stopping
|
||||||
|
max_epochs: 10
|
||||||
|
|
||||||
|
accelerator: cpu
|
||||||
|
devices: 1
|
||||||
|
|
||||||
|
# mixed precision for extra speed-up
|
||||||
|
# precision: 16
|
||||||
|
|
||||||
|
# perform a validation loop every N training epochs
|
||||||
|
check_val_every_n_epoch: 1
|
||||||
|
|
||||||
|
# set True to to ensure deterministic results
|
||||||
|
# makes training slower but gives more reproducibility than just setting seeds
|
||||||
|
deterministic: False
|
||||||
|
|
||||||
|
# determine the frequency of how often to reload the dataloaders
|
||||||
|
reload_dataloaders_every_n_epochs: 1
|
||||||
|
|
||||||
|
# if `gradient_clip_val` is not `null`, gradients will be norm-clipped during training
|
||||||
|
gradient_clip_algorithm: norm
|
||||||
|
gradient_clip_val: 1.0
|
||||||
|
|
||||||
|
# if `num_sanity_val_steps` is > 0, then specifically that many validation steps will be run during the first call to `trainer.fit`
|
||||||
|
num_sanity_val_steps: 0
|
||||||
5
configs/trainer/gpu.yaml
Normal file
5
configs/trainer/gpu.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
accelerator: gpu
|
||||||
|
devices: 1
|
||||||
5
configs/trainer/mps.yaml
Normal file
5
configs/trainer/mps.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
|
||||||
|
accelerator: mps
|
||||||
|
devices: 1
|
||||||
540
environments/flowdock_environment.yaml
Normal file
540
environments/flowdock_environment.yaml
Normal file
@@ -0,0 +1,540 @@
|
|||||||
|
name: FlowDock
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- pyg
|
||||||
|
- senyan.dev
|
||||||
|
- nvidia
|
||||||
|
- conda-forge
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=conda_forge
|
||||||
|
- _openmp_mutex=4.5=3_kmp_llvm
|
||||||
|
- aiohappyeyeballs=2.6.1=pyhd8ed1ab_0
|
||||||
|
- aiohttp=3.11.13=py310h89163eb_0
|
||||||
|
- aiosignal=1.3.2=pyhd8ed1ab_0
|
||||||
|
- alsa-lib=1.2.13=hb9d3cd8_0
|
||||||
|
- ambertools=24.8=cuda_None_nompi_py310h834fefc_101
|
||||||
|
- annotated-types=0.7.0=pyhd8ed1ab_1
|
||||||
|
- anyio=4.8.0=pyhd8ed1ab_0
|
||||||
|
- aom=3.9.1=hac33072_0
|
||||||
|
- argon2-cffi=23.1.0=pyhd8ed1ab_1
|
||||||
|
- argon2-cffi-bindings=21.2.0=py310ha75aee5_5
|
||||||
|
- arpack=3.9.1=nompi_hf03ea27_102
|
||||||
|
- arrow=1.3.0=pyhd8ed1ab_1
|
||||||
|
- asttokens=3.0.0=pyhd8ed1ab_1
|
||||||
|
- async-lru=2.0.4=pyhd8ed1ab_1
|
||||||
|
- async-timeout=5.0.1=pyhd8ed1ab_1
|
||||||
|
- attr=2.5.1=h166bdaf_1
|
||||||
|
- attrs=25.3.0=pyh71513ae_0
|
||||||
|
- babel=2.17.0=pyhd8ed1ab_0
|
||||||
|
- beautifulsoup4=4.13.3=pyha770c72_0
|
||||||
|
- blas=2.116=mkl
|
||||||
|
- blas-devel=3.9.0=16_linux64_mkl
|
||||||
|
- bleach=6.2.0=pyh29332c3_4
|
||||||
|
- bleach-with-css=6.2.0=h82add2a_4
|
||||||
|
- blosc=1.21.6=he440d0b_1
|
||||||
|
- brotli=1.1.0=hb9d3cd8_2
|
||||||
|
- brotli-bin=1.1.0=hb9d3cd8_2
|
||||||
|
- brotli-python=1.1.0=py310hf71b8c6_2
|
||||||
|
- bson=0.5.10=pyhd8ed1ab_0
|
||||||
|
- bzip2=1.0.8=h4bc722e_7
|
||||||
|
- c-ares=1.34.4=hb9d3cd8_0
|
||||||
|
- c-blosc2=2.15.2=h3122c55_1
|
||||||
|
- ca-certificates=2025.1.31=hbcca054_0
|
||||||
|
- cached-property=1.5.2=hd8ed1ab_1
|
||||||
|
- cached_property=1.5.2=pyha770c72_1
|
||||||
|
- cachetools=5.5.2=pyhd8ed1ab_0
|
||||||
|
- cairo=1.18.4=h3394656_0
|
||||||
|
- certifi=2025.1.31=pyhd8ed1ab_0
|
||||||
|
- cffi=1.17.1=py310h8deb56e_0
|
||||||
|
- chardet=5.2.0=pyhd8ed1ab_3
|
||||||
|
- charset-normalizer=3.4.1=pyhd8ed1ab_0
|
||||||
|
- colorama=0.4.6=pyhd8ed1ab_1
|
||||||
|
- comm=0.2.2=pyhd8ed1ab_1
|
||||||
|
- contourpy=1.3.1=py310h3788b33_0
|
||||||
|
- cpython=3.10.16=py310hd8ed1ab_1
|
||||||
|
- cuda-cudart=11.8.89=0
|
||||||
|
- cuda-cupti=11.8.87=0
|
||||||
|
- cuda-libraries=11.8.0=0
|
||||||
|
- cuda-nvrtc=11.8.89=0
|
||||||
|
- cuda-nvtx=11.8.86=0
|
||||||
|
- cuda-runtime=11.8.0=0
|
||||||
|
- cuda-version=11.8=h70ddcb2_3
|
||||||
|
- cudatoolkit=11.8.0=h4ba93d1_13
|
||||||
|
- cudatoolkit-dev=11.8.0=h1fa729e_6
|
||||||
|
- cycler=0.12.1=pyhd8ed1ab_1
|
||||||
|
- cyrus-sasl=2.1.27=h54b06d7_7
|
||||||
|
- dav1d=1.2.1=hd590300_0
|
||||||
|
- dbus=1.13.6=h5008d03_3
|
||||||
|
- debugpy=1.8.13=py310hf71b8c6_0
|
||||||
|
- decorator=5.2.1=pyhd8ed1ab_0
|
||||||
|
- defusedxml=0.7.1=pyhd8ed1ab_0
|
||||||
|
- deprecated=1.2.18=pyhd8ed1ab_0
|
||||||
|
- exceptiongroup=1.2.2=pyhd8ed1ab_1
|
||||||
|
- expat=2.6.4=h5888daf_0
|
||||||
|
- ffmpeg=7.1.1=gpl_h24e5c1d_701
|
||||||
|
- fftw=3.3.10=nompi_hf1063bd_110
|
||||||
|
- filelock=3.18.0=pyhd8ed1ab_0
|
||||||
|
- flexcache=0.3=pyhd8ed1ab_1
|
||||||
|
- flexparser=0.4=pyhd8ed1ab_1
|
||||||
|
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
||||||
|
- font-ttf-inconsolata=3.000=h77eed37_0
|
||||||
|
- font-ttf-source-code-pro=2.038=h77eed37_0
|
||||||
|
- font-ttf-ubuntu=0.83=h77eed37_3
|
||||||
|
- fontconfig=2.15.0=h7e30c49_1
|
||||||
|
- fonts-conda-ecosystem=1=0
|
||||||
|
- fonts-conda-forge=1=0
|
||||||
|
- fonttools=4.56.0=py310h89163eb_0
|
||||||
|
- fqdn=1.5.1=pyhd8ed1ab_1
|
||||||
|
- freetype=2.13.3=h48d6fc4_0
|
||||||
|
- freetype-py=2.3.0=pyhd8ed1ab_0
|
||||||
|
- fribidi=1.0.10=h36c2ea0_0
|
||||||
|
- frozenlist=1.5.0=py310h89163eb_1
|
||||||
|
- fsspec=2025.3.0=pyhd8ed1ab_0
|
||||||
|
- gdk-pixbuf=2.42.12=hb9ae30d_0
|
||||||
|
- gettext=0.23.1=h5888daf_0
|
||||||
|
- gettext-tools=0.23.1=h5888daf_0
|
||||||
|
- gmp=6.3.0=hac33072_2
|
||||||
|
- gmpy2=2.1.5=py310he8512ff_3
|
||||||
|
- graphite2=1.3.13=h59595ed_1003
|
||||||
|
- greenlet=3.1.1=py310hf71b8c6_1
|
||||||
|
- h11=0.14.0=pyhd8ed1ab_1
|
||||||
|
- h2=4.2.0=pyhd8ed1ab_0
|
||||||
|
- harfbuzz=10.4.0=h76408a6_0
|
||||||
|
- hdf4=4.2.15=h2a13503_7
|
||||||
|
- hdf5=1.14.4=nompi_h2d575fe_105
|
||||||
|
- hpack=4.1.0=pyhd8ed1ab_0
|
||||||
|
- httpcore=1.0.7=pyh29332c3_1
|
||||||
|
- httpx=0.28.1=pyhd8ed1ab_0
|
||||||
|
- hyperframe=6.1.0=pyhd8ed1ab_0
|
||||||
|
- icu=75.1=he02047a_0
|
||||||
|
- idna=3.10=pyhd8ed1ab_1
|
||||||
|
- importlib-metadata=8.6.1=pyha770c72_0
|
||||||
|
- importlib_resources=6.5.2=pyhd8ed1ab_0
|
||||||
|
- ipykernel=6.29.5=pyh3099207_0
|
||||||
|
- ipython=8.34.0=pyh907856f_0
|
||||||
|
- isoduration=20.11.0=pyhd8ed1ab_1
|
||||||
|
- jack=1.9.22=h7c63dc7_2
|
||||||
|
- jedi=0.19.2=pyhd8ed1ab_1
|
||||||
|
- jinja2=3.1.6=pyhd8ed1ab_0
|
||||||
|
- joblib=1.4.2=pyhd8ed1ab_1
|
||||||
|
- json5=0.10.0=pyhd8ed1ab_1
|
||||||
|
- jsonpointer=3.0.0=py310hff52083_1
|
||||||
|
- jsonschema=4.23.0=pyhd8ed1ab_1
|
||||||
|
- jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
|
||||||
|
- jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
|
||||||
|
- jupyter-lsp=2.2.5=pyhd8ed1ab_1
|
||||||
|
- jupyter_client=8.6.3=pyhd8ed1ab_1
|
||||||
|
- jupyter_core=5.7.2=pyh31011fe_1
|
||||||
|
- jupyter_events=0.12.0=pyh29332c3_0
|
||||||
|
- jupyter_server=2.15.0=pyhd8ed1ab_0
|
||||||
|
- jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
|
||||||
|
- jupyterlab=4.3.6=pyhd8ed1ab_0
|
||||||
|
- jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
|
||||||
|
- jupyterlab_server=2.27.3=pyhd8ed1ab_1
|
||||||
|
- jupyterlab_widgets=3.0.13=pyhd8ed1ab_1
|
||||||
|
- kernel-headers_linux-64=3.10.0=he073ed8_18
|
||||||
|
- keyutils=1.6.1=h166bdaf_0
|
||||||
|
- kiwisolver=1.4.7=py310h3788b33_0
|
||||||
|
- krb5=1.21.3=h659f571_0
|
||||||
|
- lame=3.100=h166bdaf_1003
|
||||||
|
- lcms2=2.17=h717163a_0
|
||||||
|
- ld_impl_linux-64=2.43=h712a8e2_4
|
||||||
|
- lerc=4.0.0=h27087fc_0
|
||||||
|
- level-zero=1.21.2=h84d6215_0
|
||||||
|
- libabseil=20250127.0=cxx17_hbbce691_0
|
||||||
|
- libaec=1.1.3=h59595ed_0
|
||||||
|
- libasprintf=0.23.1=h8e693c7_0
|
||||||
|
- libasprintf-devel=0.23.1=h8e693c7_0
|
||||||
|
- libass=0.17.3=hba53ac1_1
|
||||||
|
- libblas=3.9.0=16_linux64_mkl
|
||||||
|
- libboost=1.86.0=h6c02f8c_3
|
||||||
|
- libboost-python=1.86.0=py310ha2bacc8_3
|
||||||
|
- libbrotlicommon=1.1.0=hb9d3cd8_2
|
||||||
|
- libbrotlidec=1.1.0=hb9d3cd8_2
|
||||||
|
- libbrotlienc=1.1.0=hb9d3cd8_2
|
||||||
|
- libcap=2.75=h39aace5_0
|
||||||
|
- libcblas=3.9.0=16_linux64_mkl
|
||||||
|
- libcublas=11.11.3.6=0
|
||||||
|
- libcufft=10.9.0.58=0
|
||||||
|
- libcufile=1.9.1.3=0
|
||||||
|
- libcurand=10.3.5.147=0
|
||||||
|
- libcurl=8.12.1=h332b0f4_0
|
||||||
|
- libcusolver=11.4.1.48=0
|
||||||
|
- libcusparse=11.7.5.86=0
|
||||||
|
- libdb=6.2.32=h9c3ff4c_0
|
||||||
|
- libdeflate=1.23=h4ddbbb0_0
|
||||||
|
- libdrm=2.4.124=hb9d3cd8_0
|
||||||
|
- libedit=3.1.20250104=pl5321h7949ede_0
|
||||||
|
- libegl=1.7.0=ha4b6fd6_2
|
||||||
|
- libev=4.33=hd590300_2
|
||||||
|
- libexpat=2.6.4=h5888daf_0
|
||||||
|
- libffi=3.4.6=h2dba641_0
|
||||||
|
- libflac=1.4.3=h59595ed_0
|
||||||
|
- libgcc=14.2.0=h767d61c_2
|
||||||
|
- libgcc-ng=14.2.0=h69a702a_2
|
||||||
|
- libgcrypt-lib=1.11.0=hb9d3cd8_2
|
||||||
|
- libgettextpo=0.23.1=h5888daf_0
|
||||||
|
- libgettextpo-devel=0.23.1=h5888daf_0
|
||||||
|
- libgfortran=14.2.0=h69a702a_2
|
||||||
|
- libgfortran-ng=14.2.0=h69a702a_2
|
||||||
|
- libgfortran5=14.2.0=hf1ad2bd_2
|
||||||
|
- libgl=1.7.0=ha4b6fd6_2
|
||||||
|
- libglib=2.82.2=h2ff4ddf_1
|
||||||
|
- libglvnd=1.7.0=ha4b6fd6_2
|
||||||
|
- libglx=1.7.0=ha4b6fd6_2
|
||||||
|
- libgomp=14.2.0=h767d61c_2
|
||||||
|
- libgpg-error=1.51=hbd13f7d_1
|
||||||
|
- libhwloc=2.11.2=default_h0d58e46_1001
|
||||||
|
- libiconv=1.18=h4ce23a2_1
|
||||||
|
- libjpeg-turbo=3.0.0=hd590300_1
|
||||||
|
- liblapack=3.9.0=16_linux64_mkl
|
||||||
|
- liblapacke=3.9.0=16_linux64_mkl
|
||||||
|
- liblzma=5.6.4=hb9d3cd8_0
|
||||||
|
- libnetcdf=4.9.2=nompi_h5ddbaa4_116
|
||||||
|
- libnghttp2=1.64.0=h161d5f1_0
|
||||||
|
- libnpp=11.8.0.86=0
|
||||||
|
- libnsl=2.0.1=hd590300_0
|
||||||
|
- libntlm=1.8=hb9d3cd8_0
|
||||||
|
- libnvjpeg=11.9.0.86=0
|
||||||
|
- libogg=1.3.5=h4ab18f5_0
|
||||||
|
- libopenvino=2025.0.0=hdc3f47d_3
|
||||||
|
- libopenvino-auto-batch-plugin=2025.0.0=h4d9b6c2_3
|
||||||
|
- libopenvino-auto-plugin=2025.0.0=h4d9b6c2_3
|
||||||
|
- libopenvino-hetero-plugin=2025.0.0=h981d57b_3
|
||||||
|
- libopenvino-intel-cpu-plugin=2025.0.0=hdc3f47d_3
|
||||||
|
- libopenvino-intel-gpu-plugin=2025.0.0=hdc3f47d_3
|
||||||
|
- libopenvino-intel-npu-plugin=2025.0.0=hdc3f47d_3
|
||||||
|
- libopenvino-ir-frontend=2025.0.0=h981d57b_3
|
||||||
|
- libopenvino-onnx-frontend=2025.0.0=h0e684df_3
|
||||||
|
- libopenvino-paddle-frontend=2025.0.0=h0e684df_3
|
||||||
|
- libopenvino-pytorch-frontend=2025.0.0=h5888daf_3
|
||||||
|
- libopenvino-tensorflow-frontend=2025.0.0=h684f15b_3
|
||||||
|
- libopenvino-tensorflow-lite-frontend=2025.0.0=h5888daf_3
|
||||||
|
- libopus=1.3.1=h7f98852_1
|
||||||
|
- libpciaccess=0.18=hd590300_0
|
||||||
|
- libpng=1.6.47=h943b412_0
|
||||||
|
- libpq=17.4=h27ae623_0
|
||||||
|
- libprotobuf=5.29.3=h501fc15_0
|
||||||
|
- librdkit=2024.09.6=h84b0b3c_0
|
||||||
|
- librsvg=2.58.4=h49af25d_2
|
||||||
|
- libsndfile=1.2.2=hc60ed4a_1
|
||||||
|
- libsodium=1.0.20=h4ab18f5_0
|
||||||
|
- libsqlite=3.49.1=hee588c1_1
|
||||||
|
- libssh2=1.11.1=hf672d98_0
|
||||||
|
- libstdcxx=14.2.0=h8f9b012_2
|
||||||
|
- libstdcxx-ng=14.2.0=h4852527_2
|
||||||
|
- libsystemd0=257.4=h4e0b6ca_1
|
||||||
|
- libtiff=4.7.0=hd9ff511_3
|
||||||
|
- libudev1=257.4=hbe16f8c_1
|
||||||
|
- libunwind=1.6.2=h9c3ff4c_0
|
||||||
|
- liburing=2.9=h84d6215_0
|
||||||
|
- libusb=1.0.27=hb9d3cd8_101
|
||||||
|
- libuuid=2.38.1=h0b41bf4_0
|
||||||
|
- libva=2.22.0=h4f16b4b_2
|
||||||
|
- libvorbis=1.3.7=h9c3ff4c_0
|
||||||
|
- libvpx=1.14.1=hac33072_0
|
||||||
|
- libwebp-base=1.5.0=h851e524_0
|
||||||
|
- libxcb=1.17.0=h8a09558_0
|
||||||
|
- libxcrypt=4.4.36=hd590300_1
|
||||||
|
- libxkbcommon=1.8.1=hc4a0caf_0
|
||||||
|
- libxml2=2.13.6=h8d12d68_0
|
||||||
|
- libxslt=1.1.39=h76b75d6_0
|
||||||
|
- libzip=1.11.2=h6991a6a_0
|
||||||
|
- libzlib=1.3.1=hb9d3cd8_2
|
||||||
|
- llvm-openmp=15.0.7=h0cdce71_0
|
||||||
|
- lxml=5.3.1=py310h6ee67d5_0
|
||||||
|
- lz4-c=1.10.0=h5888daf_1
|
||||||
|
- markupsafe=3.0.2=py310h89163eb_1
|
||||||
|
- matplotlib-base=3.10.1=py310h68603db_0
|
||||||
|
- matplotlib-inline=0.1.7=pyhd8ed1ab_1
|
||||||
|
- mda-xdrlib=0.2.0=pyhd8ed1ab_1
|
||||||
|
- mdtraj=1.10.3=py310h4cdbd58_0
|
||||||
|
- mendeleev=0.20.1=pymin39_ha308f57_3
|
||||||
|
- mistune=3.1.2=pyhd8ed1ab_0
|
||||||
|
- mkl=2022.1.0=h84fe81f_915
|
||||||
|
- mkl-devel=2022.1.0=ha770c72_916
|
||||||
|
- mkl-include=2022.1.0=h84fe81f_915
|
||||||
|
- mpc=1.3.1=h24ddda3_1
|
||||||
|
- mpfr=4.2.1=h90cbb55_3
|
||||||
|
- mpg123=1.32.9=hc50e24c_0
|
||||||
|
- mpmath=1.3.0=pyhd8ed1ab_1
|
||||||
|
- multidict=6.1.0=py310h89163eb_2
|
||||||
|
- munkres=1.1.4=pyh9f0ad1d_0
|
||||||
|
- nbclient=0.10.2=pyhd8ed1ab_0
|
||||||
|
- nbconvert-core=7.16.6=pyh29332c3_0
|
||||||
|
- nbformat=5.10.4=pyhd8ed1ab_1
|
||||||
|
- ncurses=6.5=h2d0b736_3
|
||||||
|
- nest-asyncio=1.6.0=pyhd8ed1ab_1
|
||||||
|
- netcdf-fortran=4.6.1=nompi_ha5d1325_108
|
||||||
|
- networkx=3.4.2=pyh267e887_2
|
||||||
|
- notebook=7.3.3=pyhd8ed1ab_0
|
||||||
|
- notebook-shim=0.2.4=pyhd8ed1ab_1
|
||||||
|
- numexpr=2.7.3=py310hb5077e9_1
|
||||||
|
- ocl-icd=2.3.2=hb9d3cd8_2
|
||||||
|
- ocl-icd-system=1.0.0=1
|
||||||
|
- opencl-headers=2024.10.24=h5888daf_0
|
||||||
|
- openff-amber-ff-ports=0.0.4=pyhca7485f_0
|
||||||
|
- openff-forcefields=2024.09.0=pyhff2d567_0
|
||||||
|
- openff-interchange=0.4.2=pyhd8ed1ab_2
|
||||||
|
- openff-interchange-base=0.4.2=pyhd8ed1ab_2
|
||||||
|
- openff-toolkit=0.16.8=pyhd8ed1ab_2
|
||||||
|
- openff-toolkit-base=0.16.8=pyhd8ed1ab_2
|
||||||
|
- openff-units=0.3.0=pyhd8ed1ab_1
|
||||||
|
- openff-utilities=0.1.15=pyhd8ed1ab_0
|
||||||
|
- openh264=2.6.0=hc22cd8d_0
|
||||||
|
- openjpeg=2.5.3=h5fbd93e_0
|
||||||
|
- openldap=2.6.9=he970967_0
|
||||||
|
- openmm=8.2.0=py310h30bdd6a_2
|
||||||
|
- openmmforcefields=0.14.2=pyhd8ed1ab_0
|
||||||
|
- openssl=3.4.1=h7b32b05_0
|
||||||
|
- overrides=7.7.0=pyhd8ed1ab_1
|
||||||
|
- packaging=24.2=pyhd8ed1ab_2
|
||||||
|
- panedr=0.8.0=pyhd8ed1ab_1
|
||||||
|
- pango=1.56.2=h861ebed_0
|
||||||
|
- parmed=4.3.0=py310h78e4988_1
|
||||||
|
- parso=0.8.4=pyhd8ed1ab_1
|
||||||
|
- pcre2=10.44=hba22ea6_2
|
||||||
|
- pdbfixer=1.11=pyhd8ed1ab_0
|
||||||
|
- perl=5.32.1=7_hd590300_perl5
|
||||||
|
- pexpect=4.9.0=pyhd8ed1ab_1
|
||||||
|
- pickleshare=0.7.5=pyhd8ed1ab_1004
|
||||||
|
- pillow=11.1.0=py310h7e6dc6c_0
|
||||||
|
- pint=0.24.4=pyhd8ed1ab_1
|
||||||
|
- pip=25.0.1=pyh8b19718_0
|
||||||
|
- pixman=0.44.2=h29eaf8c_0
|
||||||
|
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
|
||||||
|
- platformdirs=4.3.6=pyhd8ed1ab_1
|
||||||
|
- prometheus_client=0.21.1=pyhd8ed1ab_0
|
||||||
|
- prompt-toolkit=3.0.50=pyha770c72_0
|
||||||
|
- psutil=7.0.0=py310ha75aee5_0
|
||||||
|
- pthread-stubs=0.4=hb9d3cd8_1002
|
||||||
|
- ptyprocess=0.7.0=pyhd8ed1ab_1
|
||||||
|
- pugixml=1.15=h3f63f65_0
|
||||||
|
- pulseaudio-client=17.0=hb77b528_0
|
||||||
|
- pure_eval=0.2.3=pyhd8ed1ab_1
|
||||||
|
- py-cpuinfo=9.0.0=pyhd8ed1ab_1
|
||||||
|
- pycairo=1.27.0=py310h25ff670_0
|
||||||
|
- pycparser=2.22=pyh29332c3_1
|
||||||
|
- pydantic=2.10.6=pyh3cfb1c2_0
|
||||||
|
- pydantic-core=2.27.2=py310h505e2c1_0
|
||||||
|
- pyedr=0.8.0=pyhd8ed1ab_1
|
||||||
|
- pyfiglet=0.8.post1=py_0
|
||||||
|
- pyg=2.5.2=py310_torch_2.2.0_cu118
|
||||||
|
- pygments=2.19.1=pyhd8ed1ab_0
|
||||||
|
- pyparsing=3.2.1=pyhd8ed1ab_0
|
||||||
|
- pysocks=1.7.1=pyha55dd90_7
|
||||||
|
- pytables=3.10.1=py310h431dcdc_4
|
||||||
|
- python=3.10.16=he725a3c_1_cpython
|
||||||
|
- python-constraint=1.4.0=pyhff2d567_1
|
||||||
|
- python-dateutil=2.9.0.post0=pyhff2d567_1
|
||||||
|
- python-fastjsonschema=2.21.1=pyhd8ed1ab_0
|
||||||
|
- python-tzdata=2025.1=pyhd8ed1ab_0
|
||||||
|
- python_abi=3.10=5_cp310
|
||||||
|
- pytorch=2.2.1=py3.10_cuda11.8_cudnn8.7.0_0
|
||||||
|
- pytorch-cuda=11.8=h7e8668a_6
|
||||||
|
- pytorch-mutex=1.0=cuda
|
||||||
|
- pytorch-scatter=2.1.2=py310_torch_2.2.0_cu118
|
||||||
|
- pytz=2025.1=pyhd8ed1ab_0
|
||||||
|
- pyyaml=6.0.2=py310h89163eb_2
|
||||||
|
- pyzmq=26.3.0=py310h71f11fc_0
|
||||||
|
- qhull=2020.2=h434a139_5
|
||||||
|
- rdkit=2024.09.6=py310hcd13295_0
|
||||||
|
- readline=8.2=h8c095d6_2
|
||||||
|
- referencing=0.36.2=pyh29332c3_0
|
||||||
|
- reportlab=4.3.1=py310ha75aee5_0
|
||||||
|
- requests=2.32.3=pyhd8ed1ab_1
|
||||||
|
- rfc3339-validator=0.1.4=pyhd8ed1ab_1
|
||||||
|
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
|
||||||
|
- rlpycairo=0.2.0=pyhd8ed1ab_0
|
||||||
|
- rpds-py=0.23.1=py310hc1293b2_0
|
||||||
|
- scikit-learn=1.6.1=py310h27f47ee_0
|
||||||
|
- scipy=1.15.2=py310h1d65ade_0
|
||||||
|
- sdl2=2.32.50=h9b8e6db_1
|
||||||
|
- sdl3=3.2.8=h3083f51_0
|
||||||
|
- send2trash=1.8.3=pyh0d859eb_1
|
||||||
|
- setuptools=75.8.2=pyhff2d567_0
|
||||||
|
- six=1.17.0=pyhd8ed1ab_0
|
||||||
|
- smirnoff99frosst=1.1.0=pyh44b312d_0
|
||||||
|
- snappy=1.2.1=h8bd8927_1
|
||||||
|
- sniffio=1.3.1=pyhd8ed1ab_1
|
||||||
|
- sqlalchemy=2.0.39=py310ha75aee5_1
|
||||||
|
- stack_data=0.6.3=pyhd8ed1ab_1
|
||||||
|
- svt-av1=3.0.1=h5888daf_0
|
||||||
|
- sympy=1.13.3=pyh2585a3b_105
|
||||||
|
- sysroot_linux-64=2.17=h0157908_18
|
||||||
|
- tbb=2021.13.0=hceb3a55_1
|
||||||
|
- terminado=0.18.1=pyh0d859eb_0
|
||||||
|
- threadpoolctl=3.6.0=pyhecae5ae_0
|
||||||
|
- tinycss2=1.4.0=pyhd8ed1ab_0
|
||||||
|
- tinydb=4.8.2=pyhd8ed1ab_1
|
||||||
|
- tk=8.6.13=noxft_h4845f30_101
|
||||||
|
- tomli=2.2.1=pyhd8ed1ab_1
|
||||||
|
- torchaudio=2.2.1=py310_cu118
|
||||||
|
- torchtriton=2.2.0=py310
|
||||||
|
- torchvision=0.17.1=py310_cu118
|
||||||
|
- tornado=6.4.2=py310ha75aee5_0
|
||||||
|
- tqdm=4.67.1=pyhd8ed1ab_1
|
||||||
|
- traitlets=5.14.3=pyhd8ed1ab_1
|
||||||
|
- types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
|
||||||
|
- typing-extensions=4.12.2=hd8ed1ab_1
|
||||||
|
- typing_extensions=4.12.2=pyha770c72_1
|
||||||
|
- typing_utils=0.1.0=pyhd8ed1ab_1
|
||||||
|
- tzdata=2025a=h78e105d_0
|
||||||
|
- unicodedata2=16.0.0=py310ha75aee5_0
|
||||||
|
- uri-template=1.3.0=pyhd8ed1ab_1
|
||||||
|
- urllib3=2.3.0=pyhd8ed1ab_0
|
||||||
|
- validators=0.34.0=pyhd8ed1ab_1
|
||||||
|
- wayland=1.23.1=h3e06ad9_0
|
||||||
|
- wayland-protocols=1.41=hd8ed1ab_0
|
||||||
|
- wcwidth=0.2.13=pyhd8ed1ab_1
|
||||||
|
- webcolors=24.11.1=pyhd8ed1ab_0
|
||||||
|
- webencodings=0.5.1=pyhd8ed1ab_3
|
||||||
|
- websocket-client=1.8.0=pyhd8ed1ab_1
|
||||||
|
- wheel=0.45.1=pyhd8ed1ab_1
|
||||||
|
- wrapt=1.17.2=py310ha75aee5_0
|
||||||
|
- x264=1!164.3095=h166bdaf_2
|
||||||
|
- x265=3.5=h924138e_3
|
||||||
|
- xkeyboard-config=2.43=hb9d3cd8_0
|
||||||
|
- xmltodict=0.14.2=pyhd8ed1ab_1
|
||||||
|
- xorg-libice=1.1.2=hb9d3cd8_0
|
||||||
|
- xorg-libsm=1.2.6=he73a12e_0
|
||||||
|
- xorg-libx11=1.8.12=h4f16b4b_0
|
||||||
|
- xorg-libxau=1.0.12=hb9d3cd8_0
|
||||||
|
- xorg-libxcursor=1.2.3=hb9d3cd8_0
|
||||||
|
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
|
||||||
|
- xorg-libxext=1.3.6=hb9d3cd8_0
|
||||||
|
- xorg-libxfixes=6.0.1=hb9d3cd8_0
|
||||||
|
- xorg-libxrender=0.9.12=hb9d3cd8_0
|
||||||
|
- xorg-libxscrnsaver=1.2.4=hb9d3cd8_0
|
||||||
|
- xorg-libxt=1.3.1=hb9d3cd8_0
|
||||||
|
- yaml=0.2.5=h7f98852_2
|
||||||
|
- yarl=1.18.3=py310h89163eb_1
|
||||||
|
- zeromq=4.3.5=h3b0a872_7
|
||||||
|
- zipp=3.21.0=pyhd8ed1ab_1
|
||||||
|
- zlib=1.3.1=hb9d3cd8_2
|
||||||
|
- zlib-ng=2.2.4=h7955e40_0
|
||||||
|
- zstandard=0.23.0=py310ha75aee5_1
|
||||||
|
- zstd=1.5.7=hb8e6e7a_1
|
||||||
|
- pip:
|
||||||
|
- absl-py==2.1.0
|
||||||
|
- alembic==1.15.1
|
||||||
|
- amberutils==21.0
|
||||||
|
- antlr4-python3-runtime==4.9.3
|
||||||
|
- autopage==0.5.2
|
||||||
|
- beartype==0.20.0
|
||||||
|
- biopandas==0.5.1
|
||||||
|
- biopython==1.79
|
||||||
|
- biotite==1.1.0
|
||||||
|
- biotraj==1.2.2
|
||||||
|
- cfgv==3.4.0
|
||||||
|
- cftime==1.6.4.post1
|
||||||
|
- click==8.1.8
|
||||||
|
- cliff==4.9.1
|
||||||
|
- cloudpathlib==0.21.0
|
||||||
|
- cmaes==0.11.1
|
||||||
|
- cmd2==2.5.11
|
||||||
|
- colorlog==6.9.0
|
||||||
|
- distlib==0.3.9
|
||||||
|
- git+https://github.com/NVIDIA/dllogger.git@0540a43971f4a8a16693a9de9de73c1072020769
|
||||||
|
- dm-tree==0.1.9
|
||||||
|
- docker-pycreds==0.4.0
|
||||||
|
- duckdb==1.2.1
|
||||||
|
- edgembar==3.0
|
||||||
|
- einops==0.8.1
|
||||||
|
- eval-type-backport==0.2.2
|
||||||
|
- executing==2.2.0
|
||||||
|
- fair-esm==2.0.0
|
||||||
|
- fairscale==0.4.13
|
||||||
|
- fastcore==1.7.29
|
||||||
|
- future==1.0.0
|
||||||
|
- fvcore==0.1.5.post20221221
|
||||||
|
- gcsfs==2025.3.0
|
||||||
|
- gemmi==0.7.0
|
||||||
|
- gitdb==4.0.12
|
||||||
|
- gitpython==3.1.44
|
||||||
|
- google-api-core==2.24.2
|
||||||
|
- google-auth==2.38.0
|
||||||
|
- google-auth-oauthlib==1.2.1
|
||||||
|
- google-cloud-core==2.4.3
|
||||||
|
- google-cloud-storage==3.1.0
|
||||||
|
- google-crc32c==1.6.0
|
||||||
|
- google-resumable-media==2.7.2
|
||||||
|
- googleapis-common-protos==1.69.1
|
||||||
|
- hydra-colorlog==1.2.0
|
||||||
|
- hydra-core==1.3.2
|
||||||
|
- hydra-optuna-sweeper==1.2.0
|
||||||
|
- identify==2.6.9
|
||||||
|
- iniconfig==2.0.0
|
||||||
|
- iopath==0.1.10
|
||||||
|
- ipython-genutils==0.2.0
|
||||||
|
- ipywidgets==7.8.5
|
||||||
|
- jupyterlab-widgets==1.1.11
|
||||||
|
- lightning==2.5.0.post0
|
||||||
|
- lightning-utilities==0.14.1
|
||||||
|
- looseversion==1.1.2
|
||||||
|
- lovely-numpy==0.2.13
|
||||||
|
- lovely-tensors==0.1.18
|
||||||
|
- mako==1.3.9
|
||||||
|
- markdown-it-py==3.0.0
|
||||||
|
- mdurl==0.1.2
|
||||||
|
- ml-collections==1.0.0
|
||||||
|
- mmcif==0.91.0
|
||||||
|
- mmpbsa-py==16.0
|
||||||
|
- mmtf-python==1.1.3
|
||||||
|
- mols2grid==2.0.0
|
||||||
|
- msgpack==1.1.0
|
||||||
|
- msgpack-numpy==0.4.8
|
||||||
|
- narwhals==1.30.0
|
||||||
|
- netcdf4==1.7.2
|
||||||
|
- nodeenv==1.9.1
|
||||||
|
- numpy==1.26.4
|
||||||
|
- oauthlib==3.2.2
|
||||||
|
- omegaconf==2.3.0
|
||||||
|
- git+https://github.com/amorehead/openfold.git@fe1275099639bf7e617e09ef24d6af778647dd64
|
||||||
|
- optuna==2.10.1
|
||||||
|
- packmol-memgen==2025.1.29
|
||||||
|
- pandas==2.2.3
|
||||||
|
- pandocfilters==1.5.1
|
||||||
|
- pbr==6.1.1
|
||||||
|
- pdb4amber==22.0
|
||||||
|
- plinder==0.2.24
|
||||||
|
- plotly==6.0.0
|
||||||
|
- pluggy==1.5.0
|
||||||
|
- portalocker==3.1.1
|
||||||
|
- posebusters==0.2.13
|
||||||
|
- git+https://git@github.com/zrqiao/power_spherical.git@290b1630c5f84e3bb0d61711046edcf6e47200d4
|
||||||
|
- pre-commit==4.1.0
|
||||||
|
- prettytable==3.15.1
|
||||||
|
# - prody==2.4.1 # NOTE: we must `pip` install Prody to skip its NumPy dependency
|
||||||
|
- propcache==0.3.0
|
||||||
|
- proto-plus==1.26.1
|
||||||
|
- protobuf==5.29.3
|
||||||
|
- pyarrow==19.0.1
|
||||||
|
- pyasn1==0.6.1
|
||||||
|
- pyasn1-modules==0.4.1
|
||||||
|
- pymsmt==22.0
|
||||||
|
- pyperclip==1.9.0
|
||||||
|
- pytest==8.3.5
|
||||||
|
- python-dotenv==1.0.1
|
||||||
|
- python-json-logger==3.3.0
|
||||||
|
- pytorch-lightning==2.5.0.post0
|
||||||
|
- git+https://github.com/facebookresearch/pytorch3d.git@3da7703c5ac10039645966deddffe8db52eab8c5
|
||||||
|
- pytraj==2.0.6
|
||||||
|
- requests-oauthlib==2.0.0
|
||||||
|
- rich==13.9.4
|
||||||
|
- rootutils==1.0.7
|
||||||
|
- rsa==4.9
|
||||||
|
- sander==22.0
|
||||||
|
- seaborn==0.13.2
|
||||||
|
- sentry-sdk==2.22.0
|
||||||
|
- setproctitle==1.3.5
|
||||||
|
- smmap==5.0.2
|
||||||
|
- soupsieve==2.6
|
||||||
|
- stevedore==5.4.1
|
||||||
|
- tabulate==0.9.0
|
||||||
|
- termcolor==2.5.0
|
||||||
|
- torchmetrics==1.6.3
|
||||||
|
- virtualenv==20.29.3
|
||||||
|
- wandb==0.19.8
|
||||||
|
- widgetsnbextension==3.6.10
|
||||||
|
- yacs==0.1.8
|
||||||
58
environments/flowdock_environment_docker.yaml
Normal file
58
environments/flowdock_environment_docker.yaml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
name: flowdock
|
||||||
|
channels:
|
||||||
|
- pyg
|
||||||
|
- pytorch
|
||||||
|
- nvidia
|
||||||
|
- defaults
|
||||||
|
- conda-forge
|
||||||
|
dependencies:
|
||||||
|
- mendeleev=0.20.1=pymin39_ha308f57_3
|
||||||
|
- networkx=3.4.2=pyh267e887_2
|
||||||
|
- python=3.10.16=he725a3c_1_cpython
|
||||||
|
- pytorch=2.2.1=py3.10_cuda11.8_cudnn8.7.0_0
|
||||||
|
- pytorch-cuda=11.8=h7e8668a_6
|
||||||
|
- pytorch-mutex=1.0=cuda
|
||||||
|
- pytorch-scatter=2.1.2=py310_torch_2.2.0_cu118
|
||||||
|
- rdkit=2024.09.6=py310hcd13295_0
|
||||||
|
- scikit-learn=1.6.1=py310h27f47ee_0
|
||||||
|
- scipy=1.15.2=py310h1d65ade_0
|
||||||
|
- torchaudio=2.2.1=py310_cu118
|
||||||
|
- torchtriton=2.2.0=py310
|
||||||
|
- torchvision=0.17.1=py310_cu118
|
||||||
|
- tqdm=4.67.1=pyhd8ed1ab_1
|
||||||
|
- pip:
|
||||||
|
- beartype==0.20.0
|
||||||
|
- biopandas==0.5.1
|
||||||
|
- biopython==1.79
|
||||||
|
- biotite==1.1.0
|
||||||
|
- git+https://github.com/NVIDIA/dllogger.git@0540a43971f4a8a16693a9de9de73c1072020769
|
||||||
|
- dm-tree==0.1.9
|
||||||
|
- einops==0.8.1
|
||||||
|
- fair-esm==2.0.0
|
||||||
|
- fairscale==0.4.13
|
||||||
|
- gemmi==0.7.0
|
||||||
|
- hydra-colorlog==1.2.0
|
||||||
|
- hydra-core==1.3.2
|
||||||
|
- hydra-optuna-sweeper==1.2.0
|
||||||
|
- lightning==2.5.0.post0
|
||||||
|
- lightning-utilities==0.14.1
|
||||||
|
- lovely-numpy==0.2.13
|
||||||
|
- lovely-tensors==0.1.18
|
||||||
|
- ml-collections==1.0.0
|
||||||
|
- msgpack==1.1.0
|
||||||
|
- msgpack-numpy==0.4.8
|
||||||
|
- numpy==1.26.4
|
||||||
|
- omegaconf==2.3.0
|
||||||
|
- git+https://github.com/amorehead/openfold.git@fe1275099639bf7e617e09ef24d6af778647dd64
|
||||||
|
- pandas==2.2.3
|
||||||
|
- plinder==0.2.24
|
||||||
|
- plotly==6.0.0
|
||||||
|
- posebusters==0.2.13
|
||||||
|
# - prody==2.4.1 # NOTE: we must `pip` install Prody to skip its NumPy dependency
|
||||||
|
- pytorch-lightning==2.5.0.post0
|
||||||
|
- git+https://github.com/facebookresearch/pytorch3d.git@3da7703c5ac10039645966deddffe8db52eab8c5
|
||||||
|
- rich==13.9.4
|
||||||
|
- rootutils==1.0.7
|
||||||
|
- seaborn==0.13.2
|
||||||
|
- torchmetrics==1.6.3
|
||||||
|
- wandb==0.19.8
|
||||||
120
flowdock/__init__.py
Normal file
120
flowdock/__init__.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
from beartype.typing import Any
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
METHOD_TITLE_MAPPING = {
|
||||||
|
"diffdock": "DiffDock",
|
||||||
|
"flowdock": "FlowDock",
|
||||||
|
"neuralplexer": "NeuralPLexer",
|
||||||
|
}
|
||||||
|
|
||||||
|
STANDARDIZED_DIR_METHODS = ["diffdock"]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_omegaconf_variable(variable_path: str) -> Any:
|
||||||
|
"""Resolve an OmegaConf variable path to its value."""
|
||||||
|
# split the string into parts using the dot separator
|
||||||
|
parts = variable_path.rsplit(".", 1)
|
||||||
|
|
||||||
|
# get the module name from the first part of the path
|
||||||
|
module_name = parts[0]
|
||||||
|
|
||||||
|
# dynamically import the module using the module name
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
# use the imported module to get the requested attribute value
|
||||||
|
attribute = getattr(module, parts[1])
|
||||||
|
except Exception:
|
||||||
|
module = importlib.import_module(".".join(module_name.split(".")[:-1]))
|
||||||
|
inner_module = ".".join(module_name.split(".")[-1:])
|
||||||
|
# use the imported module to get the requested attribute value
|
||||||
|
attribute = getattr(getattr(module, inner_module), parts[1])
|
||||||
|
|
||||||
|
return attribute
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_dataset_path_dirname(dataset: str) -> str:
|
||||||
|
"""Resolve the dataset path directory name based on the dataset's name.
|
||||||
|
|
||||||
|
:param dataset: Name of the dataset.
|
||||||
|
:return: Directory name for the dataset path.
|
||||||
|
"""
|
||||||
|
return "DockGen" if dataset == "dockgen" else dataset
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_method_input_csv_path(method: str, dataset: str) -> str:
|
||||||
|
"""Resolve the input CSV path for a given method.
|
||||||
|
|
||||||
|
:param method: The method name.
|
||||||
|
:param dataset: The dataset name.
|
||||||
|
:return: The input CSV path for the given method.
|
||||||
|
"""
|
||||||
|
if method in STANDARDIZED_DIR_METHODS or method in ["flowdock", "neuralplexer"]:
|
||||||
|
return os.path.join(
|
||||||
|
"forks",
|
||||||
|
METHOD_TITLE_MAPPING.get(method, method),
|
||||||
|
"inference",
|
||||||
|
f"{method}_{dataset}_inputs.csv",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid method: {method}")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_method_title(method: str) -> str:
|
||||||
|
"""Resolve the method title for a given method.
|
||||||
|
|
||||||
|
:param method: The method name.
|
||||||
|
:return: The method title for the given method.
|
||||||
|
"""
|
||||||
|
return METHOD_TITLE_MAPPING.get(method, method)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_method_output_dir(
|
||||||
|
method: str,
|
||||||
|
dataset: str,
|
||||||
|
repeat_index: int,
|
||||||
|
) -> str:
|
||||||
|
"""Resolve the output directory for a given method.
|
||||||
|
|
||||||
|
:param method: The method name.
|
||||||
|
:param dataset: The dataset name.
|
||||||
|
:param repeat_index: The repeat index for the method.
|
||||||
|
:return: The output directory for the given method.
|
||||||
|
"""
|
||||||
|
if method in STANDARDIZED_DIR_METHODS or method in ["flowdock", "neuralplexer"]:
|
||||||
|
return os.path.join(
|
||||||
|
"forks",
|
||||||
|
METHOD_TITLE_MAPPING.get(method, method),
|
||||||
|
"inference",
|
||||||
|
f"{method}_{dataset}_output{'s' if method in ['flowdock', 'neuralplexer'] else ''}_{repeat_index}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid method: {method}")
|
||||||
|
|
||||||
|
|
||||||
|
def register_custom_omegaconf_resolvers():
|
||||||
|
"""Register custom OmegaConf resolvers."""
|
||||||
|
OmegaConf.register_new_resolver(
|
||||||
|
"resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path)
|
||||||
|
)
|
||||||
|
OmegaConf.register_new_resolver(
|
||||||
|
"resolve_dataset_path_dirname", lambda dataset: resolve_dataset_path_dirname(dataset)
|
||||||
|
)
|
||||||
|
OmegaConf.register_new_resolver(
|
||||||
|
"resolve_method_input_csv_path",
|
||||||
|
lambda method, dataset: resolve_method_input_csv_path(method, dataset),
|
||||||
|
)
|
||||||
|
OmegaConf.register_new_resolver(
|
||||||
|
"resolve_method_title", lambda method: resolve_method_title(method)
|
||||||
|
)
|
||||||
|
OmegaConf.register_new_resolver(
|
||||||
|
"resolve_method_output_dir",
|
||||||
|
lambda method, dataset, repeat_index: resolve_method_output_dir(
|
||||||
|
method, dataset, repeat_index
|
||||||
|
),
|
||||||
|
)
|
||||||
|
OmegaConf.register_new_resolver(
|
||||||
|
"int_divide", lambda dividend, divisor: int(dividend) // int(divisor)
|
||||||
|
)
|
||||||
165
flowdock/eval.py
Normal file
165
flowdock/eval.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import lightning as L
|
||||||
|
import lovely_tensors as lt
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, List, Tuple
|
||||||
|
from lightning import LightningDataModule, LightningModule, Trainer
|
||||||
|
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
|
from lightning.pytorch.loggers import Logger
|
||||||
|
from lightning.pytorch.strategies.strategy import Strategy
|
||||||
|
from omegaconf import DictConfig, open_dict
|
||||||
|
|
||||||
|
lt.monkey_patch()
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
# the setup_root above is equivalent to:
|
||||||
|
# - adding project root dir to PYTHONPATH
|
||||||
|
# (so you don't need to force user to install project as a package)
|
||||||
|
# (necessary before importing any local modules e.g. `from flowdock import utils`)
|
||||||
|
# - setting up PROJECT_ROOT environment variable
|
||||||
|
# (which is used as a base for paths in "configs/paths/default.yaml")
|
||||||
|
# (this way all filepaths are the same no matter where you run the code)
|
||||||
|
# - loading environment variables from ".env" in root dir
|
||||||
|
#
|
||||||
|
# you can remove it if you:
|
||||||
|
# 1. either install project as a package or move entry files to project root dir
|
||||||
|
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
||||||
|
#
|
||||||
|
# more info: https://github.com/ashleve/rootutils
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
|
||||||
|
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
|
||||||
|
from flowdock.utils import (
|
||||||
|
RankedLogger,
|
||||||
|
extras,
|
||||||
|
instantiate_loggers,
|
||||||
|
log_hyperparameters,
|
||||||
|
task_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
@task_wrapper
|
||||||
|
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""Evaluates given checkpoint on a datamodule testset.
|
||||||
|
|
||||||
|
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
||||||
|
failure. Useful for multiruns, saving info about the crash, etc.
|
||||||
|
|
||||||
|
:param cfg: DictConfig configuration composed by Hydra.
|
||||||
|
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
|
||||||
|
"""
|
||||||
|
assert cfg.ckpt_path, "Please provide a checkpoint path to evaluate!"
|
||||||
|
assert os.path.exists(cfg.ckpt_path), f"Checkpoint path {cfg.ckpt_path} does not exist!"
|
||||||
|
|
||||||
|
# set seed for random number generators in pytorch, numpy and python.random
|
||||||
|
if cfg.get("seed"):
|
||||||
|
L.seed_everything(cfg.seed, workers=True)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
|
||||||
|
)
|
||||||
|
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
|
||||||
|
|
||||||
|
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
||||||
|
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, stage="test")
|
||||||
|
|
||||||
|
# Establish model input arguments
|
||||||
|
with open_dict(cfg):
|
||||||
|
if cfg.model.cfg.task.start_time == "auto":
|
||||||
|
cfg.model.cfg.task.start_time = 1.0
|
||||||
|
else:
|
||||||
|
cfg.model.cfg.task.start_time = float(cfg.model.cfg.task.start_time)
|
||||||
|
|
||||||
|
log.info(f"Instantiating model <{cfg.model._target_}>")
|
||||||
|
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
||||||
|
|
||||||
|
log.info("Instantiating loggers...")
|
||||||
|
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
||||||
|
|
||||||
|
plugins = None
|
||||||
|
if "_target_" in cfg.environment:
|
||||||
|
log.info(f"Instantiating environment <{cfg.environment._target_}>")
|
||||||
|
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
|
||||||
|
|
||||||
|
strategy = getattr(cfg.trainer, "strategy", None)
|
||||||
|
if "_target_" in cfg.strategy:
|
||||||
|
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
|
||||||
|
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
|
||||||
|
if (
|
||||||
|
"mixed_precision" in strategy.__dict__
|
||||||
|
and getattr(strategy, "mixed_precision", None) is not None
|
||||||
|
):
|
||||||
|
strategy.mixed_precision.param_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
strategy.mixed_precision.reduce_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
strategy.mixed_precision.buffer_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
||||||
|
trainer: Trainer = (
|
||||||
|
hydra.utils.instantiate(
|
||||||
|
cfg.trainer,
|
||||||
|
logger=logger,
|
||||||
|
plugins=plugins,
|
||||||
|
strategy=strategy,
|
||||||
|
)
|
||||||
|
if strategy is not None
|
||||||
|
else hydra.utils.instantiate(
|
||||||
|
cfg.trainer,
|
||||||
|
logger=logger,
|
||||||
|
plugins=plugins,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
object_dict = {
|
||||||
|
"cfg": cfg,
|
||||||
|
"datamodule": datamodule,
|
||||||
|
"model": model,
|
||||||
|
"logger": logger,
|
||||||
|
"trainer": trainer,
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger:
|
||||||
|
log.info("Logging hyperparameters!")
|
||||||
|
log_hyperparameters(object_dict)
|
||||||
|
|
||||||
|
log.info("Starting testing!")
|
||||||
|
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
|
||||||
|
|
||||||
|
metric_dict = trainer.callback_metrics
|
||||||
|
|
||||||
|
return metric_dict, object_dict
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
|
||||||
|
def main(cfg: DictConfig) -> None:
|
||||||
|
"""Main entry point for evaluation.
|
||||||
|
|
||||||
|
:param cfg: DictConfig configuration composed by Hydra.
|
||||||
|
"""
|
||||||
|
# apply extra utilities
|
||||||
|
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
||||||
|
extras(cfg)
|
||||||
|
|
||||||
|
evaluate(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
register_custom_omegaconf_resolvers()
|
||||||
|
main()
|
||||||
0
flowdock/models/__init__.py
Normal file
0
flowdock/models/__init__.py
Normal file
0
flowdock/models/components/__init__.py
Normal file
0
flowdock/models/components/__init__.py
Normal file
452
flowdock/models/components/callbacks/ema.py
Normal file
452
flowdock/models/components/callbacks/ema.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
# Adapted from https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
|
||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import contextlib
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
|
import lightning.pytorch as pl
|
||||||
|
import torch
|
||||||
|
from lightning.pytorch import Callback
|
||||||
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||||
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||||
|
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
||||||
|
|
||||||
|
|
||||||
|
class EMA(Callback):
|
||||||
|
"""Implements Exponential Moving Averaging (EMA).
|
||||||
|
|
||||||
|
When training a model, this callback will maintain moving averages of the trained parameters.
|
||||||
|
When evaluating, we use the moving averages copy of the trained parameters.
|
||||||
|
When saving, we save an additional set of parameters with the prefix `ema`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
|
||||||
|
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
|
||||||
|
every_n_steps: Apply EMA every N steps.
|
||||||
|
cpu_offload: Offload weights to CPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decay: float,
|
||||||
|
validate_original_weights: bool = False,
|
||||||
|
every_n_steps: int = 1,
|
||||||
|
cpu_offload: bool = False,
|
||||||
|
):
|
||||||
|
if not (0 <= decay <= 1):
|
||||||
|
raise MisconfigurationException("EMA decay value must be between 0 and 1")
|
||||||
|
self.decay = decay
|
||||||
|
self.validate_original_weights = validate_original_weights
|
||||||
|
self.every_n_steps = every_n_steps
|
||||||
|
self.cpu_offload = cpu_offload
|
||||||
|
|
||||||
|
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
"""Add the EMA optimizer to the trainer."""
|
||||||
|
device = pl_module.device if not self.cpu_offload else torch.device("cpu")
|
||||||
|
trainer.optimizers = [
|
||||||
|
EMAOptimizer(
|
||||||
|
optim,
|
||||||
|
device=device,
|
||||||
|
decay=self.decay,
|
||||||
|
every_n_steps=self.every_n_steps,
|
||||||
|
current_step=trainer.global_step,
|
||||||
|
)
|
||||||
|
for optim in trainer.optimizers
|
||||||
|
if not isinstance(optim, EMAOptimizer)
|
||||||
|
]
|
||||||
|
|
||||||
|
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
"""Swap the model weights with the EMA weights."""
|
||||||
|
if self._should_validate_ema_weights(trainer):
|
||||||
|
self.swap_model_weights(trainer)
|
||||||
|
|
||||||
|
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
"""Swap the model weights back to the original weights."""
|
||||||
|
if self._should_validate_ema_weights(trainer):
|
||||||
|
self.swap_model_weights(trainer)
|
||||||
|
|
||||||
|
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
"""Swap the model weights with the EMA weights."""
|
||||||
|
if self._should_validate_ema_weights(trainer):
|
||||||
|
self.swap_model_weights(trainer)
|
||||||
|
|
||||||
|
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
"""Swap the model weights back to the original weights."""
|
||||||
|
if self._should_validate_ema_weights(trainer):
|
||||||
|
self.swap_model_weights(trainer)
|
||||||
|
|
||||||
|
def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
|
||||||
|
"""Check if the EMA weights should be validated."""
|
||||||
|
return not self.validate_original_weights and self._ema_initialized(trainer)
|
||||||
|
|
||||||
|
def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
|
||||||
|
"""Check if the EMA weights have been initialized."""
|
||||||
|
return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers)
|
||||||
|
|
||||||
|
def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
|
||||||
|
"""Swaps the model weights with the EMA weights."""
|
||||||
|
for optimizer in trainer.optimizers:
|
||||||
|
assert isinstance(optimizer, EMAOptimizer)
|
||||||
|
optimizer.switch_main_parameter_weights(saving_ema_model)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def save_ema_model(self, trainer: "pl.Trainer"):
|
||||||
|
"""Saves an EMA copy of the model + EMA optimizer states for resume."""
|
||||||
|
self.swap_model_weights(trainer, saving_ema_model=True)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.swap_model_weights(trainer, saving_ema_model=False)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def save_original_optimizer_state(self, trainer: "pl.Trainer"):
|
||||||
|
"""Save the original optimizer state."""
|
||||||
|
for optimizer in trainer.optimizers:
|
||||||
|
assert isinstance(optimizer, EMAOptimizer)
|
||||||
|
optimizer.save_original_optimizer_state = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for optimizer in trainer.optimizers:
|
||||||
|
optimizer.save_original_optimizer_state = False
|
||||||
|
|
||||||
|
def on_load_checkpoint(
|
||||||
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Load the EMA state from the checkpoint if it exists."""
|
||||||
|
checkpoint_callback = trainer.checkpoint_callback
|
||||||
|
|
||||||
|
# use the connector as NeMo calls the connector directly in the exp_manager when restoring.
|
||||||
|
connector = trainer._checkpoint_connector
|
||||||
|
# Replace connector._ckpt_path with below to avoid calling into lightning's protected API
|
||||||
|
ckpt_path = trainer.ckpt_path
|
||||||
|
|
||||||
|
if (
|
||||||
|
ckpt_path
|
||||||
|
and checkpoint_callback is not None
|
||||||
|
and "EMA" in type(checkpoint_callback).__name__
|
||||||
|
):
|
||||||
|
ext = checkpoint_callback.FILE_EXTENSION
|
||||||
|
if ckpt_path.endswith(f"-EMA{ext}"):
|
||||||
|
rank_zero_info(
|
||||||
|
"loading EMA based weights. "
|
||||||
|
"The callback will treat the loaded EMA weights as the main weights"
|
||||||
|
" and create a new EMA copy when training."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
ema_path = ckpt_path.replace(ext, f"-EMA{ext}")
|
||||||
|
if os.path.exists(ema_path):
|
||||||
|
ema_state_dict = torch.load(ema_path, map_location=torch.device("cpu"))
|
||||||
|
|
||||||
|
checkpoint["optimizer_states"] = ema_state_dict["optimizer_states"]
|
||||||
|
del ema_state_dict
|
||||||
|
rank_zero_info("EMA state has been restored.")
|
||||||
|
else:
|
||||||
|
raise MisconfigurationException(
|
||||||
|
"Unable to find the associated EMA weights when re-loading, "
|
||||||
|
f"training will start with new EMA weights. Expected them to be at: {ema_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ema_update(ema_model_tuple, current_model_tuple, decay):
|
||||||
|
"""Update the EMA model with the current model."""
|
||||||
|
torch._foreach_mul_(ema_model_tuple, decay)
|
||||||
|
torch._foreach_add_(
|
||||||
|
ema_model_tuple,
|
||||||
|
current_model_tuple,
|
||||||
|
alpha=(1.0 - decay),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None):
|
||||||
|
"""Run EMA update on CPU."""
|
||||||
|
if pre_sync_stream is not None:
|
||||||
|
pre_sync_stream.synchronize()
|
||||||
|
|
||||||
|
ema_update(ema_model_tuple, current_model_tuple, decay)
|
||||||
|
|
||||||
|
|
||||||
|
class EMAOptimizer(torch.optim.Optimizer):
|
||||||
|
r"""EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average
|
||||||
|
of parameters registered in the optimizer.
|
||||||
|
|
||||||
|
EMA parameters are automatically updated after every step of the optimizer
|
||||||
|
with the following formula:
|
||||||
|
|
||||||
|
ema_weight = decay * ema_weight + (1 - decay) * training_weight
|
||||||
|
|
||||||
|
To access EMA parameters, use ``swap_ema_weights()`` context manager to
|
||||||
|
perform a temporary in-place swap of regular parameters with EMA
|
||||||
|
parameters.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- EMAOptimizer is not compatible with APEX AMP O2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): optimizer to wrap
|
||||||
|
device (torch.device): device for EMA parameters
|
||||||
|
decay (float): decay factor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
returns an instance of torch.optim.Optimizer that computes EMA of
|
||||||
|
parameters
|
||||||
|
|
||||||
|
Example:
|
||||||
|
model = Model().to(device)
|
||||||
|
opt = torch.optim.Adam(model.parameters())
|
||||||
|
|
||||||
|
opt = EMAOptimizer(opt, device, 0.9999)
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
training_loop(model, opt)
|
||||||
|
|
||||||
|
regular_eval_accuracy = evaluate(model)
|
||||||
|
|
||||||
|
with opt.swap_ema_weights():
|
||||||
|
ema_eval_accuracy = evaluate(model)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
device: torch.device,
|
||||||
|
decay: float = 0.9999,
|
||||||
|
every_n_steps: int = 1,
|
||||||
|
current_step: int = 0,
|
||||||
|
):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.decay = decay
|
||||||
|
self.device = device
|
||||||
|
self.current_step = current_step
|
||||||
|
self.every_n_steps = every_n_steps
|
||||||
|
self.save_original_optimizer_state = False
|
||||||
|
|
||||||
|
self.first_iteration = True
|
||||||
|
self.rebuild_ema_params = True
|
||||||
|
self.stream = None
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
|
self.ema_params = ()
|
||||||
|
self.in_saving_ema_model_context = False
|
||||||
|
|
||||||
|
def all_parameters(self) -> Iterable[torch.Tensor]:
|
||||||
|
"""Return an iterator over all parameters in the optimizer."""
|
||||||
|
return (param for group in self.param_groups for param in group["params"])
|
||||||
|
|
||||||
|
def step(self, closure=None, grad_scaler=None, **kwargs):
|
||||||
|
"""Perform a single optimization step."""
|
||||||
|
self.join()
|
||||||
|
|
||||||
|
if self.first_iteration:
|
||||||
|
if any(p.is_cuda for p in self.all_parameters()):
|
||||||
|
self.stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
self.first_iteration = False
|
||||||
|
|
||||||
|
if self.rebuild_ema_params:
|
||||||
|
opt_params = list(self.all_parameters())
|
||||||
|
|
||||||
|
self.ema_params += tuple(
|
||||||
|
copy.deepcopy(param.data.detach()).to(self.device)
|
||||||
|
for param in opt_params[len(self.ema_params) :]
|
||||||
|
)
|
||||||
|
self.rebuild_ema_params = False
|
||||||
|
|
||||||
|
if (
|
||||||
|
getattr(self.optimizer, "_step_supports_amp_scaling", False)
|
||||||
|
and grad_scaler is not None
|
||||||
|
):
|
||||||
|
loss = self.optimizer.step(closure=closure, grad_scaler=grad_scaler)
|
||||||
|
else:
|
||||||
|
loss = self.optimizer.step(closure)
|
||||||
|
|
||||||
|
if self._should_update_at_step():
|
||||||
|
self.update()
|
||||||
|
self.current_step += 1
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _should_update_at_step(self) -> bool:
|
||||||
|
"""Check if the EMA parameters should be updated at the current step."""
|
||||||
|
return self.current_step % self.every_n_steps == 0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update(self):
|
||||||
|
"""Update the EMA parameters."""
|
||||||
|
if self.stream is not None:
|
||||||
|
self.stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
|
||||||
|
with torch.cuda.stream(self.stream):
|
||||||
|
current_model_state = tuple(
|
||||||
|
param.data.to(self.device, non_blocking=True) for param in self.all_parameters()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.device.type == "cuda":
|
||||||
|
ema_update(self.ema_params, current_model_state, self.decay)
|
||||||
|
|
||||||
|
if self.device.type == "cpu":
|
||||||
|
self.thread = threading.Thread(
|
||||||
|
target=run_ema_update_cpu,
|
||||||
|
args=(
|
||||||
|
self.ema_params,
|
||||||
|
current_model_state,
|
||||||
|
self.decay,
|
||||||
|
self.stream,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def swap_tensors(self, tensor1, tensor2):
|
||||||
|
"""Swaps the tensors in-place."""
|
||||||
|
tmp = torch.empty_like(tensor1)
|
||||||
|
tmp.copy_(tensor1)
|
||||||
|
tensor1.copy_(tensor2)
|
||||||
|
tensor2.copy_(tmp)
|
||||||
|
|
||||||
|
def switch_main_parameter_weights(self, saving_ema_model: bool = False):
|
||||||
|
"""Switches the main parameter weights with the EMA weights."""
|
||||||
|
self.join()
|
||||||
|
self.in_saving_ema_model_context = saving_ema_model
|
||||||
|
for param, ema_param in zip(self.all_parameters(), self.ema_params):
|
||||||
|
self.swap_tensors(param.data, ema_param)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def swap_ema_weights(self, enabled: bool = True):
|
||||||
|
r"""A context manager to in-place swap regular parameters with EMA parameters. It swaps back
|
||||||
|
to the original regular parameters on context manager exit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled (bool): whether the swap should be performed
|
||||||
|
"""
|
||||||
|
|
||||||
|
if enabled:
|
||||||
|
self.switch_main_parameter_weights()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if enabled:
|
||||||
|
self.switch_main_parameter_weights()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
"""Forward all other attribute calls to the optimizer."""
|
||||||
|
return getattr(self.optimizer, name)
|
||||||
|
|
||||||
|
def join(self):
|
||||||
|
"""Wait for the update to complete."""
|
||||||
|
if self.stream is not None:
|
||||||
|
self.stream.synchronize()
|
||||||
|
|
||||||
|
if self.thread is not None:
|
||||||
|
self.thread.join()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Return the state dict for the optimizer."""
|
||||||
|
self.join()
|
||||||
|
|
||||||
|
if self.save_original_optimizer_state:
|
||||||
|
return self.optimizer.state_dict()
|
||||||
|
|
||||||
|
# if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
|
||||||
|
ema_params = (
|
||||||
|
self.ema_params
|
||||||
|
if not self.in_saving_ema_model_context
|
||||||
|
else list(self.all_parameters())
|
||||||
|
)
|
||||||
|
state_dict = {
|
||||||
|
"opt": self.optimizer.state_dict(),
|
||||||
|
"ema": ema_params,
|
||||||
|
"current_step": self.current_step,
|
||||||
|
"decay": self.decay,
|
||||||
|
"every_n_steps": self.every_n_steps,
|
||||||
|
}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Load the state dict for the optimizer."""
|
||||||
|
self.join()
|
||||||
|
|
||||||
|
self.optimizer.load_state_dict(state_dict["opt"])
|
||||||
|
self.ema_params = tuple(
|
||||||
|
param.to(self.device) for param in copy.deepcopy(state_dict["ema"])
|
||||||
|
)
|
||||||
|
self.current_step = state_dict["current_step"]
|
||||||
|
self.decay = state_dict["decay"]
|
||||||
|
self.every_n_steps = state_dict["every_n_steps"]
|
||||||
|
self.rebuild_ema_params = False
|
||||||
|
|
||||||
|
def add_param_group(self, param_group):
|
||||||
|
"""Add a param group to the optimizer."""
|
||||||
|
self.optimizer.add_param_group(param_group)
|
||||||
|
self.rebuild_ema_params = True
|
||||||
|
|
||||||
|
|
||||||
|
class EMAModelCheckpoint(ModelCheckpoint):
|
||||||
|
"""EMA version of ModelCheckpoint that saves EMA checkpoints as well."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _ema_callback(self, trainer: "pl.Trainer") -> Optional[EMA]:
|
||||||
|
"""Returns the EMA callback if it exists."""
|
||||||
|
ema_callback = None
|
||||||
|
for callback in trainer.callbacks:
|
||||||
|
if isinstance(callback, EMA):
|
||||||
|
ema_callback = callback
|
||||||
|
return ema_callback
|
||||||
|
|
||||||
|
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||||
|
"""Saves the checkpoint file and the EMA checkpoint file if it exists."""
|
||||||
|
ema_callback = self._ema_callback(trainer)
|
||||||
|
if ema_callback is not None:
|
||||||
|
with ema_callback.save_original_optimizer_state(trainer):
|
||||||
|
super()._save_checkpoint(trainer, filepath)
|
||||||
|
|
||||||
|
# save EMA copy of the model as well.
|
||||||
|
with ema_callback.save_ema_model(trainer):
|
||||||
|
filepath = self._ema_format_filepath(filepath)
|
||||||
|
if self.verbose:
|
||||||
|
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
|
||||||
|
super()._save_checkpoint(trainer, filepath)
|
||||||
|
else:
|
||||||
|
super()._save_checkpoint(trainer, filepath)
|
||||||
|
|
||||||
|
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||||
|
"""Removes the checkpoint file and the EMA checkpoint file if it exists."""
|
||||||
|
super()._remove_checkpoint(trainer, filepath)
|
||||||
|
ema_callback = self._ema_callback(trainer)
|
||||||
|
if ema_callback is not None:
|
||||||
|
# remove EMA copy of the state dict as well.
|
||||||
|
filepath = self._ema_format_filepath(filepath)
|
||||||
|
super()._remove_checkpoint(trainer, filepath)
|
||||||
|
|
||||||
|
def _ema_format_filepath(self, filepath: str) -> str:
|
||||||
|
"""Appends '-EMA' to the filepath."""
|
||||||
|
return filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}")
|
||||||
|
|
||||||
|
def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool:
|
||||||
|
"""Checks if any of the checkpoints are EMA checkpoints."""
|
||||||
|
return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints)
|
||||||
|
|
||||||
|
def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool:
|
||||||
|
"""Checks if the filepath is an EMA checkpoint."""
|
||||||
|
return str(filepath).endswith(f"-EMA{self.FILE_EXTENSION}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _saved_checkpoint_paths(self) -> Iterable[Path]:
|
||||||
|
"""Returns all the saved checkpoint paths in the directory."""
|
||||||
|
return Path(self.dirpath).rglob("*.ckpt")
|
||||||
857
flowdock/models/components/cpm.py
Normal file
857
flowdock/models/components/cpm.py
Normal file
@@ -0,0 +1,857 @@
|
|||||||
|
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, Optional, Tuple
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.models.components.embedding import (
|
||||||
|
GaussianFourierEncoding1D,
|
||||||
|
RelativeGeometryEncoding,
|
||||||
|
)
|
||||||
|
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
|
||||||
|
from flowdock.models.components.modules import (
|
||||||
|
BiDirectionalTriangleAttention,
|
||||||
|
TransformerLayer,
|
||||||
|
)
|
||||||
|
from flowdock.utils import RankedLogger
|
||||||
|
from flowdock.utils.frame_utils import cartesian_to_internal, get_frame_matrix
|
||||||
|
from flowdock.utils.model_utils import GELUMLP
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
STATE_DICT = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ProtFormer(torch.nn.Module):
|
||||||
|
"""Protein relational reasoning with downsampled edges."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
pair_dim: int,
|
||||||
|
n_blocks: int = 4,
|
||||||
|
n_heads: int = 8,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Initialize the ProtFormer model."""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.pair_dim = pair_dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_blocks = n_blocks
|
||||||
|
self.time_encoding = GaussianFourierEncoding1D(16)
|
||||||
|
self.res_in_mlp = GELUMLP(dim + 32, dim, dropout=dropout)
|
||||||
|
self.chain_pos_encoding = GaussianFourierEncoding1D(self.pair_dim // 4)
|
||||||
|
|
||||||
|
self.rel_geom_enc = RelativeGeometryEncoding(15, self.pair_dim)
|
||||||
|
self.template_nenc = GELUMLP(64 + 37 * 3, self.dim, n_hidden_feats=128)
|
||||||
|
self.template_eenc = RelativeGeometryEncoding(15, self.pair_dim)
|
||||||
|
self.template_binding_site_enc = torch.nn.Linear(1, 64, bias=False)
|
||||||
|
self.pp_edge_embed = GELUMLP(
|
||||||
|
pair_dim + self.pair_dim // 4 * 2 + dim * 2,
|
||||||
|
self.pair_dim,
|
||||||
|
n_hidden_feats=dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.graph_stacks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerLayer(
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
head_dim=pair_dim // n_heads,
|
||||||
|
edge_channels=pair_dim,
|
||||||
|
edge_update=True,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
for _ in range(self.n_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ABab_mha = TransformerLayer(
|
||||||
|
pair_dim,
|
||||||
|
n_heads,
|
||||||
|
bidirectional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.triangle_stacks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
BiDirectionalTriangleAttention(pair_dim, pair_dim // n_heads, n_heads)
|
||||||
|
for _ in range(self.n_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.graph_relations = [
|
||||||
|
(
|
||||||
|
"residue_to_residue",
|
||||||
|
"gather_idx_ab_a",
|
||||||
|
"gather_idx_ab_b",
|
||||||
|
"prot_res",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_residue_to_sampled_residue",
|
||||||
|
"gather_idx_AB_a",
|
||||||
|
"gather_idx_AB_b",
|
||||||
|
"prot_res",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def compute_chain_pe(
|
||||||
|
self,
|
||||||
|
residue_index,
|
||||||
|
res_chain_index,
|
||||||
|
src_rid,
|
||||||
|
dst_rid,
|
||||||
|
):
|
||||||
|
"""Compute chain positional encoding for a pair of residues."""
|
||||||
|
chain_disp = residue_index[src_rid] - residue_index[dst_rid]
|
||||||
|
chain_rope = self.chain_pos_encoding(chain_disp.div(8).unsqueeze(-1)).div(
|
||||||
|
chain_disp.div(8).abs().add(1).unsqueeze(-1)
|
||||||
|
)
|
||||||
|
# Mask cross-chain entries
|
||||||
|
chain_mask = res_chain_index[src_rid] == res_chain_index[dst_rid]
|
||||||
|
chain_rope = chain_rope * chain_mask[..., None]
|
||||||
|
return chain_rope
|
||||||
|
|
||||||
|
def compute_chain_pair_pe(
|
||||||
|
self,
|
||||||
|
residue_index,
|
||||||
|
res_chain_index,
|
||||||
|
AB_broadcasted_rid,
|
||||||
|
ab_rid,
|
||||||
|
AB_broadcasted_cid,
|
||||||
|
ab_cid,
|
||||||
|
):
|
||||||
|
"""Compute chain positional encoding for a pair of residues."""
|
||||||
|
chain_disp_row = residue_index[AB_broadcasted_rid] - residue_index[ab_rid]
|
||||||
|
chain_disp_col = residue_index[AB_broadcasted_cid] - residue_index[ab_cid]
|
||||||
|
chain_disp = torch.stack([chain_disp_row, chain_disp_col], dim=-1)
|
||||||
|
chain_rope = self.chain_pos_encoding(chain_disp.div(8).unsqueeze(-1)).div(
|
||||||
|
chain_disp.div(8).abs().add(1).unsqueeze(-1)
|
||||||
|
)
|
||||||
|
# Mask cross-chain entries
|
||||||
|
chain_mask_row = res_chain_index[AB_broadcasted_rid] == res_chain_index[ab_rid]
|
||||||
|
chain_mask_col = res_chain_index[AB_broadcasted_cid] == res_chain_index[ab_cid]
|
||||||
|
chain_mask = torch.stack([chain_mask_row, chain_mask_col], dim=-1)
|
||||||
|
chain_rope = (chain_rope * chain_mask[..., None]).flatten(-2, -1)
|
||||||
|
return chain_rope
|
||||||
|
|
||||||
|
def eval_protein_template_encodings(self, batch, edge_idx, use_plddt=False):
|
||||||
|
"""Evaluate template encodings for protein residues."""
|
||||||
|
with torch.no_grad():
|
||||||
|
template_bb_coords = batch["features"]["apo_res_atom_positions"][:, :3]
|
||||||
|
template_bb_frames = get_frame_matrix(
|
||||||
|
template_bb_coords[:, 0, :],
|
||||||
|
template_bb_coords[:, 1, :],
|
||||||
|
template_bb_coords[:, 2, :],
|
||||||
|
)
|
||||||
|
# Add template local representations & lddt
|
||||||
|
template_local_coords = cartesian_to_internal(
|
||||||
|
batch["features"]["apo_res_atom_positions"],
|
||||||
|
template_bb_frames.unsqueeze(1),
|
||||||
|
)
|
||||||
|
template_local_coords[~batch["features"]["apo_res_atom_mask"].bool()] = 0
|
||||||
|
# if use_plddt:
|
||||||
|
# template_plddt_enc = F.one_hot(
|
||||||
|
# torch.bucketize(
|
||||||
|
# batch["features"]["apo_pLDDT"],
|
||||||
|
# torch.linspace(0, 1, 65, device=template_bb_coords.device)[:-1],
|
||||||
|
# right=True,
|
||||||
|
# )
|
||||||
|
# - 1,
|
||||||
|
# num_classes=64,
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# template_plddt_enc = torch.zeros(
|
||||||
|
# template_local_coords.shape[0], 64, device=template_bb_coords.device
|
||||||
|
# )
|
||||||
|
if self.training:
|
||||||
|
use_sidechain_coords = random.randint(0, 1) # nosec
|
||||||
|
template_local_coords = template_local_coords * use_sidechain_coords
|
||||||
|
# use_plddt_input = random.randint(0, 1)
|
||||||
|
# template_plddt_enc = template_plddt_enc * use_plddt_input
|
||||||
|
if "binding_site_mask" in batch["features"].keys():
|
||||||
|
# Externally-specified binding residue list
|
||||||
|
binding_site_enc = self.template_binding_site_enc(
|
||||||
|
batch["features"]["binding_site_mask"][:, None].float()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
binding_site_enc = torch.zeros(
|
||||||
|
template_local_coords.shape[0], 64, device=template_bb_coords.device
|
||||||
|
)
|
||||||
|
template_nfeat = self.template_nenc(
|
||||||
|
torch.cat([template_local_coords.flatten(-2, -1), binding_site_enc], dim=-1)
|
||||||
|
)
|
||||||
|
template_efeat = self.template_eenc(template_bb_frames, edge_idx)
|
||||||
|
template_alignment_mask = batch["features"]["apo_res_alignment_mask"].float()
|
||||||
|
if self.training:
|
||||||
|
# template_alignment_mask = template_alignment_mask * use_template
|
||||||
|
nomasking_rate = random.randint(9, 10) / 10 # nosec
|
||||||
|
template_alignment_mask = template_alignment_mask * (
|
||||||
|
torch.rand_like(template_alignment_mask) < nomasking_rate
|
||||||
|
)
|
||||||
|
template_nfeat = template_nfeat * template_alignment_mask.unsqueeze(-1)
|
||||||
|
template_efeat = (
|
||||||
|
template_efeat
|
||||||
|
* template_alignment_mask[edge_idx[0]].unsqueeze(-1)
|
||||||
|
* template_alignment_mask[edge_idx[1]].unsqueeze(-1)
|
||||||
|
)
|
||||||
|
return template_nfeat, template_efeat
|
||||||
|
|
||||||
|
def forward(self, batch, **kwargs):
|
||||||
|
"""Forward pass of the ProtFormer model."""
|
||||||
|
return self.forward_prot_sample(batch, **kwargs)
|
||||||
|
|
||||||
|
def forward_prot_sample(
|
||||||
|
self,
|
||||||
|
batch,
|
||||||
|
embed_coords=True,
|
||||||
|
in_attr_suffix="",
|
||||||
|
out_attr_suffix="",
|
||||||
|
use_template=False,
|
||||||
|
use_plddt=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Forward pass of the ProtFormer model for a single protein sample."""
|
||||||
|
features = batch["features"]
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
device = features["res_type"].device
|
||||||
|
|
||||||
|
time_encoding = self.time_encoding(features["timestep_encoding_prot"])
|
||||||
|
if not embed_coords:
|
||||||
|
time_encoding = torch.zeros_like(time_encoding)
|
||||||
|
|
||||||
|
residue_rep = (
|
||||||
|
self.res_in_mlp(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
features["res_embedding_in"],
|
||||||
|
time_encoding,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
+ features["res_embedding_in"]
|
||||||
|
)
|
||||||
|
batch_size = metadata["num_structid"]
|
||||||
|
|
||||||
|
# Prepare indexers
|
||||||
|
# Use max to ensure segmentation faults are 100% invoked
|
||||||
|
# in case there are any bad indices
|
||||||
|
max(metadata["num_a_per_sample"])
|
||||||
|
n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"]
|
||||||
|
|
||||||
|
indexer["gather_idx_pid_b"] = indexer["gather_idx_pid_a"]
|
||||||
|
# Evaluate gather_idx_AB_a and gather_idx_AB_b
|
||||||
|
# Assign a to rows and b to columns
|
||||||
|
# Simple broadcasting for single-structure batches
|
||||||
|
indexer["gather_idx_AB_a"] = (
|
||||||
|
indexer["gather_idx_pid_a"]
|
||||||
|
.view(batch_size, n_protein_patches)[:, :, None]
|
||||||
|
.expand(-1, -1, n_protein_patches)
|
||||||
|
.contiguous()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
indexer["gather_idx_AB_b"] = (
|
||||||
|
indexer["gather_idx_pid_b"]
|
||||||
|
.view(batch_size, n_protein_patches)[:, None, :]
|
||||||
|
.expand(-1, n_protein_patches, -1)
|
||||||
|
.contiguous()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle all batch offsets here
|
||||||
|
graph_batcher = make_multi_relation_graph_batcher(self.graph_relations, indexer, metadata)
|
||||||
|
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||||
|
|
||||||
|
input_protein_coords_padded = features["input_protein_coords"]
|
||||||
|
backbone_frames = get_frame_matrix(
|
||||||
|
input_protein_coords_padded[:, 0, :],
|
||||||
|
input_protein_coords_padded[:, 1, :],
|
||||||
|
input_protein_coords_padded[:, 2, :],
|
||||||
|
)
|
||||||
|
batch["features"]["backbone_frames"] = backbone_frames
|
||||||
|
# Adding geometrical info to pair representations
|
||||||
|
|
||||||
|
chain_pe = self.compute_chain_pe(
|
||||||
|
features["residue_index"],
|
||||||
|
features["res_chain_id"],
|
||||||
|
merged_edge_idx[0],
|
||||||
|
merged_edge_idx[1],
|
||||||
|
)
|
||||||
|
geometry_pe = self.rel_geom_enc(backbone_frames, merged_edge_idx)
|
||||||
|
if not embed_coords:
|
||||||
|
geometry_pe = torch.zeros_like(geometry_pe)
|
||||||
|
merged_edge_reps = self.pp_edge_embed(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
geometry_pe,
|
||||||
|
chain_pe,
|
||||||
|
residue_rep[merged_edge_idx[0]],
|
||||||
|
residue_rep[merged_edge_idx[1]],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if use_template and "apo_res_atom_positions" in features.keys():
|
||||||
|
(
|
||||||
|
template_res_encodings,
|
||||||
|
template_geom_encodings,
|
||||||
|
) = self.eval_protein_template_encodings(batch, merged_edge_idx, use_plddt=use_plddt)
|
||||||
|
residue_rep = residue_rep + template_res_encodings
|
||||||
|
merged_edge_reps = merged_edge_reps + template_geom_encodings
|
||||||
|
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
|
||||||
|
|
||||||
|
node_reps = {"prot_res": residue_rep}
|
||||||
|
|
||||||
|
gather_idx_res_protpatch = indexer["gather_idx_a_pid"]
|
||||||
|
# Pointer: AB->AB, ab->AB
|
||||||
|
gather_idx_ab_AB = (
|
||||||
|
indexer["gather_idx_ab_structid"] * n_protein_patches**2
|
||||||
|
+ (gather_idx_res_protpatch % n_protein_patches)[indexer["gather_idx_ab_a"]]
|
||||||
|
* n_protein_patches
|
||||||
|
+ (gather_idx_res_protpatch % n_protein_patches)[indexer["gather_idx_ab_b"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Intertwine graph iterations and triangle iterations
|
||||||
|
for block_id in range(self.n_blocks):
|
||||||
|
# Communicate between atomistic and patch resolutions
|
||||||
|
# Up-sampling for interface edge embeddings
|
||||||
|
rec_pair_rep = edge_reps["residue_to_residue"]
|
||||||
|
AB_grid_attr_flat = edge_reps["sampled_residue_to_sampled_residue"]
|
||||||
|
# Upper-left block: intra-window visual-attention
|
||||||
|
# Cross-attention between random and grid edges
|
||||||
|
rec_pair_rep, AB_grid_attr_flat = self.ABab_mha(
|
||||||
|
rec_pair_rep,
|
||||||
|
AB_grid_attr_flat,
|
||||||
|
(
|
||||||
|
torch.arange(metadata["num_ab"], device=device),
|
||||||
|
gather_idx_ab_AB,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
AB_grid_attr = AB_grid_attr_flat.view(
|
||||||
|
batch_size,
|
||||||
|
n_protein_patches,
|
||||||
|
n_protein_patches,
|
||||||
|
self.pair_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inter-patch triangle attentions, refining intermolecular edges
|
||||||
|
_, AB_grid_attr = self.triangle_stacks[block_id](
|
||||||
|
AB_grid_attr,
|
||||||
|
AB_grid_attr,
|
||||||
|
AB_grid_attr.unsqueeze(-4),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transfer grid-formatted representations to edges
|
||||||
|
edge_reps["residue_to_residue"] = rec_pair_rep
|
||||||
|
edge_reps["sampled_residue_to_sampled_residue"] = AB_grid_attr.flatten(0, 2)
|
||||||
|
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
|
||||||
|
merged_edge_reps = graph_batcher.collate_edge_attr(edge_reps)
|
||||||
|
|
||||||
|
# Graph transformer iteration
|
||||||
|
_, merged_node_reps, merged_edge_reps = self.graph_stacks[block_id](
|
||||||
|
merged_node_reps,
|
||||||
|
merged_node_reps,
|
||||||
|
merged_edge_idx,
|
||||||
|
merged_edge_reps,
|
||||||
|
)
|
||||||
|
node_reps = graph_batcher.offload_node_attr(merged_node_reps)
|
||||||
|
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
|
||||||
|
|
||||||
|
batch["features"][f"rec_res_attr{out_attr_suffix}"] = node_reps["prot_res"]
|
||||||
|
batch["features"][f"res_res_pair_attr{out_attr_suffix}"] = edge_reps["residue_to_residue"]
|
||||||
|
batch["features"][f"res_res_grid_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||||
|
"sampled_residue_to_sampled_residue"
|
||||||
|
]
|
||||||
|
batch["indexer"]["gather_idx_AB_a"] = indexer["gather_idx_AB_a"]
|
||||||
|
batch["indexer"]["gather_idx_AB_b"] = indexer["gather_idx_AB_b"]
|
||||||
|
batch["indexer"]["gather_idx_ab_AB"] = gather_idx_ab_AB
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class BindingFormer(ProtFormer):
|
||||||
|
"""Edge inference on protein-ligand graphs."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
pair_dim: int,
|
||||||
|
n_blocks: int = 4,
|
||||||
|
n_heads: int = 8,
|
||||||
|
n_ligand_patches: int = 16,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Initialize the BindingFormer model."""
|
||||||
|
super().__init__(
|
||||||
|
dim,
|
||||||
|
pair_dim,
|
||||||
|
n_blocks,
|
||||||
|
n_heads,
|
||||||
|
dropout,
|
||||||
|
)
|
||||||
|
self.dim = dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_blocks = n_blocks
|
||||||
|
self.n_ligand_patches = n_ligand_patches
|
||||||
|
self.pl_edge_embed = GELUMLP(dim * 2, self.pair_dim, n_hidden_feats=dim, dropout=dropout)
|
||||||
|
self.AaJ_mha = TransformerLayer(pair_dim, n_heads, bidirectional=True)
|
||||||
|
|
||||||
|
self.graph_relations = [
|
||||||
|
(
|
||||||
|
"residue_to_residue",
|
||||||
|
"gather_idx_ab_a",
|
||||||
|
"gather_idx_ab_b",
|
||||||
|
"prot_res",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_residue_to_sampled_residue",
|
||||||
|
"gather_idx_AB_a",
|
||||||
|
"gather_idx_AB_b",
|
||||||
|
"prot_res",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_residue_to_sampled_lig_triplet",
|
||||||
|
"gather_idx_AJ_a",
|
||||||
|
"gather_idx_AJ_J",
|
||||||
|
"prot_res",
|
||||||
|
"lig_trp",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_lig_triplet_to_sampled_residue",
|
||||||
|
"gather_idx_AJ_J",
|
||||||
|
"gather_idx_AJ_a",
|
||||||
|
"lig_trp",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"residue_to_sampled_lig_triplet",
|
||||||
|
"gather_idx_aJ_a",
|
||||||
|
"gather_idx_aJ_J",
|
||||||
|
"prot_res",
|
||||||
|
"lig_trp",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_lig_triplet_to_residue",
|
||||||
|
"gather_idx_aJ_J",
|
||||||
|
"gather_idx_aJ_a",
|
||||||
|
"lig_trp",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_lig_triplet_to_sampled_lig_triplet",
|
||||||
|
"gather_idx_IJ_I",
|
||||||
|
"gather_idx_IJ_J",
|
||||||
|
"lig_trp",
|
||||||
|
"lig_trp",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch,
|
||||||
|
observed_block_contacts=None,
|
||||||
|
in_attr_suffix="",
|
||||||
|
out_attr_suffix="",
|
||||||
|
):
|
||||||
|
"""Forward pass of the BindingFormer model."""
|
||||||
|
features = batch["features"]
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
device = features["res_type"].device
|
||||||
|
# Synchronize with a language model
|
||||||
|
residue_rep = features[f"rec_res_attr{in_attr_suffix}"]
|
||||||
|
rec_pair_rep = features[f"res_res_pair_attr{in_attr_suffix}"]
|
||||||
|
# Inherit the last-layer pair representations from protein encoder
|
||||||
|
AB_grid_attr_flat = features[f"res_res_grid_attr_flat{in_attr_suffix}"]
|
||||||
|
|
||||||
|
# Prepare indexers
|
||||||
|
batch_size = metadata["num_structid"]
|
||||||
|
n_a_per_sample = max(metadata["num_a_per_sample"])
|
||||||
|
n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"]
|
||||||
|
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
n_ligand_patches = max(metadata["num_I_per_sample"])
|
||||||
|
max(metadata["num_molid_per_sample"])
|
||||||
|
lig_frame_rep = features[f"lig_trp_attr{in_attr_suffix}"]
|
||||||
|
UI_grid_attr = features["lig_af_grid_attr_projected"]
|
||||||
|
IJ_grid_attr = (UI_grid_attr + UI_grid_attr.transpose(1, 2)) / 2
|
||||||
|
|
||||||
|
aJ_grid_attr = self.pl_edge_embed(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
residue_rep.view(batch_size, n_a_per_sample, self.dim)[:, :, None].expand(
|
||||||
|
-1, -1, n_ligand_patches, -1
|
||||||
|
),
|
||||||
|
lig_frame_rep.view(batch_size, n_ligand_patches, self.dim)[
|
||||||
|
:, None, :
|
||||||
|
].expand(-1, n_a_per_sample, -1, -1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
AJ_grid_attr = IJ_grid_attr.new_zeros(
|
||||||
|
batch_size, n_protein_patches, n_ligand_patches, self.pair_dim
|
||||||
|
)
|
||||||
|
gather_idx_I_I = torch.arange(
|
||||||
|
batch_size * n_ligand_patches, device=AJ_grid_attr.device
|
||||||
|
)
|
||||||
|
gather_idx_a_a = torch.arange(batch_size * n_a_per_sample, device=AJ_grid_attr.device)
|
||||||
|
# Note: off-diagonal (AJ) blocks are zero-initialized in the prior stack
|
||||||
|
indexer["gather_idx_IJ_I"] = (
|
||||||
|
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, :, None]
|
||||||
|
.expand(-1, -1, n_ligand_patches)
|
||||||
|
.contiguous()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
indexer["gather_idx_IJ_J"] = (
|
||||||
|
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
|
||||||
|
.expand(-1, n_ligand_patches, -1)
|
||||||
|
.contiguous()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
indexer["gather_idx_AJ_a"] = (
|
||||||
|
indexer["gather_idx_pid_a"]
|
||||||
|
.view(batch_size, n_protein_patches)[:, :, None]
|
||||||
|
.expand(-1, -1, n_ligand_patches)
|
||||||
|
.contiguous()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
indexer["gather_idx_AJ_J"] = (
|
||||||
|
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
|
||||||
|
.expand(-1, n_protein_patches, -1)
|
||||||
|
.contiguous()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
indexer["gather_idx_aJ_a"] = (
|
||||||
|
gather_idx_a_a.view(batch_size, n_a_per_sample)[:, :, None]
|
||||||
|
.expand(-1, -1, n_ligand_patches)
|
||||||
|
.contiguous()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
indexer["gather_idx_aJ_J"] = (
|
||||||
|
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
|
||||||
|
.expand(-1, n_a_per_sample, -1)
|
||||||
|
.contiguous()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
batch["indexer"] = indexer
|
||||||
|
|
||||||
|
if observed_block_contacts is not None:
|
||||||
|
# Generative feedback from block one-hot sampling
|
||||||
|
# AJ_grid_attr = (
|
||||||
|
# AJ_grid_attr
|
||||||
|
# + observed_block_contacts.transpose(1, 2)
|
||||||
|
# .contiguous()
|
||||||
|
# .flatten(0, 1)[indexer["gather_idx_I_molid"]]
|
||||||
|
# .view(batch_size, n_ligand_patches, n_protein_patches, -1)
|
||||||
|
# .transpose(1, 2)
|
||||||
|
# .contiguous()
|
||||||
|
# )
|
||||||
|
AJ_grid_attr = AJ_grid_attr + observed_block_contacts
|
||||||
|
|
||||||
|
graph_batcher = make_multi_relation_graph_batcher(
|
||||||
|
self.graph_relations, indexer, metadata
|
||||||
|
)
|
||||||
|
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||||
|
node_reps = {
|
||||||
|
"prot_res": residue_rep,
|
||||||
|
"lig_trp": lig_frame_rep,
|
||||||
|
}
|
||||||
|
edge_reps = {
|
||||||
|
"residue_to_residue": rec_pair_rep,
|
||||||
|
"sampled_residue_to_sampled_residue": AB_grid_attr_flat,
|
||||||
|
"sampled_lig_triplet_to_sampled_residue": AJ_grid_attr.flatten(0, 2),
|
||||||
|
"sampled_residue_to_sampled_lig_triplet": AJ_grid_attr.flatten(0, 2),
|
||||||
|
"sampled_lig_triplet_to_residue": aJ_grid_attr.flatten(0, 2),
|
||||||
|
"residue_to_sampled_lig_triplet": aJ_grid_attr.flatten(0, 2),
|
||||||
|
"sampled_lig_triplet_to_sampled_lig_triplet": IJ_grid_attr.flatten(0, 2),
|
||||||
|
}
|
||||||
|
edge_reps = graph_batcher.zero_pad_edge_attr(edge_reps, self.dim, device)
|
||||||
|
else:
|
||||||
|
graph_batcher = make_multi_relation_graph_batcher(
|
||||||
|
self.graph_relations[:2], indexer, metadata
|
||||||
|
)
|
||||||
|
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||||
|
|
||||||
|
node_reps = {
|
||||||
|
"prot_res": residue_rep,
|
||||||
|
}
|
||||||
|
edge_reps = {
|
||||||
|
"residue_to_residue": rec_pair_rep,
|
||||||
|
"sampled_residue_to_sampled_residue": AB_grid_attr_flat,
|
||||||
|
}
|
||||||
|
edge_reps = graph_batcher.zero_pad_edge_attr(edge_reps, self.dim, device)
|
||||||
|
|
||||||
|
# Intertwine graph iterations and triangle iterations
|
||||||
|
gather_idx_res_protpatch = indexer["gather_idx_a_pid"]
|
||||||
|
for block_id in range(self.n_blocks):
|
||||||
|
# Communicate between atomistic and patch resolutions
|
||||||
|
# Up-sampling for interface edge embeddings
|
||||||
|
rec_pair_rep = edge_reps["residue_to_residue"]
|
||||||
|
AB_grid_attr_flat = edge_reps["sampled_residue_to_sampled_residue"]
|
||||||
|
AB_grid_attr = AB_grid_attr_flat.view(
|
||||||
|
batch_size,
|
||||||
|
n_protein_patches,
|
||||||
|
n_protein_patches,
|
||||||
|
self.pair_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
# Symmetrize off-diagonal blocks
|
||||||
|
AJ_grid_attr_flat_ = (
|
||||||
|
edge_reps["sampled_residue_to_sampled_lig_triplet"]
|
||||||
|
+ edge_reps["sampled_lig_triplet_to_sampled_residue"]
|
||||||
|
) / 2
|
||||||
|
AJ_grid_attr = AJ_grid_attr_flat_.contiguous().view(
|
||||||
|
batch_size, n_protein_patches, n_ligand_patches, -1
|
||||||
|
)
|
||||||
|
aJ_grid_attr_flat_ = (
|
||||||
|
edge_reps["residue_to_sampled_lig_triplet"]
|
||||||
|
+ edge_reps["sampled_lig_triplet_to_residue"]
|
||||||
|
) / 2
|
||||||
|
aJ_grid_attr = aJ_grid_attr_flat_.contiguous().view(
|
||||||
|
batch_size, n_a_per_sample, n_ligand_patches, -1
|
||||||
|
)
|
||||||
|
IJ_grid_attr = (
|
||||||
|
edge_reps["sampled_lig_triplet_to_sampled_lig_triplet"]
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size, n_ligand_patches, n_ligand_patches, -1)
|
||||||
|
)
|
||||||
|
AJ_grid_attr_temp_, aJ_grid_attr_temp_ = self.AaJ_mha(
|
||||||
|
AJ_grid_attr.flatten(0, 1),
|
||||||
|
aJ_grid_attr.flatten(0, 1),
|
||||||
|
(
|
||||||
|
gather_idx_res_protpatch,
|
||||||
|
torch.arange(gather_idx_res_protpatch.shape[0], device=device),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
AJ_grid_attr = AJ_grid_attr_temp_.contiguous().view(
|
||||||
|
batch_size, n_protein_patches, n_ligand_patches, -1
|
||||||
|
)
|
||||||
|
aJ_grid_attr = aJ_grid_attr_temp_.contiguous().view(
|
||||||
|
batch_size, n_a_per_sample, n_ligand_patches, -1
|
||||||
|
)
|
||||||
|
merged_grid_rep = torch.cat(
|
||||||
|
[
|
||||||
|
torch.cat([AB_grid_attr, AJ_grid_attr], dim=2),
|
||||||
|
torch.cat([AJ_grid_attr.transpose(1, 2), IJ_grid_attr], dim=2),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged_grid_rep = AB_grid_attr
|
||||||
|
|
||||||
|
# Inter-patch triangle attentions
|
||||||
|
_, merged_grid_rep = self.triangle_stacks[block_id](
|
||||||
|
merged_grid_rep,
|
||||||
|
merged_grid_rep,
|
||||||
|
merged_grid_rep.unsqueeze(-4),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dis-assemble the grid representation
|
||||||
|
AB_grid_attr = merged_grid_rep[:, :n_protein_patches, :n_protein_patches]
|
||||||
|
# Transfer grid-formatted representations to edges
|
||||||
|
edge_reps["residue_to_residue"] = rec_pair_rep
|
||||||
|
edge_reps["sampled_residue_to_sampled_residue"] = AB_grid_attr.flatten(0, 2)
|
||||||
|
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
AJ_grid_attr = merged_grid_rep[
|
||||||
|
:, :n_protein_patches, n_protein_patches:
|
||||||
|
].contiguous()
|
||||||
|
IJ_grid_attr = merged_grid_rep[
|
||||||
|
:, n_protein_patches:, n_protein_patches:
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
edge_reps["sampled_residue_to_sampled_lig_triplet"] = AJ_grid_attr.flatten(0, 2)
|
||||||
|
edge_reps["sampled_lig_triplet_to_sampled_residue"] = AJ_grid_attr.flatten(0, 2)
|
||||||
|
edge_reps["residue_to_sampled_lig_triplet"] = aJ_grid_attr.flatten(0, 2)
|
||||||
|
edge_reps["sampled_lig_triplet_to_residue"] = aJ_grid_attr.flatten(0, 2)
|
||||||
|
edge_reps["sampled_lig_triplet_to_sampled_lig_triplet"] = IJ_grid_attr.flatten(
|
||||||
|
0, 2
|
||||||
|
)
|
||||||
|
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
|
||||||
|
merged_edge_reps = graph_batcher.collate_edge_attr(edge_reps)
|
||||||
|
|
||||||
|
# Graph transformer iteration
|
||||||
|
_, merged_node_reps, merged_edge_reps = self.graph_stacks[block_id](
|
||||||
|
merged_node_reps,
|
||||||
|
merged_node_reps,
|
||||||
|
merged_edge_idx,
|
||||||
|
merged_edge_reps,
|
||||||
|
)
|
||||||
|
node_reps = graph_batcher.offload_node_attr(merged_node_reps)
|
||||||
|
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
|
||||||
|
|
||||||
|
batch["features"][f"rec_res_attr{out_attr_suffix}"] = node_reps["prot_res"]
|
||||||
|
batch["features"][f"res_res_pair_attr{out_attr_suffix}"] = edge_reps["residue_to_residue"]
|
||||||
|
batch["features"][f"res_res_grid_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||||
|
"sampled_residue_to_sampled_residue"
|
||||||
|
]
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
batch["features"][f"lig_trp_attr{out_attr_suffix}"] = node_reps["lig_trp"]
|
||||||
|
batch["features"][f"res_trp_grid_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||||
|
"sampled_residue_to_sampled_lig_triplet"
|
||||||
|
]
|
||||||
|
batch["features"][f"res_trp_pair_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||||
|
"residue_to_sampled_lig_triplet"
|
||||||
|
]
|
||||||
|
batch["features"][f"trp_trp_grid_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||||
|
"sampled_lig_triplet_to_sampled_lig_triplet"
|
||||||
|
]
|
||||||
|
batch["metadata"]["n_lig_patches_per_sample"] = n_ligand_patches
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_protein_encoder(
|
||||||
|
protein_model_cfg: DictConfig,
|
||||||
|
task_cfg: DictConfig,
|
||||||
|
state_dict: Optional[STATE_DICT] = None,
|
||||||
|
) -> Tuple[torch.nn.Module, torch.nn.Module]:
|
||||||
|
"""Instantiates a ProtFormer model for protein encoding.
|
||||||
|
|
||||||
|
:param protein_model_cfg: Protein model configuration.
|
||||||
|
:param task_cfg: Task configuration.
|
||||||
|
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||||
|
:return: Protein encoder model and residue input projector.
|
||||||
|
"""
|
||||||
|
node_dim = protein_model_cfg.residue_dim
|
||||||
|
model = ProtFormer(
|
||||||
|
node_dim,
|
||||||
|
protein_model_cfg.pair_dim,
|
||||||
|
n_heads=protein_model_cfg.n_heads,
|
||||||
|
n_blocks=protein_model_cfg.n_encoder_stacks,
|
||||||
|
dropout=task_cfg.dropout,
|
||||||
|
)
|
||||||
|
if protein_model_cfg.use_esm_embedding:
|
||||||
|
# protein sequence language model
|
||||||
|
res_in_projector = torch.nn.Linear(protein_model_cfg.plm_embed_dim, node_dim, bias=False)
|
||||||
|
else:
|
||||||
|
# one-hot amino acid types
|
||||||
|
res_in_projector = torch.nn.Linear(
|
||||||
|
protein_model_cfg.n_aa_types,
|
||||||
|
node_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
if protein_model_cfg.from_pretrained and state_dict is not None:
|
||||||
|
try:
|
||||||
|
# NOTE: we must avoid enforcing strict key matching
|
||||||
|
# due to the (new) weights `template_binding_site_enc.weight`
|
||||||
|
model.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("protein_encoder")
|
||||||
|
},
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained protein encoder weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Skipping loading of pretrained protein encoder weights due to: {e}.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
res_in_projector.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith(
|
||||||
|
"plm_adapter"
|
||||||
|
if protein_model_cfg.use_esm_embedding
|
||||||
|
else "res_in_projector"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained protein input projector weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(
|
||||||
|
f"Skipping loading of pretrained protein input projector weights due to: {e}."
|
||||||
|
)
|
||||||
|
return model, res_in_projector
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_pl_contact_stack(
|
||||||
|
protein_model_cfg: DictConfig,
|
||||||
|
ligand_model_cfg: DictConfig,
|
||||||
|
contact_cfg: DictConfig,
|
||||||
|
task_cfg: DictConfig,
|
||||||
|
state_dict: Optional[STATE_DICT] = None,
|
||||||
|
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.nn.Module]:
|
||||||
|
"""Instantiates a BindingFormer model for protein-ligand contact prediction.
|
||||||
|
|
||||||
|
:param protein_model_cfg: Protein model configuration.
|
||||||
|
:param ligand_model_cfg: Ligand model configuration.
|
||||||
|
:param contact_cfg: Contact prediction configuration.
|
||||||
|
:param task_cfg: Task configuration.
|
||||||
|
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||||
|
:return: Protein-ligand contact prediction model, contact code embedding, distance bins, and
|
||||||
|
distogram head.
|
||||||
|
"""
|
||||||
|
pl_contact_stack = BindingFormer(
|
||||||
|
protein_model_cfg.residue_dim,
|
||||||
|
protein_model_cfg.pair_dim,
|
||||||
|
n_heads=protein_model_cfg.n_heads,
|
||||||
|
n_blocks=contact_cfg.n_stacks,
|
||||||
|
n_ligand_patches=ligand_model_cfg.n_patches,
|
||||||
|
dropout=contact_cfg.dropout if contact_cfg.get("dropout") else task_cfg.dropout,
|
||||||
|
)
|
||||||
|
contact_code_embed = torch.nn.Embedding(2, protein_model_cfg.pair_dim)
|
||||||
|
# Distogram heads
|
||||||
|
dist_bins = torch.nn.Parameter(torch.linspace(2, 22, 32), requires_grad=False)
|
||||||
|
dgram_head = GELUMLP(
|
||||||
|
protein_model_cfg.pair_dim,
|
||||||
|
32,
|
||||||
|
n_hidden_feats=protein_model_cfg.pair_dim,
|
||||||
|
zero_init=True,
|
||||||
|
)
|
||||||
|
if contact_cfg.from_pretrained and state_dict is not None:
|
||||||
|
try:
|
||||||
|
# NOTE: we must avoid enforcing strict key matching
|
||||||
|
# due to the (new) weights `template_binding_site_enc.weight`
|
||||||
|
pl_contact_stack.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("pl_contact_stack")
|
||||||
|
},
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained protein-ligand contact prediction weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(
|
||||||
|
f"Skipping loading of pretrained protein-ligand contact prediction weights due to: {e}."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
contact_code_embed.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("contact_code_embed")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained contact code embedding weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(
|
||||||
|
f"Skipping loading of pretrained contact code embedding weights due to: {e}."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dgram_head.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("dgram_head")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained distogram head weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Skipping loading of pretrained distogram head weights due to: {e}.")
|
||||||
|
return pl_contact_stack, contact_code_embed, dist_bins, dgram_head
|
||||||
105
flowdock/models/components/embedding.py
Normal file
105
flowdock/models/components/embedding.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Tuple
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.utils.frame_utils import RigidTransform
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianFourierEncoding1D(torch.nn.Module):
|
||||||
|
"""Gaussian Fourier Encoding for 1D data."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_basis: int,
|
||||||
|
eps: float = 1e-2,
|
||||||
|
):
|
||||||
|
"""Initialize Gaussian Fourier Encoding."""
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.fourier_freqs = torch.nn.Parameter(torch.randn(n_basis) * math.pi)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Forward pass of Gaussian Fourier Encoding."""
|
||||||
|
encodings = torch.cat(
|
||||||
|
[
|
||||||
|
torch.sin(self.fourier_freqs.mul(x)),
|
||||||
|
torch.cos(self.fourier_freqs.mul(x)),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return encodings
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianRBFEncoding1D(torch.nn.Module):
|
||||||
|
"""Gaussian RBF Encoding for 1D data."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_basis: int,
|
||||||
|
x_max: float,
|
||||||
|
sigma: float = 1.0,
|
||||||
|
):
|
||||||
|
"""Initialize Gaussian RBF Encoding."""
|
||||||
|
super().__init__()
|
||||||
|
self.sigma = sigma
|
||||||
|
self.rbf_centers = torch.nn.Parameter(
|
||||||
|
torch.linspace(0, x_max, n_basis), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Forward pass of Gaussian RBF Encoding."""
|
||||||
|
encodings = torch.exp(-((x.unsqueeze(-1) - self.rbf_centers).div(self.sigma).square()))
|
||||||
|
return encodings
|
||||||
|
|
||||||
|
|
||||||
|
class RelativeGeometryEncoding(torch.nn.Module):
|
||||||
|
"Compute radial basis functions and iterresidue/pseudoresidue orientations."
|
||||||
|
|
||||||
|
def __init__(self, n_basis: int, out_dim: int, d_max: float = 20.0):
|
||||||
|
"""Initialize RelativeGeometryEncoding."""
|
||||||
|
super().__init__()
|
||||||
|
self.rbf_encoding = GaussianRBFEncoding1D(n_basis, d_max)
|
||||||
|
self.rel_geom_projector = torch.nn.Linear(n_basis + 15, out_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, frames: RigidTransform, merged_edge_idx: Tuple[torch.Tensor, torch.Tensor]):
|
||||||
|
"""Forward pass of RelativeGeometryEncoding."""
|
||||||
|
frame_t, frame_R = frames.t, frames.R
|
||||||
|
pair_dists = torch.norm(
|
||||||
|
frame_t[merged_edge_idx[0]] - frame_t[merged_edge_idx[1]],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
pair_directions_l = torch.matmul(
|
||||||
|
(frame_t[merged_edge_idx[1]] - frame_t[merged_edge_idx[0]]).unsqueeze(-2),
|
||||||
|
frame_R[merged_edge_idx[0]],
|
||||||
|
).squeeze(-2) / pair_dists.square().add(1).sqrt().unsqueeze(-1)
|
||||||
|
pair_directions_r = torch.matmul(
|
||||||
|
(frame_t[merged_edge_idx[0]] - frame_t[merged_edge_idx[1]]).unsqueeze(-2),
|
||||||
|
frame_R[merged_edge_idx[1]],
|
||||||
|
).squeeze(-2) / pair_dists.square().add(1).sqrt().unsqueeze(-1)
|
||||||
|
pair_orientations = torch.matmul(
|
||||||
|
frame_R.transpose(-2, -1).contiguous()[merged_edge_idx[0]],
|
||||||
|
frame_R[merged_edge_idx[1]],
|
||||||
|
)
|
||||||
|
return self.rel_geom_projector(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
self.rbf_encoding(pair_dists),
|
||||||
|
pair_directions_l,
|
||||||
|
pair_directions_r,
|
||||||
|
pair_orientations.flatten(-2, -1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
884
flowdock/models/components/esdm.py
Normal file
884
flowdock/models/components/esdm.py
Normal file
@@ -0,0 +1,884 @@
|
|||||||
|
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||||
|
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, Optional, Tuple
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.models.components.embedding import (
|
||||||
|
GaussianFourierEncoding1D,
|
||||||
|
RelativeGeometryEncoding,
|
||||||
|
)
|
||||||
|
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
|
||||||
|
from flowdock.models.components.modules import PointSetAttention
|
||||||
|
from flowdock.utils import RankedLogger
|
||||||
|
from flowdock.utils.frame_utils import RigidTransform, get_frame_matrix
|
||||||
|
from flowdock.utils.model_utils import GELUMLP, AveragePooling, SumPooling, segment_mean
|
||||||
|
|
||||||
|
STATE_DICT = Dict[str, Any]
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalUpdateUsingReferenceRotations(torch.nn.Module):
|
||||||
|
"""Update local geometric representations using reference rotations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fiber_dim: int,
|
||||||
|
extra_feat_dim: int = 0,
|
||||||
|
eps: float = 1e-4,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
hidden_dim: Optional[int] = None,
|
||||||
|
zero_init: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the LocalUpdateUsingReferenceRotations module."""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = fiber_dim * 5 + extra_feat_dim
|
||||||
|
self.fiber_dim = fiber_dim
|
||||||
|
self.mlp = GELUMLP(
|
||||||
|
self.dim,
|
||||||
|
fiber_dim * 4,
|
||||||
|
dropout=dropout,
|
||||||
|
zero_init=zero_init,
|
||||||
|
n_hidden_feats=hidden_dim,
|
||||||
|
)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
rotation_mats: torch.Tensor,
|
||||||
|
extra_feats=None,
|
||||||
|
):
|
||||||
|
"""Forward pass of the LocalUpdateUsingReferenceRotations module."""
|
||||||
|
# Vector norms are evaluated without applying rigid transform
|
||||||
|
vecx_local = torch.matmul(
|
||||||
|
x[:, 1:].transpose(-2, -1),
|
||||||
|
rotation_mats,
|
||||||
|
)
|
||||||
|
x1_local = torch.cat(
|
||||||
|
[
|
||||||
|
x[:, 0],
|
||||||
|
vecx_local.flatten(-2, -1),
|
||||||
|
x[:, 1:].square().sum(dim=-2).add(self.eps).sqrt(),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
if extra_feats is not None:
|
||||||
|
x1_local = torch.cat([x1_local, extra_feats], dim=-1)
|
||||||
|
x1_local = self.mlp(x1_local).view(-1, 4, self.fiber_dim)
|
||||||
|
vecx1_out = torch.matmul(
|
||||||
|
rotation_mats,
|
||||||
|
x1_local[:, 1:],
|
||||||
|
)
|
||||||
|
x1_out = torch.cat([x1_local[:, :1], vecx1_out], dim=-2)
|
||||||
|
return x1_out
|
||||||
|
|
||||||
|
|
||||||
|
class LocalUpdateUsingChannelWiseGating(torch.nn.Module):
|
||||||
|
"""Update local geometric representations using channel-wise gating."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fiber_dim: int,
|
||||||
|
eps: float = 1e-4,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
hidden_dim: Optional[int] = None,
|
||||||
|
zero_init: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the LocalUpdateUsingChannelWiseGating module."""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = fiber_dim * 2
|
||||||
|
self.fiber_dim = fiber_dim
|
||||||
|
self.mlp = GELUMLP(
|
||||||
|
self.dim,
|
||||||
|
self.dim,
|
||||||
|
dropout=dropout,
|
||||||
|
n_hidden_feats=hidden_dim,
|
||||||
|
zero_init=zero_init,
|
||||||
|
)
|
||||||
|
self.gate = torch.nn.Sigmoid()
|
||||||
|
self.lin_out = torch.nn.Linear(fiber_dim, fiber_dim, bias=False)
|
||||||
|
if zero_init:
|
||||||
|
self.lin_out.weight.data.fill_(0.0)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Forward pass of the LocalUpdateUsingChannelWiseGating module."""
|
||||||
|
x1 = torch.cat(
|
||||||
|
[
|
||||||
|
x[:, 0],
|
||||||
|
x[:, 1:].square().sum(dim=-2).add(self.eps).sqrt(),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
x1 = self.mlp(x1)
|
||||||
|
# Gated nonlinear operation on l=1 representations
|
||||||
|
x1_scalar, x1_gatein = torch.split(x1, self.fiber_dim, dim=-1)
|
||||||
|
x1_gate = self.gate(x1_gatein).unsqueeze(-2)
|
||||||
|
vecx1_out = self.lin_out(x[:, 1:]).mul(x1_gate)
|
||||||
|
x1_out = torch.cat([x1_scalar.unsqueeze(-2), vecx1_out], dim=-2)
|
||||||
|
return x1_out
|
||||||
|
|
||||||
|
|
||||||
|
class EquivariantTransformerBlock(torch.nn.Module):
|
||||||
|
"""Equivariant Transformer Block module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fiber_dim: int,
|
||||||
|
heads: int = 8,
|
||||||
|
point_dim: int = 4,
|
||||||
|
eps: float = 1e-4,
|
||||||
|
edge_dim: Optional[int] = None,
|
||||||
|
target_frames: bool = False,
|
||||||
|
edge_update: bool = False,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Initialize the EquivariantTransformerBlock module."""
|
||||||
|
super().__init__()
|
||||||
|
self.attn_conv = PointSetAttention(
|
||||||
|
fiber_dim,
|
||||||
|
heads=heads,
|
||||||
|
point_dim=point_dim,
|
||||||
|
edge_dim=edge_dim,
|
||||||
|
edge_update=edge_update,
|
||||||
|
)
|
||||||
|
self.fiber_dim = fiber_dim
|
||||||
|
self.target_frames = target_frames
|
||||||
|
self.eps = eps
|
||||||
|
self.edge_update = edge_update
|
||||||
|
if target_frames:
|
||||||
|
self.local_update = LocalUpdateUsingReferenceRotations(
|
||||||
|
fiber_dim, eps=eps, dropout=dropout, zero_init=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.local_update = LocalUpdateUsingChannelWiseGating(
|
||||||
|
fiber_dim, eps=eps, dropout=dropout, zero_init=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
edge_index: torch.LongTensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
R: torch.Tensor = None,
|
||||||
|
x_edge: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
"""Forward pass of the EquivariantTransformerBlock module."""
|
||||||
|
if self.edge_update:
|
||||||
|
xout, edge_out = self.attn_conv(x, x, edge_index, t, t, x_edge=x_edge)
|
||||||
|
x_edge = x_edge + edge_out
|
||||||
|
else:
|
||||||
|
xout = self.attn_conv(x, x, edge_index, t, t, x_edge=x_edge)
|
||||||
|
x = x + xout
|
||||||
|
if self.target_frames:
|
||||||
|
x = self.local_update(x, R) + x
|
||||||
|
else:
|
||||||
|
x = self.local_update(x) + x
|
||||||
|
return x, x_edge
|
||||||
|
|
||||||
|
|
||||||
|
class EquivariantStructureDenoisingModule(torch.nn.Module):
|
||||||
|
"""Equivariant Structure Denoising Module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fiber_dim: int,
|
||||||
|
input_dim: int,
|
||||||
|
input_pair_dim: int,
|
||||||
|
hidden_dim: int = 1024,
|
||||||
|
n_stacks: int = 4,
|
||||||
|
n_heads: int = 8,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Initialize the EquivariantStructureDenoisingModule module."""
|
||||||
|
super().__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.input_pair_dim = input_pair_dim
|
||||||
|
self.fiber_dim = fiber_dim
|
||||||
|
self.protatm_padding_dim = 37
|
||||||
|
self.n_blocks = n_stacks
|
||||||
|
self.input_node_projector = torch.nn.Linear(input_dim, fiber_dim, bias=False)
|
||||||
|
self.input_node_vec_projector = torch.nn.Linear(input_dim, fiber_dim * 3, bias=False)
|
||||||
|
self.input_pair_projector = torch.nn.Linear(input_pair_dim, fiber_dim, bias=False)
|
||||||
|
# Inherit the residue representations
|
||||||
|
self.atm_embed = GELUMLP(input_dim + 32, fiber_dim)
|
||||||
|
self.ipa_modules = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
EquivariantTransformerBlock(
|
||||||
|
fiber_dim,
|
||||||
|
heads=n_heads,
|
||||||
|
point_dim=fiber_dim // (n_heads * 2),
|
||||||
|
edge_dim=fiber_dim,
|
||||||
|
target_frames=True,
|
||||||
|
edge_update=True,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
for _ in range(n_stacks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.res_adapters = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
LocalUpdateUsingReferenceRotations(
|
||||||
|
fiber_dim,
|
||||||
|
extra_feat_dim=input_dim,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
zero_init=True,
|
||||||
|
)
|
||||||
|
for _ in range(n_stacks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.protatm_type_encoding = GELUMLP(self.protatm_padding_dim + input_dim, input_pair_dim)
|
||||||
|
self.time_encoding = GaussianFourierEncoding1D(16)
|
||||||
|
self.rel_geom_enc = RelativeGeometryEncoding(15, fiber_dim)
|
||||||
|
self.rel_geom_embed = GELUMLP(fiber_dim, fiber_dim, n_hidden_feats=fiber_dim)
|
||||||
|
# [displacement, scale]
|
||||||
|
self.out_drift_res = torch.nn.ModuleList(
|
||||||
|
[torch.nn.Linear(fiber_dim, 1, bias=False) for _ in range(n_stacks)]
|
||||||
|
)
|
||||||
|
# for i in range(n_stacks):
|
||||||
|
# self.out_drift_res[i].weight.data.fill_(0.0)
|
||||||
|
self.out_scale_res = torch.nn.ModuleList(
|
||||||
|
[GELUMLP(fiber_dim, 1, zero_init=True) for _ in range(n_stacks)]
|
||||||
|
)
|
||||||
|
self.out_drift_atm = torch.nn.ModuleList(
|
||||||
|
[torch.nn.Linear(fiber_dim, 1, bias=False) for _ in range(n_stacks)]
|
||||||
|
)
|
||||||
|
# for i in range(n_stacks):
|
||||||
|
# self.out_drift_atm[i].weight.data.fill_(0.0)
|
||||||
|
self.out_scale_atm = torch.nn.ModuleList(
|
||||||
|
[GELUMLP(fiber_dim, 1, zero_init=True) for _ in range(n_stacks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-tabulated edges
|
||||||
|
self.graph_relations = [
|
||||||
|
(
|
||||||
|
"residue_to_residue",
|
||||||
|
"gather_idx_ab_a",
|
||||||
|
"gather_idx_ab_b",
|
||||||
|
"prot_res",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_residue_to_sampled_residue",
|
||||||
|
"gather_idx_AB_a",
|
||||||
|
"gather_idx_AB_b",
|
||||||
|
"prot_res",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"prot_atm_to_prot_atm_graph",
|
||||||
|
"protatm_protatm_idx_src",
|
||||||
|
"protatm_protatm_idx_dst",
|
||||||
|
"prot_atm",
|
||||||
|
"prot_atm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"prot_atm_to_prot_atm_knn",
|
||||||
|
"knn_idx_protatm_protatm_src",
|
||||||
|
"knn_idx_protatm_protatm_dst",
|
||||||
|
"prot_atm",
|
||||||
|
"prot_atm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"prot_atm_to_residue",
|
||||||
|
"protatm_res_idx_protatm",
|
||||||
|
"protatm_res_idx_res",
|
||||||
|
"prot_atm",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"residue_to_prot_atm",
|
||||||
|
"protatm_res_idx_res",
|
||||||
|
"protatm_res_idx_protatm",
|
||||||
|
"prot_res",
|
||||||
|
"prot_atm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_lig_triplet_to_lig_atm",
|
||||||
|
"gather_idx_UI_I",
|
||||||
|
"gather_idx_UI_u",
|
||||||
|
"lig_trp",
|
||||||
|
"lig_atm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"lig_atm_to_sampled_lig_triplet",
|
||||||
|
"gather_idx_UI_u",
|
||||||
|
"gather_idx_UI_I",
|
||||||
|
"lig_atm",
|
||||||
|
"lig_trp",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"lig_atm_to_lig_atm_graph",
|
||||||
|
"gather_idx_uv_u",
|
||||||
|
"gather_idx_uv_v",
|
||||||
|
"lig_atm",
|
||||||
|
"lig_atm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_residue_to_sampled_lig_triplet",
|
||||||
|
"gather_idx_AJ_a",
|
||||||
|
"gather_idx_AJ_J",
|
||||||
|
"prot_res",
|
||||||
|
"lig_trp",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_lig_triplet_to_sampled_residue",
|
||||||
|
"gather_idx_AJ_J",
|
||||||
|
"gather_idx_AJ_a",
|
||||||
|
"lig_trp",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"residue_to_sampled_lig_triplet",
|
||||||
|
"gather_idx_aJ_a",
|
||||||
|
"gather_idx_aJ_J",
|
||||||
|
"prot_res",
|
||||||
|
"lig_trp",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_lig_triplet_to_residue",
|
||||||
|
"gather_idx_aJ_J",
|
||||||
|
"gather_idx_aJ_a",
|
||||||
|
"lig_trp",
|
||||||
|
"prot_res",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sampled_lig_triplet_to_sampled_lig_triplet",
|
||||||
|
"gather_idx_IJ_I",
|
||||||
|
"gather_idx_IJ_J",
|
||||||
|
"lig_trp",
|
||||||
|
"lig_trp",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"prot_atm_to_lig_atm_knn",
|
||||||
|
"knn_idx_protatm_ligatm_src",
|
||||||
|
"knn_idx_protatm_ligatm_dst",
|
||||||
|
"prot_atm",
|
||||||
|
"lig_atm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"lig_atm_to_prot_atm_knn",
|
||||||
|
"knn_idx_ligatm_protatm_src",
|
||||||
|
"knn_idx_ligatm_protatm_dst",
|
||||||
|
"lig_atm",
|
||||||
|
"prot_atm",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"lig_atm_to_lig_atm_knn",
|
||||||
|
"knn_idx_ligatm_ligatm_src",
|
||||||
|
"knn_idx_ligatm_ligatm_dst",
|
||||||
|
"lig_atm",
|
||||||
|
"lig_atm",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.graph_relations_no_ligand = self.graph_relations[:6]
|
||||||
|
|
||||||
|
def init_scalar_vec_rep(self, x, x_v=None, frame=None):
|
||||||
|
"""Initialize scalar and vector representations."""
|
||||||
|
if frame is None:
|
||||||
|
# Zero-initialize the vector channels
|
||||||
|
vec_shape = (*x.shape[:-1], 3, x.shape[-1])
|
||||||
|
res = torch.cat([x.unsqueeze(-2), torch.zeros(vec_shape, device=x.device)], dim=-2)
|
||||||
|
else:
|
||||||
|
x_v = x_v.view(*x.shape[:-1], 3, x.shape[-1])
|
||||||
|
x_v_glob = torch.matmul(frame.R, x_v)
|
||||||
|
res = torch.cat([x.unsqueeze(-2), x_v_glob], dim=-2)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch,
|
||||||
|
frozen_lig=False,
|
||||||
|
frozen_prot=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Forward pass of the EquivariantStructureDenoisingModule module."""
|
||||||
|
features = batch["features"]
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
metadata["num_structid"]
|
||||||
|
max(metadata["num_a_per_sample"])
|
||||||
|
|
||||||
|
prot_res_rep_in = features["rec_res_attr_decin"]
|
||||||
|
timestep_prot = features["timestep_encoding_prot"]
|
||||||
|
device = features["res_type"].device
|
||||||
|
|
||||||
|
# Protein all-atom representation initialization
|
||||||
|
protatm_padding_mask = features["res_atom_mask"]
|
||||||
|
protatm_atom37_onehot = torch.nn.functional.one_hot(
|
||||||
|
features["protatm_to_atom37_index"], num_classes=self.protatm_padding_dim
|
||||||
|
)
|
||||||
|
protatm_res_pair_encoding = self.protatm_type_encoding(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
prot_res_rep_in[indexer["protatm_res_idx_res"]],
|
||||||
|
protatm_atom37_onehot,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Gathered AA features from individual graphs
|
||||||
|
prot_atm_rep_in = features["prot_atom_attr_projected"]
|
||||||
|
prot_atm_rep_int = self.atm_embed(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
prot_atm_rep_in,
|
||||||
|
self.time_encoding(timestep_prot)[indexer["protatm_res_idx_res"]],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prot_atm_coords_padded = features["input_protein_coords"]
|
||||||
|
prot_atm_coords_flat = prot_atm_coords_padded[protatm_padding_mask]
|
||||||
|
|
||||||
|
# Embed the rigid body node representations
|
||||||
|
backbone_frames = get_frame_matrix(
|
||||||
|
prot_atm_coords_padded[:, 0],
|
||||||
|
prot_atm_coords_padded[:, 1],
|
||||||
|
prot_atm_coords_padded[:, 2],
|
||||||
|
)
|
||||||
|
prot_res_rep = self.init_scalar_vec_rep(
|
||||||
|
self.input_node_projector(prot_res_rep_in),
|
||||||
|
x_v=self.input_node_vec_projector(prot_res_rep_in),
|
||||||
|
frame=backbone_frames,
|
||||||
|
)
|
||||||
|
prot_atm_rep = self.init_scalar_vec_rep(prot_atm_rep_int)
|
||||||
|
# gather AA features from individual graphs
|
||||||
|
node_reps = {
|
||||||
|
"prot_res": prot_res_rep,
|
||||||
|
"prot_atm": prot_atm_rep,
|
||||||
|
}
|
||||||
|
# Embed pair representations
|
||||||
|
edge_reps = {
|
||||||
|
"residue_to_residue": features["res_res_pair_attr_decin"],
|
||||||
|
"prot_atm_to_prot_atm_graph": features["prot_atom_pair_attr_projected"],
|
||||||
|
"prot_atm_to_prot_atm_knn": features["knn_feat_protatm_protatm"],
|
||||||
|
"prot_atm_to_residue": protatm_res_pair_encoding,
|
||||||
|
"residue_to_prot_atm": protatm_res_pair_encoding,
|
||||||
|
"sampled_residue_to_sampled_residue": features["res_res_grid_attr_flat_decin"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
max(metadata["num_i_per_sample"])
|
||||||
|
timestep_lig = features["timestep_encoding_lig"]
|
||||||
|
lig_atm_rep_in = features["lig_atom_attr_projected"]
|
||||||
|
lig_frame_rep_in = features["lig_trp_attr_decin"]
|
||||||
|
# Ligand atom embedding. Two timescales
|
||||||
|
lig_atm_rep_int = self.atm_embed(
|
||||||
|
torch.cat(
|
||||||
|
[lig_atm_rep_in, self.time_encoding(timestep_lig)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
lig_atm_rep = self.init_scalar_vec_rep(lig_atm_rep_int)
|
||||||
|
|
||||||
|
# Prepare ligand atom - sidechain atom indexers
|
||||||
|
# Initialize coordinate features
|
||||||
|
lig_atm_coords = features["input_ligand_coords"].clone()
|
||||||
|
|
||||||
|
lig_frame_atm_idx = (
|
||||||
|
indexer["gather_idx_ijk_i"][indexer["gather_idx_I_ijk"]],
|
||||||
|
indexer["gather_idx_ijk_j"][indexer["gather_idx_I_ijk"]],
|
||||||
|
indexer["gather_idx_ijk_k"][indexer["gather_idx_I_ijk"]],
|
||||||
|
)
|
||||||
|
ligand_trp_frames = get_frame_matrix(
|
||||||
|
lig_atm_coords[lig_frame_atm_idx[0]],
|
||||||
|
lig_atm_coords[lig_frame_atm_idx[1]],
|
||||||
|
lig_atm_coords[lig_frame_atm_idx[2]],
|
||||||
|
)
|
||||||
|
lig_frame_rep = self.init_scalar_vec_rep(
|
||||||
|
self.input_node_projector(lig_frame_rep_in),
|
||||||
|
x_v=self.input_node_vec_projector(lig_frame_rep_in),
|
||||||
|
frame=ligand_trp_frames,
|
||||||
|
)
|
||||||
|
node_reps.update(
|
||||||
|
{
|
||||||
|
"lig_atm": lig_atm_rep,
|
||||||
|
"lig_trp": lig_frame_rep,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
edge_reps.update(
|
||||||
|
{
|
||||||
|
"lig_atm_to_lig_atm_graph": features["lig_atom_pair_attr_projected"],
|
||||||
|
"sampled_lig_triplet_to_lig_atm": features["lig_af_pair_attr_projected"],
|
||||||
|
"lig_atm_to_sampled_lig_triplet": features["lig_af_pair_attr_projected"],
|
||||||
|
"sampled_residue_to_sampled_lig_triplet": features[
|
||||||
|
"res_trp_grid_attr_flat_decin"
|
||||||
|
],
|
||||||
|
"sampled_lig_triplet_to_sampled_residue": features[
|
||||||
|
"res_trp_grid_attr_flat_decin"
|
||||||
|
],
|
||||||
|
"residue_to_sampled_lig_triplet": features["res_trp_pair_attr_flat_decin"],
|
||||||
|
"sampled_lig_triplet_to_residue": features["res_trp_pair_attr_flat_decin"],
|
||||||
|
"sampled_lig_triplet_to_sampled_lig_triplet": features[
|
||||||
|
"trp_trp_grid_attr_flat_decin"
|
||||||
|
],
|
||||||
|
"prot_atm_to_lig_atm_knn": features["knn_feat_protatm_ligatm"],
|
||||||
|
"lig_atm_to_prot_atm_knn": features["knn_feat_ligatm_protatm"],
|
||||||
|
"lig_atm_to_lig_atm_knn": features["knn_feat_ligatm_ligatm"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Message passing
|
||||||
|
protatm_res_idx_res = indexer["protatm_res_idx_res"]
|
||||||
|
if batch["misc"]["protein_only"]:
|
||||||
|
graph_relations = self.graph_relations_no_ligand
|
||||||
|
else:
|
||||||
|
graph_relations = self.graph_relations
|
||||||
|
graph_batcher = make_multi_relation_graph_batcher(graph_relations, indexer, metadata)
|
||||||
|
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||||
|
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
|
||||||
|
merged_edge_reps = graph_batcher.collate_edge_attr(
|
||||||
|
graph_batcher.zero_pad_edge_attr(edge_reps, self.input_pair_dim, device)
|
||||||
|
)
|
||||||
|
merged_edge_reps = self.input_pair_projector(merged_edge_reps)
|
||||||
|
assert merged_edge_idx[0].shape[0] == merged_edge_reps.shape[0]
|
||||||
|
assert merged_edge_idx[1].shape[0] == merged_edge_reps.shape[0]
|
||||||
|
|
||||||
|
dummy_prot_atm_frames = RigidTransform(prot_atm_coords_flat, R=None)
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
dummy_lig_atm_frames = RigidTransform(lig_atm_coords, R=None)
|
||||||
|
merged_node_t = graph_batcher.collate_node_attr(
|
||||||
|
{
|
||||||
|
"prot_res": backbone_frames.t,
|
||||||
|
"prot_atm": dummy_prot_atm_frames.t,
|
||||||
|
"lig_atm": dummy_lig_atm_frames.t,
|
||||||
|
"lig_trp": ligand_trp_frames.t,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
merged_node_R = graph_batcher.collate_node_attr(
|
||||||
|
{
|
||||||
|
"prot_res": backbone_frames.R,
|
||||||
|
"prot_atm": dummy_prot_atm_frames.R,
|
||||||
|
"lig_atm": dummy_lig_atm_frames.R,
|
||||||
|
"lig_trp": ligand_trp_frames.R,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged_node_t = graph_batcher.collate_node_attr(
|
||||||
|
{
|
||||||
|
"prot_res": backbone_frames.t,
|
||||||
|
"prot_atm": dummy_prot_atm_frames.t,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
merged_node_R = graph_batcher.collate_node_attr(
|
||||||
|
{
|
||||||
|
"prot_res": backbone_frames.R,
|
||||||
|
"prot_atm": dummy_prot_atm_frames.R,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
merged_node_frames = RigidTransform(merged_node_t, merged_node_R)
|
||||||
|
merged_edge_reps = merged_edge_reps + (
|
||||||
|
self.rel_geom_embed(
|
||||||
|
self.rel_geom_enc(merged_node_frames, merged_edge_idx) + merged_edge_reps
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# No need to reassign embeddings but need to update point coordinates & frames
|
||||||
|
for block_id in range(self.n_blocks):
|
||||||
|
dummy_prot_atm_frames = RigidTransform(prot_atm_coords_flat, R=None)
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
dummy_lig_atm_frames = RigidTransform(lig_atm_coords, R=None)
|
||||||
|
merged_node_t = graph_batcher.collate_node_attr(
|
||||||
|
{
|
||||||
|
"prot_res": backbone_frames.t,
|
||||||
|
"prot_atm": dummy_prot_atm_frames.t,
|
||||||
|
"lig_atm": dummy_lig_atm_frames.t,
|
||||||
|
"lig_trp": ligand_trp_frames.t,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
merged_node_R = graph_batcher.collate_node_attr(
|
||||||
|
{
|
||||||
|
"prot_res": backbone_frames.R,
|
||||||
|
"prot_atm": dummy_prot_atm_frames.R,
|
||||||
|
"lig_atm": dummy_lig_atm_frames.R,
|
||||||
|
"lig_trp": ligand_trp_frames.R,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged_node_t = graph_batcher.collate_node_attr(
|
||||||
|
{
|
||||||
|
"prot_res": backbone_frames.t,
|
||||||
|
"prot_atm": dummy_prot_atm_frames.t,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
merged_node_R = graph_batcher.collate_node_attr(
|
||||||
|
{
|
||||||
|
"prot_res": backbone_frames.R,
|
||||||
|
"prot_atm": dummy_prot_atm_frames.R,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# PredictDrift iteration
|
||||||
|
merged_node_reps, merged_edge_reps = self.ipa_modules[block_id](
|
||||||
|
merged_node_reps,
|
||||||
|
merged_edge_idx,
|
||||||
|
t=merged_node_t,
|
||||||
|
R=merged_node_R,
|
||||||
|
x_edge=merged_edge_reps,
|
||||||
|
)
|
||||||
|
offloaded_node_reps = graph_batcher.offload_node_attr(merged_node_reps)
|
||||||
|
if "lig_trp" in offloaded_node_reps.keys():
|
||||||
|
lig_frame_rep = offloaded_node_reps["lig_trp"]
|
||||||
|
offloaded_node_reps["lig_trp"] = lig_frame_rep + self.res_adapters[block_id](
|
||||||
|
lig_frame_rep, ligand_trp_frames.R, extra_feats=lig_frame_rep_in
|
||||||
|
)
|
||||||
|
prot_res_rep = offloaded_node_reps["prot_res"]
|
||||||
|
offloaded_node_reps["prot_res"] = prot_res_rep + self.res_adapters[block_id](
|
||||||
|
prot_res_rep, backbone_frames.R, extra_feats=prot_res_rep_in
|
||||||
|
)
|
||||||
|
merged_node_reps = graph_batcher.collate_node_attr(offloaded_node_reps)
|
||||||
|
# Displacement vectors in the global coordinate system
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
drift_trp = (
|
||||||
|
self.out_drift_res[block_id](offloaded_node_reps["lig_trp"][:, 1:]).squeeze(-1)
|
||||||
|
* torch.sigmoid(
|
||||||
|
self.out_scale_res[block_id](offloaded_node_reps["lig_trp"][:, 0])
|
||||||
|
)
|
||||||
|
* 10
|
||||||
|
)
|
||||||
|
drift_trp_gathered = segment_mean(
|
||||||
|
drift_trp,
|
||||||
|
indexer["gather_idx_I_molid"],
|
||||||
|
metadata["num_molid"],
|
||||||
|
)[indexer["gather_idx_i_molid"]]
|
||||||
|
drift_atm = self.out_drift_atm[block_id](
|
||||||
|
offloaded_node_reps["lig_atm"][:, 1:]
|
||||||
|
).squeeze(-1) * torch.sigmoid(
|
||||||
|
self.out_scale_atm[block_id](offloaded_node_reps["lig_atm"][:, 0])
|
||||||
|
)
|
||||||
|
if not frozen_lig:
|
||||||
|
lig_atm_coords = lig_atm_coords + drift_atm + drift_trp_gathered
|
||||||
|
ligand_trp_frames = get_frame_matrix(
|
||||||
|
lig_atm_coords[lig_frame_atm_idx[0]],
|
||||||
|
lig_atm_coords[lig_frame_atm_idx[1]],
|
||||||
|
lig_atm_coords[lig_frame_atm_idx[2]],
|
||||||
|
)
|
||||||
|
|
||||||
|
drift_bb = (
|
||||||
|
self.out_drift_res[block_id](offloaded_node_reps["prot_res"][:, 1:]).squeeze(-1)
|
||||||
|
* torch.sigmoid(
|
||||||
|
self.out_scale_res[block_id](offloaded_node_reps["prot_res"][:, 0])
|
||||||
|
)
|
||||||
|
* 10
|
||||||
|
)
|
||||||
|
drift_bb_gathered = drift_bb[protatm_res_idx_res]
|
||||||
|
drift_prot_atm_int = self.out_drift_atm[block_id](
|
||||||
|
offloaded_node_reps["prot_atm"][:, 1:]
|
||||||
|
).squeeze(-1) * torch.sigmoid(
|
||||||
|
self.out_scale_atm[block_id](offloaded_node_reps["prot_atm"][:, 0])
|
||||||
|
)
|
||||||
|
if not frozen_prot:
|
||||||
|
prot_atm_coords_flat = (
|
||||||
|
prot_atm_coords_flat + drift_prot_atm_int + drift_bb_gathered
|
||||||
|
)
|
||||||
|
|
||||||
|
prot_atm_coords_padded = torch.zeros_like(features["input_protein_coords"])
|
||||||
|
prot_atm_coords_padded[protatm_padding_mask] = prot_atm_coords_flat
|
||||||
|
backbone_frames = get_frame_matrix(
|
||||||
|
prot_atm_coords_padded[:, 0],
|
||||||
|
prot_atm_coords_padded[:, 1],
|
||||||
|
prot_atm_coords_padded[:, 2],
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = {
|
||||||
|
"final_embedding_prot_atom": offloaded_node_reps["prot_atm"],
|
||||||
|
"final_embedding_prot_res": offloaded_node_reps["prot_res"],
|
||||||
|
"final_coords_prot_atom": prot_atm_coords_flat,
|
||||||
|
"final_coords_prot_atom_padded": prot_atm_coords_padded,
|
||||||
|
}
|
||||||
|
if not batch["misc"]["protein_only"]:
|
||||||
|
ret["final_embedding_lig_atom"] = offloaded_node_reps["lig_atm"]
|
||||||
|
ret["final_coords_lig_atom"] = lig_atm_coords
|
||||||
|
else:
|
||||||
|
ret["final_embedding_lig_atom"] = None
|
||||||
|
ret["final_coords_lig_atom"] = None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_score_head(
|
||||||
|
protein_model_cfg: DictConfig,
|
||||||
|
score_cfg: DictConfig,
|
||||||
|
task_cfg: DictConfig,
|
||||||
|
state_dict: Optional[STATE_DICT] = None,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
"""Instantiates an EquivariantStructureDenoisingModule model for protein-ligand complex
|
||||||
|
structure denoising.
|
||||||
|
|
||||||
|
:param protein_model_cfg: Protein model configuration.
|
||||||
|
:param score_cfg: Score configuration.
|
||||||
|
:param task_cfg: Task configuration.
|
||||||
|
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||||
|
:return: EquivariantStructureDenoisingModule model.
|
||||||
|
"""
|
||||||
|
model = EquivariantStructureDenoisingModule(
|
||||||
|
score_cfg.fiber_dim,
|
||||||
|
input_dim=protein_model_cfg.residue_dim,
|
||||||
|
input_pair_dim=protein_model_cfg.pair_dim,
|
||||||
|
hidden_dim=score_cfg.hidden_dim,
|
||||||
|
n_stacks=score_cfg.n_stacks,
|
||||||
|
n_heads=protein_model_cfg.n_heads,
|
||||||
|
dropout=task_cfg.dropout,
|
||||||
|
)
|
||||||
|
if score_cfg.from_pretrained and state_dict is not None:
|
||||||
|
try:
|
||||||
|
model.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("score_head")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained score weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Skipping loading of pretrained score weights due to: {e}.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_confidence_head(
|
||||||
|
protein_model_cfg: DictConfig,
|
||||||
|
confidence_cfg: DictConfig,
|
||||||
|
task_cfg: DictConfig,
|
||||||
|
state_dict: Optional[STATE_DICT] = None,
|
||||||
|
) -> Tuple[torch.nn.Module, torch.nn.Module]:
|
||||||
|
"""Instantiates an EquivariantStructureDenoisingModule model for confidence prediction.
|
||||||
|
|
||||||
|
:param protein_model_cfg: Protein model configuration.
|
||||||
|
:param confidence_cfg: Confidence configuration.
|
||||||
|
:param task_cfg: Task configuration.
|
||||||
|
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||||
|
:return: EquivariantStructureDenoisingModule model and plDDT gram head weights.
|
||||||
|
"""
|
||||||
|
confidence_head = EquivariantStructureDenoisingModule(
|
||||||
|
confidence_cfg.fiber_dim,
|
||||||
|
input_dim=protein_model_cfg.residue_dim,
|
||||||
|
input_pair_dim=protein_model_cfg.pair_dim,
|
||||||
|
hidden_dim=confidence_cfg.hidden_dim,
|
||||||
|
n_stacks=confidence_cfg.n_stacks,
|
||||||
|
n_heads=protein_model_cfg.n_heads,
|
||||||
|
dropout=task_cfg.dropout,
|
||||||
|
)
|
||||||
|
plddt_gram_head = GELUMLP(
|
||||||
|
protein_model_cfg.pair_dim,
|
||||||
|
8,
|
||||||
|
n_hidden_feats=protein_model_cfg.pair_dim,
|
||||||
|
zero_init=True,
|
||||||
|
)
|
||||||
|
if confidence_cfg.from_pretrained and state_dict is not None:
|
||||||
|
try:
|
||||||
|
confidence_head.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("confidence_head")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained confidence weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Skipping loading of pretrained confidence weights due to: {e}.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
plddt_gram_head.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("plddt_gram_head")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained pLDDT gram head weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Skipping loading of pretrained pLDDT gram head weights due to: {e}.")
|
||||||
|
return confidence_head, plddt_gram_head
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_affinity_head(
|
||||||
|
ligand_model_cfg: DictConfig,
|
||||||
|
affinity_cfg: DictConfig,
|
||||||
|
task_cfg: DictConfig,
|
||||||
|
learnable_pooling: bool = True,
|
||||||
|
state_dict: Optional[STATE_DICT] = None,
|
||||||
|
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
|
||||||
|
"""Instantiates an EquivariantStructureDenoisingModule model for affinity prediction.
|
||||||
|
|
||||||
|
:param ligand_model_cfg: Ligand model configuration.
|
||||||
|
:param affinity_cfg: Affinity configuration.
|
||||||
|
:param task_cfg: Task configuration.
|
||||||
|
:param learnable_pooling: Whether to use learnable ligand pooling modules.
|
||||||
|
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||||
|
:return: EquivariantStructureDenoisingModule model as well as a ligand pooling module and
|
||||||
|
projection head.
|
||||||
|
"""
|
||||||
|
affinity_head = EquivariantStructureDenoisingModule(
|
||||||
|
affinity_cfg.fiber_dim,
|
||||||
|
input_dim=ligand_model_cfg.node_channels,
|
||||||
|
input_pair_dim=ligand_model_cfg.pair_channels,
|
||||||
|
hidden_dim=affinity_cfg.hidden_dim,
|
||||||
|
n_stacks=affinity_cfg.n_stacks,
|
||||||
|
n_heads=ligand_model_cfg.n_heads,
|
||||||
|
dropout=affinity_cfg.dropout if affinity_cfg.get("dropout") else task_cfg.dropout,
|
||||||
|
)
|
||||||
|
if affinity_cfg.ligand_pooling in ["sum", "add", "summation", "addition"]:
|
||||||
|
ligand_pooling = SumPooling(learnable=learnable_pooling, hidden_dim=affinity_cfg.fiber_dim)
|
||||||
|
elif affinity_cfg.ligand_pooling in ["mean", "avg", "average"]:
|
||||||
|
ligand_pooling = AveragePooling(
|
||||||
|
learnable=learnable_pooling, hidden_dim=affinity_cfg.fiber_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unsupported ligand pooling method: {affinity_cfg.ligand_pooling}"
|
||||||
|
)
|
||||||
|
affinity_proj_head = GELUMLP(
|
||||||
|
affinity_cfg.fiber_dim,
|
||||||
|
1,
|
||||||
|
n_hidden_feats=affinity_cfg.fiber_dim,
|
||||||
|
zero_init=True,
|
||||||
|
)
|
||||||
|
if affinity_cfg.from_pretrained and state_dict is not None:
|
||||||
|
try:
|
||||||
|
affinity_head.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("affinity_head")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained affinity head weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Skipping loading of pretrained affinity head weights due to: {e}.")
|
||||||
|
|
||||||
|
if learnable_pooling:
|
||||||
|
try:
|
||||||
|
ligand_pooling.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("ligand_pooling")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained ligand pooling weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Skipping loading of pretrained ligand pooling weights due to: {e}.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
affinity_proj_head.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("affinity_proj_head")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained affinity projection head weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(
|
||||||
|
f"Skipping loading of pretrained affinity projection head weights due to: {e}."
|
||||||
|
)
|
||||||
|
return affinity_head, ligand_pooling, affinity_proj_head
|
||||||
2029
flowdock/models/components/flowdock.py
Normal file
2029
flowdock/models/components/flowdock.py
Normal file
File diff suppressed because it is too large
Load Diff
166
flowdock/models/components/hetero_graph.py
Normal file
166
flowdock/models/components/hetero_graph.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Relation:
|
||||||
|
edge_type: str
|
||||||
|
edge_rev_name: str
|
||||||
|
edge_frd_name: str
|
||||||
|
src_node_type: str
|
||||||
|
dst_node_type: str
|
||||||
|
num_edges: int
|
||||||
|
|
||||||
|
|
||||||
|
class MultiRelationGraphBatcher:
|
||||||
|
"""Collate sub-graphs of different node/edge types into a single instance.
|
||||||
|
|
||||||
|
Returned multi-relation edge indices are stored in LongTensor of shape [2, N_edges].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
relation_forms: List[Relation],
|
||||||
|
graph_metadata: Dict[str, int],
|
||||||
|
):
|
||||||
|
"""Initialize the batcher."""
|
||||||
|
self._relation_forms = relation_forms
|
||||||
|
self._make_offset_dict(graph_metadata)
|
||||||
|
|
||||||
|
def _make_offset_dict(self, graph_metadata):
|
||||||
|
"""Create offset dictionaries for node and edge types."""
|
||||||
|
self._node_chunk_sizes = {}
|
||||||
|
self._edge_chunk_sizes = {}
|
||||||
|
self._offsets_lower = {}
|
||||||
|
self._offsets_upper = {}
|
||||||
|
all_node_types = set()
|
||||||
|
for relation in self._relation_forms:
|
||||||
|
assert (
|
||||||
|
f"num_{relation.src_node_type}" in graph_metadata.keys()
|
||||||
|
), f"Missing metadata: num_{relation.src_node_type}"
|
||||||
|
assert (
|
||||||
|
f"num_{relation.dst_node_type}" in graph_metadata.keys()
|
||||||
|
), f"Missing metadata: num_{relation.src_node_type}"
|
||||||
|
all_node_types.add(relation.src_node_type)
|
||||||
|
all_node_types.add(relation.dst_node_type)
|
||||||
|
offset = 0
|
||||||
|
# Fix node type ordering
|
||||||
|
self.all_node_types = list(all_node_types)
|
||||||
|
for node_type in self.all_node_types:
|
||||||
|
self._offsets_lower[node_type] = offset
|
||||||
|
self._node_chunk_sizes[node_type] = graph_metadata[f"num_{node_type}"]
|
||||||
|
new_offset = offset + self._node_chunk_sizes[node_type]
|
||||||
|
self._offsets_upper[node_type] = new_offset
|
||||||
|
offset = new_offset
|
||||||
|
|
||||||
|
def collate_single_relation_graphs(self, indexer, node_attr_dict, edge_attr_dict):
|
||||||
|
"""Collate sub-graphs of different node/edge types into a single instance."""
|
||||||
|
return {
|
||||||
|
"node_attr": self.collate_node_attr(node_attr_dict),
|
||||||
|
"edge_attr": self.collate_edge_attr(edge_attr_dict),
|
||||||
|
"edge_index": self.collate_idx_list(indexer),
|
||||||
|
}
|
||||||
|
|
||||||
|
def collate_idx_list(
|
||||||
|
self,
|
||||||
|
indexer: Dict[str, torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Collate edge indices for all relations."""
|
||||||
|
ret_eidxs_rev, ret_eidxs_frd = [], []
|
||||||
|
for relation in self._relation_forms:
|
||||||
|
assert relation.edge_rev_name in indexer.keys()
|
||||||
|
assert relation.edge_frd_name in indexer.keys()
|
||||||
|
assert indexer[relation.edge_rev_name].dim() == 1
|
||||||
|
assert indexer[relation.edge_frd_name].dim() == 1
|
||||||
|
assert torch.all(
|
||||||
|
indexer[relation.edge_rev_name] < self._node_chunk_sizes[relation.src_node_type]
|
||||||
|
), f"Node index on edge exceeding boundary: {relation.edge_type}, {self._node_chunk_sizes[relation.src_node_type]}, {self._node_chunk_sizes[relation.dst_node_type]}, {max(indexer[relation.edge_rev_name])}, {max(indexer[relation.edge_frd_name])}"
|
||||||
|
assert torch.all(
|
||||||
|
indexer[relation.edge_frd_name] < self._node_chunk_sizes[relation.dst_node_type]
|
||||||
|
), f"Node index on edge exceeding boundary: {relation.edge_type}, {self._node_chunk_sizes[relation.src_node_type]}, {self._node_chunk_sizes[relation.dst_node_type]}, {max(indexer[relation.edge_rev_name])}, {max(indexer[relation.edge_frd_name])}"
|
||||||
|
ret_eidxs_rev.append(
|
||||||
|
indexer[relation.edge_rev_name] + self._offsets_lower[relation.src_node_type]
|
||||||
|
)
|
||||||
|
ret_eidxs_frd.append(
|
||||||
|
indexer[relation.edge_frd_name] + self._offsets_lower[relation.dst_node_type]
|
||||||
|
)
|
||||||
|
ret_eidxs_rev = torch.cat(ret_eidxs_rev, dim=0)
|
||||||
|
ret_eidxs_frd = torch.cat(ret_eidxs_frd, dim=0)
|
||||||
|
return torch.stack([ret_eidxs_rev, ret_eidxs_frd], dim=0)
|
||||||
|
|
||||||
|
def collate_node_attr(self, node_attr_dict: Dict[str, torch.Tensor]):
|
||||||
|
"""Collate node attributes for all node types."""
|
||||||
|
for node_type in self.all_node_types:
|
||||||
|
assert (
|
||||||
|
node_attr_dict[node_type].shape[0] == self._node_chunk_sizes[node_type]
|
||||||
|
), f"Node count mismatch: {node_type}, {node_attr_dict[node_type].shape[0]}, {self._node_chunk_sizes[node_type]}"
|
||||||
|
return torch.cat([node_attr_dict[node_type] for node_type in self.all_node_types], dim=0)
|
||||||
|
|
||||||
|
def collate_edge_attr(self, edge_attr_dict: Dict[str, torch.Tensor]):
|
||||||
|
"""Collate edge attributes for all relations."""
|
||||||
|
# for relation in self._relation_forms:
|
||||||
|
# print(relation.edge_type, edge_attr_dict[relation.edge_type].shape)
|
||||||
|
return torch.cat(
|
||||||
|
[edge_attr_dict[relation.edge_type] for relation in self._relation_forms],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def zero_pad_edge_attr(
|
||||||
|
self,
|
||||||
|
edge_attr_dict: Dict[str, torch.Tensor],
|
||||||
|
embedding_dim: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Zero pad edge attributes for all relations."""
|
||||||
|
for relation in self._relation_forms:
|
||||||
|
if edge_attr_dict[relation.edge_type] is None:
|
||||||
|
edge_attr_dict[relation.edge_type] = torch.zeros(
|
||||||
|
(relation.num_edges, embedding_dim),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return edge_attr_dict
|
||||||
|
|
||||||
|
def offload_node_attr(self, cat_node_attr: torch.Tensor):
|
||||||
|
"""Offload node attributes for all node types."""
|
||||||
|
node_chunk_sizes = [self._node_chunk_sizes[node_type] for node_type in self.all_node_types]
|
||||||
|
node_attr_split = torch.split(cat_node_attr, node_chunk_sizes)
|
||||||
|
return {
|
||||||
|
self.all_node_types[i]: node_attr_split[i] for i in range(len(self.all_node_types))
|
||||||
|
}
|
||||||
|
|
||||||
|
def offload_edge_attr(self, cat_edge_attr: torch.Tensor):
|
||||||
|
"""Offload edge attributes for all relations."""
|
||||||
|
edge_chunk_sizes = [relation.num_edges for relation in self._relation_forms]
|
||||||
|
edge_attr_split = torch.split(cat_edge_attr, edge_chunk_sizes)
|
||||||
|
return {
|
||||||
|
self._relation_forms[i].edge_type: edge_attr_split[i]
|
||||||
|
for i in range(len(self._relation_forms))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_multi_relation_graph_batcher(
|
||||||
|
list_of_relations: List[Tuple[str, str, str, str, str]],
|
||||||
|
indexer,
|
||||||
|
metadata,
|
||||||
|
):
|
||||||
|
"""Make a multi-relation graph batcher."""
|
||||||
|
# Use one instantiation of the indexer to compute chunk sizes
|
||||||
|
relation_forms = [
|
||||||
|
Relation(
|
||||||
|
edge_type=rl_tuple[0],
|
||||||
|
edge_rev_name=rl_tuple[1],
|
||||||
|
edge_frd_name=rl_tuple[2],
|
||||||
|
src_node_type=rl_tuple[3],
|
||||||
|
dst_node_type=rl_tuple[4],
|
||||||
|
num_edges=indexer[rl_tuple[1]].shape[0],
|
||||||
|
)
|
||||||
|
for rl_tuple in list_of_relations
|
||||||
|
]
|
||||||
|
return MultiRelationGraphBatcher(
|
||||||
|
relation_forms,
|
||||||
|
metadata,
|
||||||
|
)
|
||||||
1439
flowdock/models/components/losses.py
Normal file
1439
flowdock/models/components/losses.py
Normal file
File diff suppressed because it is too large
Load Diff
364
flowdock/models/components/mht_encoder.py
Normal file
364
flowdock/models/components/mht_encoder.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||||
|
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, Optional, Tuple
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
|
||||||
|
from flowdock.models.components.modules import TransformerLayer
|
||||||
|
from flowdock.utils import RankedLogger
|
||||||
|
from flowdock.utils.model_utils import GELUMLP, segment_softmax, segment_sum
|
||||||
|
|
||||||
|
MODEL_BATCH = Dict[str, Any]
|
||||||
|
STATE_DICT = Dict[str, Any]
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PathConvStack(torch.nn.Module):
|
||||||
|
"""Path integral convolution stack for ligand encoding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pair_channels: int,
|
||||||
|
n_heads: int = 8,
|
||||||
|
max_pi_length: int = 8,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Initialize PathConvStack model."""
|
||||||
|
super().__init__()
|
||||||
|
self.pair_channels = pair_channels
|
||||||
|
self.max_pi_length = max_pi_length
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
self.prop_value_layer = torch.nn.Linear(pair_channels, n_heads, bias=False)
|
||||||
|
self.triangle_pair_kernel_layer = torch.nn.Linear(pair_channels, n_heads, bias=False)
|
||||||
|
self.prop_update_mlp = GELUMLP(
|
||||||
|
n_heads * (max_pi_length + 1), pair_channels, dropout=dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prop_attr: torch.Tensor,
|
||||||
|
stereo_attr: torch.Tensor,
|
||||||
|
indexer: Dict[str, torch.LongTensor],
|
||||||
|
metadata: Dict[str, Any],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass for PathConvStack model.
|
||||||
|
|
||||||
|
:param prop_attr: Atom-frame pair attributes.
|
||||||
|
:param stereo_attr: Stereochemistry attributes.
|
||||||
|
:param indexer: A dictionary of indices.
|
||||||
|
:param metadata: A dictionary of metadata.
|
||||||
|
:return: Updated atom-frame pair attributes.
|
||||||
|
"""
|
||||||
|
triangle_pair_kernel = self.triangle_pair_kernel_layer(stereo_attr)
|
||||||
|
# Segment-wise softmax, normalized by outgoing triangles
|
||||||
|
triangle_pair_alpha = segment_softmax(
|
||||||
|
triangle_pair_kernel, indexer["gather_idx_ijkl_jkl"], metadata["num_ijk"]
|
||||||
|
) # .div(self.max_pi_length)
|
||||||
|
# Uijk,ijkl->ujkl pair representation update
|
||||||
|
kernel = triangle_pair_alpha[indexer["gather_idx_Uijkl_ijkl"]]
|
||||||
|
out_prop_attr = [self.prop_value_layer(prop_attr)]
|
||||||
|
for _ in range(self.max_pi_length):
|
||||||
|
gathered_prop_attr = out_prop_attr[-1][indexer["gather_idx_Uijkl_Uijk"]]
|
||||||
|
out_prop_attr.append(
|
||||||
|
segment_sum(
|
||||||
|
kernel.mul(gathered_prop_attr),
|
||||||
|
indexer["gather_idx_Uijkl_ujkl"],
|
||||||
|
metadata["num_Uijk"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_prop_attr = torch.cat(out_prop_attr, dim=-1)
|
||||||
|
new_prop_attr = self.prop_update_mlp(new_prop_attr) + prop_attr
|
||||||
|
return new_prop_attr
|
||||||
|
|
||||||
|
|
||||||
|
class PIFormer(torch.nn.Module):
|
||||||
|
"""PIFormer model for ligand encoding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_channels: int,
|
||||||
|
pair_channels: int,
|
||||||
|
n_atom_encodings: int,
|
||||||
|
n_bond_encodings: int,
|
||||||
|
n_atom_pos_encodings: int,
|
||||||
|
n_stereo_encodings: int,
|
||||||
|
heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_path_length: int = 4,
|
||||||
|
n_transformer_stacks=4,
|
||||||
|
hidden_dim: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Initialize PIFormer model."""
|
||||||
|
super().__init__()
|
||||||
|
self.node_channels = node_channels
|
||||||
|
self.pair_channels = pair_channels
|
||||||
|
self.max_pi_length = max_path_length
|
||||||
|
self.n_transformer_stacks = n_transformer_stacks
|
||||||
|
self.n_atom_encodings = n_atom_encodings
|
||||||
|
self.n_bond_encodings = n_bond_encodings
|
||||||
|
self.n_atom_pair_encodings = n_bond_encodings + 4
|
||||||
|
self.n_atom_pos_encodings = n_atom_pos_encodings
|
||||||
|
|
||||||
|
self.input_atom_layer = torch.nn.Linear(n_atom_encodings, node_channels)
|
||||||
|
self.input_pair_layer = torch.nn.Linear(self.n_atom_pair_encodings, pair_channels)
|
||||||
|
self.input_stereo_layer = torch.nn.Linear(n_stereo_encodings, pair_channels)
|
||||||
|
self.input_prop_layer = GELUMLP(
|
||||||
|
self.n_atom_pair_encodings * 3,
|
||||||
|
pair_channels,
|
||||||
|
)
|
||||||
|
self.path_integral_stacks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
PathConvStack(
|
||||||
|
pair_channels,
|
||||||
|
max_pi_length=max_path_length,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
for _ in range(n_transformer_stacks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.graph_transformer_stacks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerLayer(
|
||||||
|
node_channels,
|
||||||
|
heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
edge_channels=pair_channels,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
edge_update=True,
|
||||||
|
)
|
||||||
|
for _ in range(n_transformer_stacks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, batch: MODEL_BATCH, masking_rate: float = 0.0) -> MODEL_BATCH:
|
||||||
|
"""Forward pass for PIFormer model.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param masking_rate: Masking rate.
|
||||||
|
:return: A batch dictionary.
|
||||||
|
"""
|
||||||
|
features = batch["features"]
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
features["atom_encodings"] = features["atom_encodings"]
|
||||||
|
atom_attr = features["atom_encodings"]
|
||||||
|
atom_pair_attr = features["atom_pair_encodings"]
|
||||||
|
af_pair_attr = features["atom_frame_pair_encodings"]
|
||||||
|
stereo_enc = features["stereo_chemistry_encodings"]
|
||||||
|
batch["features"]["lig_atom_token"] = atom_attr.detach().clone()
|
||||||
|
batch["features"]["lig_pair_token"] = atom_pair_attr.detach().clone()
|
||||||
|
|
||||||
|
atom_mask = torch.rand(atom_attr.shape[0], device=atom_attr.device) > masking_rate
|
||||||
|
stereo_mask = torch.rand(stereo_enc.shape[0], device=stereo_enc.device) > masking_rate
|
||||||
|
atom_pair_mask = (
|
||||||
|
torch.rand(atom_pair_attr.shape[0], device=atom_pair_attr.device) > masking_rate
|
||||||
|
)
|
||||||
|
af_pair_mask = (
|
||||||
|
torch.rand(af_pair_attr.shape[0], device=atom_pair_attr.device) > masking_rate
|
||||||
|
)
|
||||||
|
atom_attr = atom_attr * atom_mask[:, None]
|
||||||
|
stereo_enc = stereo_enc * stereo_mask[:, None]
|
||||||
|
atom_pair_attr = atom_pair_attr * atom_pair_mask[:, None]
|
||||||
|
af_pair_attr = af_pair_attr * af_pair_mask[:, None]
|
||||||
|
|
||||||
|
# Embedding blocks
|
||||||
|
metadata["num_atom"] = metadata["num_u"]
|
||||||
|
metadata["num_frame"] = metadata["num_ijk"]
|
||||||
|
atom_attr = self.input_atom_layer(atom_attr)
|
||||||
|
atom_pair_attr = self.input_pair_layer(atom_pair_attr)
|
||||||
|
triangle_attr = atom_attr.new_zeros(metadata["num_frame"], self.node_channels)
|
||||||
|
# Initialize atom-frame pair attributes. Reusing uv indices
|
||||||
|
prop_attr = self.input_prop_layer(af_pair_attr)
|
||||||
|
stereo_attr = self.input_stereo_layer(stereo_enc)
|
||||||
|
|
||||||
|
graph_relations = [
|
||||||
|
("atom_to_atom", "gather_idx_uv_u", "gather_idx_uv_v", "atom", "atom"),
|
||||||
|
(
|
||||||
|
"atom_to_frame",
|
||||||
|
"gather_idx_Uijk_u",
|
||||||
|
"gather_idx_Uijk_ijk",
|
||||||
|
"atom",
|
||||||
|
"frame",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"frame_to_atom",
|
||||||
|
"gather_idx_Uijk_ijk",
|
||||||
|
"gather_idx_Uijk_u",
|
||||||
|
"frame",
|
||||||
|
"atom",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"frame_to_frame",
|
||||||
|
"gather_idx_ijkl_ijk",
|
||||||
|
"gather_idx_ijkl_jkl",
|
||||||
|
"frame",
|
||||||
|
"frame",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
graph_batcher = make_multi_relation_graph_batcher(graph_relations, indexer, metadata)
|
||||||
|
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||||
|
node_reps = {"atom": atom_attr, "frame": triangle_attr}
|
||||||
|
edge_reps = {
|
||||||
|
"atom_to_atom": atom_pair_attr,
|
||||||
|
"atom_to_frame": prop_attr,
|
||||||
|
"frame_to_atom": prop_attr,
|
||||||
|
"frame_to_frame": stereo_attr,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Graph path integral recursion
|
||||||
|
for block_id in range(self.n_transformer_stacks):
|
||||||
|
merged_node_attr = graph_batcher.collate_node_attr(node_reps)
|
||||||
|
merged_edge_attr = graph_batcher.collate_edge_attr(edge_reps)
|
||||||
|
_, merged_node_attr, merged_edge_attr = self.graph_transformer_stacks[block_id](
|
||||||
|
merged_node_attr,
|
||||||
|
merged_node_attr,
|
||||||
|
merged_edge_idx,
|
||||||
|
merged_edge_attr,
|
||||||
|
)
|
||||||
|
node_reps = graph_batcher.offload_node_attr(merged_node_attr)
|
||||||
|
edge_reps = graph_batcher.offload_edge_attr(merged_edge_attr)
|
||||||
|
prop_attr = edge_reps["atom_to_frame"]
|
||||||
|
stereo_attr = edge_reps["frame_to_frame"]
|
||||||
|
prop_attr = prop_attr + self.path_integral_stacks[block_id](
|
||||||
|
prop_attr,
|
||||||
|
stereo_attr,
|
||||||
|
indexer,
|
||||||
|
metadata,
|
||||||
|
)
|
||||||
|
edge_reps["atom_to_frame"] = prop_attr
|
||||||
|
|
||||||
|
node_reps["sampled_frame"] = node_reps["frame"][indexer["gather_idx_I_ijk"]]
|
||||||
|
|
||||||
|
batch["metadata"]["num_lig_atm"] = metadata["num_u"]
|
||||||
|
batch["metadata"]["num_lig_trp"] = metadata["num_I"]
|
||||||
|
|
||||||
|
batch["features"]["lig_atom_attr"] = node_reps["atom"]
|
||||||
|
# Downsampled ligand frames
|
||||||
|
batch["features"]["lig_trp_attr"] = node_reps["sampled_frame"]
|
||||||
|
batch["features"]["lig_atom_pair_attr"] = edge_reps["atom_to_atom"]
|
||||||
|
batch["features"]["lig_prop_attr"] = edge_reps["atom_to_frame"]
|
||||||
|
edge_reps["sampled_atom_to_sampled_frame"] = edge_reps["atom_to_frame"][
|
||||||
|
indexer["gather_idx_UI_Uijk"]
|
||||||
|
]
|
||||||
|
batch["features"]["lig_af_pair_attr"] = edge_reps["sampled_atom_to_sampled_frame"]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_ligand_encoder(
|
||||||
|
ligand_model_cfg: DictConfig,
|
||||||
|
task_cfg: DictConfig,
|
||||||
|
state_dict: Optional[STATE_DICT] = None,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
"""Instantiates a PIFormer model for ligand encoding.
|
||||||
|
|
||||||
|
:param ligand_model_cfg: Ligand model configuration.
|
||||||
|
:param task_cfg: Task configuration.
|
||||||
|
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||||
|
:return: Ligand encoder model.
|
||||||
|
"""
|
||||||
|
model = PIFormer(
|
||||||
|
ligand_model_cfg.node_channels,
|
||||||
|
ligand_model_cfg.pair_channels,
|
||||||
|
ligand_model_cfg.n_atom_encodings,
|
||||||
|
ligand_model_cfg.n_bond_encodings,
|
||||||
|
ligand_model_cfg.n_atom_pos_encodings,
|
||||||
|
ligand_model_cfg.n_stereo_encodings,
|
||||||
|
ligand_model_cfg.n_attention_heads,
|
||||||
|
ligand_model_cfg.attention_head_dim,
|
||||||
|
hidden_dim=ligand_model_cfg.hidden_dim,
|
||||||
|
max_path_length=ligand_model_cfg.max_path_integral_length,
|
||||||
|
n_transformer_stacks=ligand_model_cfg.n_transformer_stacks,
|
||||||
|
dropout=task_cfg.dropout,
|
||||||
|
)
|
||||||
|
if ligand_model_cfg.from_pretrained and state_dict is not None:
|
||||||
|
try:
|
||||||
|
model.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("ligand_encoder")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info(
|
||||||
|
"Successfully loaded pretrained ligand Molecular Heat Transformer (MHT) weights."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(
|
||||||
|
f"Skipping loading of pretrained ligand Molecular Heat Transformer (MHT) weights due to: {e}."
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_relational_reasoning_module(
|
||||||
|
protein_model_cfg: DictConfig,
|
||||||
|
ligand_model_cfg: DictConfig,
|
||||||
|
relational_reasoning_cfg: DictConfig,
|
||||||
|
state_dict: Optional[STATE_DICT] = None,
|
||||||
|
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
|
||||||
|
"""Instantiates relational reasoning module for ligand encoding.
|
||||||
|
|
||||||
|
:param protein_model_cfg: Protein model configuration.
|
||||||
|
:param ligand_model_cfg: Ligand model configuration.
|
||||||
|
:param relational_reasoning_cfg: Relational reasoning configuration.
|
||||||
|
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||||
|
:return: Relational reasoning modules for ligand encoding.
|
||||||
|
"""
|
||||||
|
molgraph_single_projector = torch.nn.Linear(
|
||||||
|
ligand_model_cfg.node_channels, protein_model_cfg.residue_dim, bias=False
|
||||||
|
)
|
||||||
|
molgraph_pair_projector = torch.nn.Linear(
|
||||||
|
ligand_model_cfg.pair_channels, protein_model_cfg.pair_dim, bias=False
|
||||||
|
)
|
||||||
|
covalent_embed = torch.nn.Embedding(2, protein_model_cfg.pair_dim)
|
||||||
|
if relational_reasoning_cfg.from_pretrained and state_dict is not None:
|
||||||
|
try:
|
||||||
|
molgraph_single_projector.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("molgraph_single_projector")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained ligand graph single projector weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(
|
||||||
|
f"Skipping loading of pretrained ligand graph single projector weights due to: {e}."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
molgraph_pair_projector.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("molgraph_pair_projector")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained ligand graph pair projector weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(
|
||||||
|
f"Skipping loading of pretrained ligand graph pair projector weights due to: {e}."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
covalent_embed.load_state_dict(
|
||||||
|
{
|
||||||
|
".".join(k.split(".")[1:]): v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
if k.startswith("covalent_embed")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log.info("Successfully loaded pretrained ligand covalent embedding weights.")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(
|
||||||
|
f"Skipping loading of pretrained ligand covalent embedding weights due to: {e}."
|
||||||
|
)
|
||||||
|
return molgraph_single_projector, molgraph_pair_projector, covalent_embed
|
||||||
423
flowdock/models/components/modules.py
Normal file
423
flowdock/models/components/modules.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from beartype.typing import Optional, Tuple, Union
|
||||||
|
from openfold.model.primitives import Attention
|
||||||
|
from openfold.utils.tensor_utils import permute_final_dims
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.utils.model_utils import GELUMLP, segment_softmax
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttentionConv(torch.nn.Module):
|
||||||
|
"""Native Pytorch implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: Union[int, Tuple[int, int]],
|
||||||
|
head_dim: int,
|
||||||
|
edge_dim: int = None,
|
||||||
|
n_heads: int = 1,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
edge_lin: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Multi-Head Attention Convolution layer."""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.edge_dim = edge_dim
|
||||||
|
self.edge_lin = edge_lin
|
||||||
|
self._alpha = None
|
||||||
|
|
||||||
|
if isinstance(dim, int):
|
||||||
|
dim = (dim, dim)
|
||||||
|
|
||||||
|
self.lin_key = torch.nn.Linear(dim[0], n_heads * head_dim, bias=False)
|
||||||
|
self.lin_query = torch.nn.Linear(dim[1], n_heads * head_dim, bias=False)
|
||||||
|
self.lin_value = torch.nn.Linear(dim[0], n_heads * head_dim, bias=False)
|
||||||
|
if edge_lin is True:
|
||||||
|
self.lin_edge = torch.nn.Linear(edge_dim, n_heads, bias=False)
|
||||||
|
else:
|
||||||
|
self.lin_edge = self.register_parameter("lin_edge", None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
"""Reset the parameters of the layer."""
|
||||||
|
self.lin_key.reset_parameters()
|
||||||
|
self.lin_query.reset_parameters()
|
||||||
|
self.lin_value.reset_parameters()
|
||||||
|
if self.edge_lin:
|
||||||
|
self.lin_edge.reset_parameters()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
edge_index: torch.Tensor,
|
||||||
|
edge_attr: torch.Tensor = None,
|
||||||
|
return_attention_weights=None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
"""Forward pass of the Multi-Head Attention Convolution layer.
|
||||||
|
|
||||||
|
:param x (torch.Tensor or Tuple[torch.Tensor, torch.Tensor]): The input features.
|
||||||
|
:param edge_index (torch.Tensor): The edge index tensor.
|
||||||
|
:param edge_attr (torch.Tensor, optional): The edge attribute tensor.
|
||||||
|
:param return_attention_weights (bool, optional): If set to `True`,
|
||||||
|
will additionally return the tuple
|
||||||
|
`(edge_index, attention_weights)`, holding the computed
|
||||||
|
attention weights for each edge. Default is `None`.
|
||||||
|
:return: The output features or the tuple
|
||||||
|
`(output_features, (edge_index, attention_weights))`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
H, C = self.n_heads, self.head_dim
|
||||||
|
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = (x, x)
|
||||||
|
|
||||||
|
query = self.lin_query(x[1]).view(*x[1].shape[:-1], H, C)
|
||||||
|
key = self.lin_key(x[0]).view(*x[0].shape[:-1], H, C)
|
||||||
|
value = self.lin_value(x[0]).view(*x[0].shape[:-1], H, C)
|
||||||
|
|
||||||
|
attended_values = self.message(key, query, value, edge_attr, edge_index)
|
||||||
|
out = self.aggregate(attended_values, edge_index[1], query.shape[0])
|
||||||
|
|
||||||
|
alpha = self._alpha
|
||||||
|
self._alpha = None
|
||||||
|
|
||||||
|
out = out.contiguous().view(*out.shape[:-2], H * C)
|
||||||
|
|
||||||
|
if isinstance(return_attention_weights, bool):
|
||||||
|
assert alpha is not None
|
||||||
|
return out, (edge_index, alpha)
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
|
def message(
|
||||||
|
self,
|
||||||
|
key: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
edge_attr: torch.Tensor,
|
||||||
|
index: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Add the relative positional encodings to attention scores.
|
||||||
|
|
||||||
|
:param key (torch.Tensor): The key tensor.
|
||||||
|
:param query (torch.Tensor): The query tensor.
|
||||||
|
:param value (torch.Tensor): The value tensor.
|
||||||
|
:param edge_attr (torch.Tensor): The edge attribute tensor.
|
||||||
|
:param index (torch.Tensor): The edge index tensor.
|
||||||
|
:return: The output tensor.
|
||||||
|
"""
|
||||||
|
edge_bias = 0
|
||||||
|
if self.lin_edge is not None:
|
||||||
|
assert edge_attr is not None
|
||||||
|
edge_bias = self.lin_edge(edge_attr)
|
||||||
|
|
||||||
|
_alpha_z = (query[index[1]] * key[index[0]]).sum(dim=-1) / math.sqrt(
|
||||||
|
self.head_dim
|
||||||
|
) + edge_bias
|
||||||
|
self._alpha = _alpha_z
|
||||||
|
alpha = segment_softmax(_alpha_z, index[1], query.shape[0])
|
||||||
|
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
out = value[index[0]]
|
||||||
|
out *= alpha.unsqueeze(-1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def aggregate(
|
||||||
|
self, src: torch.Tensor, dst_idx: torch.Tensor, dst_size: torch.Size
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Aggregate the source tensor to the destination tensor.
|
||||||
|
|
||||||
|
:param src (torch.Tensor): The source tensor.
|
||||||
|
:param dst_idx (torch.Tensor): The destination index tensor.
|
||||||
|
:param dst_size (torch.Size): The destination size tensor.
|
||||||
|
:return: The output tensor.
|
||||||
|
"""
|
||||||
|
out = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*src.shape[1:],
|
||||||
|
dtype=src.dtype,
|
||||||
|
device=src.device,
|
||||||
|
).index_add_(0, dst_idx, src)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(torch.nn.Module):
|
||||||
|
"""A single layer of a transformer model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
|
hidden_dim: Optional[int] = None,
|
||||||
|
bidirectional: bool = False,
|
||||||
|
edge_channels: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
edge_update: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the transformer layer."""
|
||||||
|
super().__init__()
|
||||||
|
edge_lin = edge_channels is not None
|
||||||
|
self.edge_update = edge_update
|
||||||
|
if head_dim is None:
|
||||||
|
head_dim = node_dim // n_heads
|
||||||
|
self.conv = MultiHeadAttentionConv(
|
||||||
|
node_dim,
|
||||||
|
head_dim,
|
||||||
|
edge_dim=edge_channels,
|
||||||
|
n_heads=n_heads,
|
||||||
|
edge_lin=edge_lin,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.projector = torch.nn.Linear(head_dim * n_heads, node_dim, bias=False)
|
||||||
|
self.norm = torch.nn.LayerNorm(node_dim)
|
||||||
|
self.mlp = GELUMLP(
|
||||||
|
node_dim,
|
||||||
|
node_dim,
|
||||||
|
n_hidden_feats=hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
zero_init=True,
|
||||||
|
)
|
||||||
|
if edge_update:
|
||||||
|
self.mlpe = GELUMLP(
|
||||||
|
n_heads + edge_channels, edge_channels, dropout=dropout, zero_init=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_s: torch.Tensor,
|
||||||
|
x_a: torch.Tensor,
|
||||||
|
edge_index: torch.Tensor,
|
||||||
|
edge_attr: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||||
|
"""Forward pass through the transformer layer.
|
||||||
|
|
||||||
|
:param x_s (torch.Tensor): The source node features.
|
||||||
|
:param x_a (torch.Tensor): The target node features.
|
||||||
|
:param edge_index (torch.Tensor): The edge index tensor. :param edge_attr (torch.Tensor,
|
||||||
|
optional): The edge attribute tensor.
|
||||||
|
:return: The output source and target node features.
|
||||||
|
"""
|
||||||
|
out_a, (edge_index, alpha) = self.conv(
|
||||||
|
(x_s, x_a),
|
||||||
|
edge_index,
|
||||||
|
edge_attr,
|
||||||
|
return_attention_weights=True,
|
||||||
|
)
|
||||||
|
x_a = x_a + self.projector(out_a)
|
||||||
|
x_a = self.mlp(self.norm(x_a)) + x_a
|
||||||
|
if self.bidirectional:
|
||||||
|
out_s = self.conv((x_a, x_s), (edge_index[1], edge_index[0]), edge_attr)
|
||||||
|
x_s = x_s + self.projector(out_s)
|
||||||
|
x_s = self.mlp(self.norm(x_s)) + x_s
|
||||||
|
if self.edge_update:
|
||||||
|
edge_attr = edge_attr + self.mlpe(torch.cat([alpha, edge_attr], dim=-1))
|
||||||
|
return x_s, x_a, edge_attr
|
||||||
|
else:
|
||||||
|
return x_s, x_a
|
||||||
|
|
||||||
|
|
||||||
|
class PointSetAttention(torch.nn.Module):
|
||||||
|
"""PointSetAttention module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fiber_dim: int,
|
||||||
|
heads: int = 8,
|
||||||
|
point_dim: int = 4,
|
||||||
|
edge_dim: Optional[int] = None,
|
||||||
|
edge_update: bool = False,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Initialize the PointSetAttention module."""
|
||||||
|
super().__init__()
|
||||||
|
self.fiber_dim = fiber_dim
|
||||||
|
self.edge_dim = edge_dim
|
||||||
|
self.heads = heads
|
||||||
|
self.point_dim = point_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.edge_update = edge_update
|
||||||
|
self.distance_scaling = 10 # 1 nm
|
||||||
|
|
||||||
|
# num attention contributions
|
||||||
|
num_attn_logits = 2
|
||||||
|
|
||||||
|
self.lin_query = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
|
||||||
|
self.lin_key = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
|
||||||
|
self.lin_value = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
|
||||||
|
if edge_dim is not None:
|
||||||
|
self.lin_edge = torch.nn.Linear(edge_dim, heads, bias=False)
|
||||||
|
if edge_update:
|
||||||
|
self.edge_update_mlp = GELUMLP(heads + edge_dim, edge_dim)
|
||||||
|
|
||||||
|
# qkv projection for scalar attention (normal)
|
||||||
|
self.scalar_attn_logits_scale = (num_attn_logits * point_dim) ** -0.5
|
||||||
|
|
||||||
|
# qkv projection for point attention (coordinate and orientation aware)
|
||||||
|
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.0)) - 1.0)
|
||||||
|
self.point_weights = torch.nn.Parameter(point_weight_init_value)
|
||||||
|
|
||||||
|
self.point_attn_logits_scale = ((num_attn_logits * point_dim) * (9 / 2)) ** -0.5
|
||||||
|
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.0)) - 1.0)
|
||||||
|
self.point_weights = torch.nn.Parameter(point_weight_init_value)
|
||||||
|
|
||||||
|
# combine out - point dim * 4
|
||||||
|
self.to_out = torch.nn.Linear(heads * point_dim, fiber_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_k: torch.Tensor,
|
||||||
|
x_q: torch.Tensor,
|
||||||
|
edge_index: torch.LongTensor,
|
||||||
|
point_centers_k: torch.Tensor,
|
||||||
|
point_centers_q: torch.Tensor,
|
||||||
|
x_edge: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
"""Forward pass of the PointSetAttention module."""
|
||||||
|
H, P = self.heads, self.point_dim
|
||||||
|
|
||||||
|
q = self.lin_query(x_q)
|
||||||
|
k = self.lin_key(x_k)
|
||||||
|
v = self.lin_value(x_k)
|
||||||
|
|
||||||
|
scalar_q = q[..., 0, :].view(-1, H, P)
|
||||||
|
scalar_k = k[..., 0, :].view(-1, H, P)
|
||||||
|
scalar_v = v[..., 0, :].view(-1, H, P)
|
||||||
|
|
||||||
|
point_q_local = q[..., 1:, :].view(-1, 3, H, P)
|
||||||
|
point_k_local = k[..., 1:, :].view(-1, 3, H, P)
|
||||||
|
point_v_local = v[..., 1:, :].view(-1, 3, H, P)
|
||||||
|
|
||||||
|
point_q = point_q_local + point_centers_q[..., None, None] / self.distance_scaling
|
||||||
|
point_k = point_k_local + point_centers_k[..., None, None] / self.distance_scaling
|
||||||
|
point_v = point_v_local + point_centers_k[..., None, None] / self.distance_scaling
|
||||||
|
|
||||||
|
if self.edge_dim is not None:
|
||||||
|
edge_bias = self.lin_edge(x_edge)
|
||||||
|
else:
|
||||||
|
edge_bias = 0
|
||||||
|
|
||||||
|
attn_logits, attentions = self.compute_attention(
|
||||||
|
scalar_k, scalar_q, point_k, point_q, edge_bias, edge_index
|
||||||
|
)
|
||||||
|
res_scalar = self.aggregate(
|
||||||
|
attentions[:, :, None] * scalar_v[edge_index[0]],
|
||||||
|
edge_index[1],
|
||||||
|
scalar_q.shape[0],
|
||||||
|
)
|
||||||
|
res_points = self.aggregate(
|
||||||
|
attentions[:, None, :, None] * point_v[edge_index[0]],
|
||||||
|
edge_index[1],
|
||||||
|
point_q.shape[0],
|
||||||
|
)
|
||||||
|
res_points_local = res_points - point_centers_q[..., None, None] / self.distance_scaling
|
||||||
|
|
||||||
|
# [N, H, P], [N, 3, H, P] -> [N, 4, C]
|
||||||
|
res = torch.cat([res_scalar.unsqueeze(-3), res_points_local], dim=-3).flatten(-2, -1)
|
||||||
|
out = self.to_out(res) # [N, 4, C]
|
||||||
|
if self.edge_update:
|
||||||
|
edge_out = self.edge_update_mlp(torch.cat([attn_logits, x_edge], dim=-1))
|
||||||
|
return out, edge_out
|
||||||
|
return out
|
||||||
|
|
||||||
|
def compute_attention(self, scalar_k, scalar_q, point_k, point_q, edge_bias, index):
|
||||||
|
"""Compute the attention scores."""
|
||||||
|
scalar_q = scalar_q[index[1]]
|
||||||
|
scalar_k = scalar_k[index[0]]
|
||||||
|
point_q = point_q[index[1]]
|
||||||
|
point_k = point_k[index[0]]
|
||||||
|
|
||||||
|
scalar_logits = (scalar_q * scalar_k).sum(dim=-1) * self.scalar_attn_logits_scale
|
||||||
|
point_weights = F.softplus(self.point_weights).unsqueeze(0)
|
||||||
|
point_logits = (
|
||||||
|
torch.square(point_q - point_k).sum(dim=(-3, -1)) * self.point_attn_logits_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = scalar_logits - 1 / 2 * point_logits * point_weights + edge_bias
|
||||||
|
alpha = segment_softmax(logits, index[1], scalar_q.shape[0])
|
||||||
|
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
||||||
|
return logits, alpha
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def aggregate(src, dst_idx, dst_size):
|
||||||
|
"""Aggregate the source tensor to the destination tensor."""
|
||||||
|
out = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*src.shape[1:],
|
||||||
|
dtype=src.dtype,
|
||||||
|
device=src.device,
|
||||||
|
).index_add_(0, dst_idx, src)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BiDirectionalTriangleAttention(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Adapted from https://github.com/aqlaboratory/openfold
|
||||||
|
supports rectangular pair representation tensors
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c_in: int, c_hidden: int, no_heads: int, inf: float = 1e9):
|
||||||
|
"""Initialize the Bi-Directional Triangle Attention layer."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self.no_heads = no_heads
|
||||||
|
self.inf = inf
|
||||||
|
|
||||||
|
self.linear = torch.nn.Linear(c_in, self.no_heads, bias=False)
|
||||||
|
|
||||||
|
self.mha_1 = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
|
||||||
|
self.mha_2 = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
|
||||||
|
self.layer_norm = torch.nn.LayerNorm(self.c_in)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x1: torch.Tensor,
|
||||||
|
x2: torch.Tensor,
|
||||||
|
x_pair: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
use_lma: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward pass of the Bi-Directional Triangle Attention layer."""
|
||||||
|
if mask is None:
|
||||||
|
# [*, I, J, K]
|
||||||
|
mask = x_pair.new_ones(
|
||||||
|
x_pair.shape[:-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# [*, I, J, C_in]
|
||||||
|
x1 = self.layer_norm(x1)
|
||||||
|
# [*, I, K, C_in]
|
||||||
|
x2 = self.layer_norm(x2)
|
||||||
|
|
||||||
|
# [*, I, 1, J, K]
|
||||||
|
mask_bias = (self.inf * (mask - 1))[..., :, None, :, :]
|
||||||
|
|
||||||
|
# [*, I, H, J, K]
|
||||||
|
triangle_bias = permute_final_dims(self.linear(x_pair), [0, 3, 1, 2])
|
||||||
|
|
||||||
|
biases_J2I = [mask_bias, triangle_bias]
|
||||||
|
|
||||||
|
x1_out = self.mha_1(q_x=x1, kv_x=x2, biases=biases_J2I, use_lma=use_lma)
|
||||||
|
x1 = x1 + x1_out
|
||||||
|
|
||||||
|
# transpose the triangle bias for I->J attention.
|
||||||
|
mask_bias_T_ = mask_bias.transpose(-2, -1).contiguous()
|
||||||
|
triangle_bias_T_ = triangle_bias.transpose(-2, -1).contiguous()
|
||||||
|
biases_I2J = [mask_bias_T_, triangle_bias_T_]
|
||||||
|
x2_out = self.mha_2(q_x=x2, kv_x=x1, biases=biases_I2J, use_lma=use_lma)
|
||||||
|
x2 = x2 + x2_out
|
||||||
|
|
||||||
|
return x1, x2
|
||||||
443
flowdock/models/components/noise.py
Normal file
443
flowdock/models/components/noise.py
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
import numpy as np
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.models.components.transforms import LatentCoordinateConverter
|
||||||
|
from flowdock.utils import RankedLogger
|
||||||
|
from flowdock.utils.model_utils import segment_mean
|
||||||
|
|
||||||
|
MODEL_BATCH = Dict[str, Any]
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionSDE:
|
||||||
|
"""Diffusion SDE class.
|
||||||
|
|
||||||
|
Adapted from: https://github.com/HannesStark/FlowSite
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sigma: torch.Tensor, tau_factor: float = 5.0):
|
||||||
|
"""Initialize the Diffusion SDE class."""
|
||||||
|
self.lamb = 1 / sigma**2
|
||||||
|
self.tau_factor = tau_factor
|
||||||
|
|
||||||
|
def var(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate the variance of the diffusion SDE."""
|
||||||
|
return (1 - torch.exp(-self.lamb * t)) / self.lamb
|
||||||
|
|
||||||
|
def max_t(self) -> float:
|
||||||
|
"""Calculate the maximum time of the diffusion SDE."""
|
||||||
|
return self.tau_factor / self.lamb
|
||||||
|
|
||||||
|
def mu_factor(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate the mu factor of the diffusion SDE."""
|
||||||
|
return torch.exp(-self.lamb * t / 2)
|
||||||
|
|
||||||
|
|
||||||
|
class HarmonicSDE:
|
||||||
|
"""Harmonic SDE class.
|
||||||
|
|
||||||
|
Adapted from: https://github.com/HannesStark/FlowSite
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, J: Optional[torch.Tensor] = None, diagonalize: bool = True):
|
||||||
|
"""Initialize the Harmonic SDE class."""
|
||||||
|
self.l_index = 1
|
||||||
|
self.use_cuda = False
|
||||||
|
if not diagonalize:
|
||||||
|
return
|
||||||
|
if J is not None:
|
||||||
|
self.D, self.P = np.linalg.eigh(J)
|
||||||
|
self.N = self.D.size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def diagonalize(
|
||||||
|
N,
|
||||||
|
ptr: torch.Tensor,
|
||||||
|
edges: Optional[List[Tuple[int, int]]] = None,
|
||||||
|
antiedges: Optional[List[Tuple[int, int]]] = None,
|
||||||
|
a=1,
|
||||||
|
b=0.3,
|
||||||
|
lamb: Optional[torch.Tensor] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
):
|
||||||
|
"""Diagonalize using the Harmonic SDE."""
|
||||||
|
device = device or ptr.device
|
||||||
|
J = torch.zeros((N, N), device=device)
|
||||||
|
if edges is None:
|
||||||
|
for i, j in zip(np.arange(N - 1), np.arange(1, N)):
|
||||||
|
J[i, i] += a
|
||||||
|
J[j, j] += a
|
||||||
|
J[i, j] = J[j, i] = -a
|
||||||
|
else:
|
||||||
|
for i, j in edges:
|
||||||
|
J[i, i] += a
|
||||||
|
J[j, j] += a
|
||||||
|
J[i, j] = J[j, i] = -a
|
||||||
|
if antiedges is not None:
|
||||||
|
for i, j in antiedges:
|
||||||
|
J[i, i] -= b
|
||||||
|
J[j, j] -= b
|
||||||
|
J[i, j] = J[j, i] = b
|
||||||
|
if edges is not None:
|
||||||
|
J += torch.diag(lamb)
|
||||||
|
|
||||||
|
Ds, Ps = [], []
|
||||||
|
for start, end in zip(ptr[:-1], ptr[1:]):
|
||||||
|
D, P = torch.linalg.eigh(J[start:end, start:end])
|
||||||
|
D_ = D
|
||||||
|
if edges is None:
|
||||||
|
D_inv = 1 / D
|
||||||
|
D_inv[0] = 0
|
||||||
|
D_ = D_inv
|
||||||
|
Ds.append(D_)
|
||||||
|
Ps.append(P)
|
||||||
|
return torch.cat(Ds), torch.block_diag(*Ps)
|
||||||
|
|
||||||
|
def eigens(self, t):
|
||||||
|
"""Calculate the eigenvalues of `sigma_t` using the Harmonic SDE."""
|
||||||
|
np_ = torch if self.use_cuda else np
|
||||||
|
D = 1 / self.D * (1 - np_.exp(-t * self.D))
|
||||||
|
t = torch.tensor(t, device="cuda").float() if self.use_cuda else t
|
||||||
|
return np_.where(D != 0, D, t)
|
||||||
|
|
||||||
|
def conditional(self, mask, x2):
|
||||||
|
"""Calculate the conditional distribution using the Harmonic SDE."""
|
||||||
|
J_11 = self.J[~mask][:, ~mask]
|
||||||
|
J_12 = self.J[~mask][:, mask]
|
||||||
|
h = -J_12 @ x2
|
||||||
|
mu = np.linalg.inv(J_11) @ h
|
||||||
|
D, P = np.linalg.eigh(J_11)
|
||||||
|
z = np.random.randn(*mu.shape)
|
||||||
|
return (P / D**0.5) @ z + mu
|
||||||
|
|
||||||
|
def A(self, t, invT=False):
|
||||||
|
"""Calculate the matrix `A` using the Harmonic SDE."""
|
||||||
|
D = self.eigens(t)
|
||||||
|
A = self.P * (D**0.5)
|
||||||
|
if not invT:
|
||||||
|
return A
|
||||||
|
AinvT = self.P / (D**0.5)
|
||||||
|
return A, AinvT
|
||||||
|
|
||||||
|
def Sigma_inv(self, t):
|
||||||
|
"""Calculate the inverse of the covariance matrix `Sigma` using the Harmonic SDE."""
|
||||||
|
D = 1 / self.eigens(t)
|
||||||
|
return (self.P * D) @ self.P.T
|
||||||
|
|
||||||
|
def Sigma(self, t):
|
||||||
|
"""Calculate the covariance matrix `Sigma` using the Harmonic SDE."""
|
||||||
|
D = self.eigens(t)
|
||||||
|
return (self.P * D) @ self.P.T
|
||||||
|
|
||||||
|
@property
|
||||||
|
def J(self):
|
||||||
|
"""Return the matrix `J`."""
|
||||||
|
return (self.P * self.D) @ self.P.T
|
||||||
|
|
||||||
|
def rmsd(self, t):
|
||||||
|
"""Calculate the root mean square deviation using the Harmonic SDE."""
|
||||||
|
l_index = self.l_index
|
||||||
|
D = 1 / self.D * (1 - np.exp(-t * self.D))
|
||||||
|
return np.sqrt(3 * D[l_index:].mean())
|
||||||
|
|
||||||
|
def sample(self, t, x=None, score=False, k=None, center=True, adj=False):
|
||||||
|
"""Sample from the Harmonic SDE."""
|
||||||
|
l_index = self.l_index
|
||||||
|
np_ = torch if self.use_cuda else np
|
||||||
|
if x is None:
|
||||||
|
if self.use_cuda:
|
||||||
|
x = torch.zeros((self.N, 3), device="cuda").float()
|
||||||
|
else:
|
||||||
|
x = np.zeros((self.N, 3))
|
||||||
|
if t == 0:
|
||||||
|
return x
|
||||||
|
z = (
|
||||||
|
np.random.randn(self.N, 3)
|
||||||
|
if not self.use_cuda
|
||||||
|
else torch.randn(self.N, 3, device="cuda").float()
|
||||||
|
)
|
||||||
|
D = self.eigens(t)
|
||||||
|
xx = self.P.T @ x
|
||||||
|
if center:
|
||||||
|
z[0] = 0
|
||||||
|
xx[0] = 0
|
||||||
|
if k:
|
||||||
|
z[k + l_index :] = 0
|
||||||
|
xx[k + l_index :] = 0
|
||||||
|
|
||||||
|
out = np_.exp(-t * self.D / 2)[:, None] * xx + np_.sqrt(D)[:, None] * z
|
||||||
|
|
||||||
|
if score:
|
||||||
|
score = -(1 / np_.sqrt(D))[:, None] * z
|
||||||
|
if adj:
|
||||||
|
score = score + self.D[:, None] * out
|
||||||
|
return self.P @ out, self.P @ score
|
||||||
|
return self.P @ out
|
||||||
|
|
||||||
|
def score_norm(self, t, k=None, adj=False):
|
||||||
|
"""Calculate the score norm using the Harmonic SDE."""
|
||||||
|
if k == 0:
|
||||||
|
return 0
|
||||||
|
l_index = self.l_index
|
||||||
|
np_ = torch if self.use_cuda else np
|
||||||
|
k = k or self.N - 1
|
||||||
|
D = 1 / self.eigens(t)
|
||||||
|
if adj:
|
||||||
|
D = D * np_.exp(-self.D * t)
|
||||||
|
return (D[l_index : k + l_index].sum() / self.N) ** 0.5
|
||||||
|
|
||||||
|
def inject(self, t, modes):
|
||||||
|
"""Inject noise along the given modes using the Harmonic SDE."""
|
||||||
|
z = (
|
||||||
|
np.random.randn(self.N, 3)
|
||||||
|
if not self.use_cuda
|
||||||
|
else torch.randn(self.N, 3, device="cuda").float()
|
||||||
|
)
|
||||||
|
z[~modes] = 0
|
||||||
|
A = self.A(t, invT=False)
|
||||||
|
return A @ z
|
||||||
|
|
||||||
|
def score(self, x0, xt, t):
|
||||||
|
"""Calculate the score of the diffusion kernel using the Harmonic SDE."""
|
||||||
|
Sigma_inv = self.Sigma_inv(t)
|
||||||
|
mu_t = (self.P * np.exp(-t * self.D / 2)) @ (self.P.T @ x0)
|
||||||
|
return Sigma_inv @ (mu_t - xt)
|
||||||
|
|
||||||
|
def project(self, X, k, center=False):
|
||||||
|
"""Project onto the first `k` nonzero modes using the Harmonic SDE."""
|
||||||
|
l_index = self.l_index
|
||||||
|
D = self.P.T @ X
|
||||||
|
D[k + l_index :] = 0
|
||||||
|
if center:
|
||||||
|
D[0] = 0
|
||||||
|
return self.P @ D
|
||||||
|
|
||||||
|
def unproject(self, X, mask, k, return_Pinv=False):
|
||||||
|
"""Find the vector along the first k nonzero modes whose mask is closest to `X`"""
|
||||||
|
l_index = self.l_index
|
||||||
|
PP = self.P[mask, : k + l_index]
|
||||||
|
Pinv = np.linalg.pinv(PP)
|
||||||
|
out = self.P[:, : k + l_index] @ Pinv @ X
|
||||||
|
if return_Pinv:
|
||||||
|
return out, Pinv
|
||||||
|
return out
|
||||||
|
|
||||||
|
def energy(self, X):
|
||||||
|
"""Calculate the energy using the Harmonic SDE."""
|
||||||
|
l_index = self.l_index
|
||||||
|
return (self.D[:, None] * (self.P.T @ X) ** 2).sum(-1)[l_index:] / 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def free_energy(self):
|
||||||
|
"""Calculate the free energy using the Harmonic SDE."""
|
||||||
|
l_index = self.l_index
|
||||||
|
return 3 * np.log(self.D[l_index:]).sum() / 2
|
||||||
|
|
||||||
|
def KL_H(self, t):
|
||||||
|
"""Calculate the Kullback-Leibler divergence using the Harmonic SDE."""
|
||||||
|
l_index = self.l_index
|
||||||
|
D = self.D[l_index:]
|
||||||
|
return -3 * 0.5 * (np.log(1 - np.exp(-D * t)) + np.exp(-D * t)).sum(0)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_gaussian_prior(
|
||||||
|
x0: torch.Tensor,
|
||||||
|
latent_converter: LatentCoordinateConverter,
|
||||||
|
sigma: float,
|
||||||
|
x0_sigma: float = 1e-4,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Sample noise from a Gaussian prior distribution.
|
||||||
|
|
||||||
|
:param x0: ground-truth tensor
|
||||||
|
:param latent_converter: The latent coordinate converter
|
||||||
|
:param sigma: standard deviation of the Gaussian noise
|
||||||
|
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
|
||||||
|
:return: tuple of ground-truth and predicted tensors with additive Gaussian prior noise
|
||||||
|
"""
|
||||||
|
prior = torch.randn_like(x0)
|
||||||
|
x_int_0 = x0 + prior * x0_sigma # add small Gaussian noise to the ground-truth tensor
|
||||||
|
(
|
||||||
|
x1_ca_lat,
|
||||||
|
x1_cother_lat,
|
||||||
|
x1_lig_lat,
|
||||||
|
) = torch.split(
|
||||||
|
prior * sigma,
|
||||||
|
[
|
||||||
|
latent_converter._n_res_per_sample,
|
||||||
|
latent_converter._n_cother_per_sample,
|
||||||
|
latent_converter._n_ligha_per_sample,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
x_int_1 = torch.cat(
|
||||||
|
[
|
||||||
|
x1_ca_lat,
|
||||||
|
x1_cother_lat,
|
||||||
|
x1_lig_lat,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
return x_int_0, x_int_1
|
||||||
|
|
||||||
|
|
||||||
|
def sample_protein_harmonic_prior(
|
||||||
|
protein_ca_x0: torch.Tensor,
|
||||||
|
protein_cother_x0: torch.Tensor,
|
||||||
|
batch: MODEL_BATCH,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Sample protein noise from a harmonic prior distribution.
|
||||||
|
Adapted from: https://github.com/bjing2016/alphaflow
|
||||||
|
|
||||||
|
Note that this function represents non-Ca atoms as Gaussian noise
|
||||||
|
centered around each harmonically-noised Ca atom.
|
||||||
|
|
||||||
|
:param protein_ca_x0: ground-truth protein Ca-atom tensor
|
||||||
|
:param protein_cother_x0: ground-truth protein other-atom tensor
|
||||||
|
:param batch: A batch dictionary
|
||||||
|
:return: tuple of harmonic protein Ca atom noise and Gaussian protein other atom noise
|
||||||
|
"""
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
protein_bid = indexer["gather_idx_a_structid"]
|
||||||
|
protein_num_nodes = protein_ca_x0.size(0) * protein_ca_x0.size(1)
|
||||||
|
ptr = torch.cumsum(torch.bincount(protein_bid), dim=0)
|
||||||
|
ptr = torch.cat((torch.tensor([0], device=protein_bid.device), ptr))
|
||||||
|
try:
|
||||||
|
D_inv, P = HarmonicSDE.diagonalize(
|
||||||
|
protein_num_nodes,
|
||||||
|
ptr,
|
||||||
|
a=3 / (3.8**2),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to call HarmonicSDE.diagonalize() for protein(s) {metadata['sample_ID_per_sample']} due to: {e}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
noise = torch.randn((protein_num_nodes, 3), device=protein_ca_x0.device)
|
||||||
|
harmonic_ca_noise = P @ (torch.sqrt(D_inv)[:, None] * noise)
|
||||||
|
gaussian_cother_noise = (
|
||||||
|
torch.randn_like(protein_cother_x0.flatten(0, 1))
|
||||||
|
+ harmonic_ca_noise[indexer["gather_idx_a_cotherid"]]
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
harmonic_ca_noise.view(protein_ca_x0.size()).contiguous(),
|
||||||
|
gaussian_cother_noise.view(protein_cother_x0.size()).contiguous(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_ligand_harmonic_prior(
|
||||||
|
lig_x0: torch.Tensor, protein_ca_x0: torch.Tensor, batch: MODEL_BATCH, sigma: float = 1.0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Sample ligand noise from a harmonic prior distribution.
|
||||||
|
Adapted from: https://github.com/HannesStark/FlowSite
|
||||||
|
|
||||||
|
:param lig_x0: ground-truth ligand tensor
|
||||||
|
:param protein_x0: ground-truth protein Ca-atom tensor
|
||||||
|
:param batch: A batch dictionary
|
||||||
|
:param sigma: standard deviation of the harmonic noise
|
||||||
|
:return: tensor of harmonic noise
|
||||||
|
"""
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
lig_num_nodes = lig_x0.size(0) * lig_x0.size(1)
|
||||||
|
num_molid_per_sample = max(metadata["num_molid_per_sample"])
|
||||||
|
# NOTE: here, we distinguish each ligand chain in a complex for harmonic chain sampling
|
||||||
|
lig_bid = indexer["gather_idx_i_molid"]
|
||||||
|
protein_sigma = (
|
||||||
|
segment_mean(
|
||||||
|
torch.square(protein_ca_x0).flatten(0, 1),
|
||||||
|
indexer["gather_idx_a_structid"],
|
||||||
|
metadata["num_structid"],
|
||||||
|
).mean(dim=-1)
|
||||||
|
** 0.5
|
||||||
|
).repeat_interleave(num_molid_per_sample)
|
||||||
|
sde = DiffusionSDE(protein_sigma * sigma)
|
||||||
|
edges = torch.stack(
|
||||||
|
(
|
||||||
|
indexer["gather_idx_ij_i"],
|
||||||
|
indexer["gather_idx_ij_j"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
edges = edges[:, edges[0] < edges[1]] # de-duplicate edges
|
||||||
|
ptr = torch.cumsum(torch.bincount(lig_bid), dim=0)
|
||||||
|
ptr = torch.cat((torch.tensor([0], device=lig_bid.device), ptr))
|
||||||
|
try:
|
||||||
|
D, P = HarmonicSDE.diagonalize(
|
||||||
|
lig_num_nodes,
|
||||||
|
ptr,
|
||||||
|
edges=edges.T,
|
||||||
|
lamb=sde.lamb[lig_bid],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to call HarmonicSDE.diagonalize() for ligand(s) {metadata['sample_ID_per_sample']} due to: {e}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
noise = torch.randn((lig_num_nodes, 3), device=lig_x0.device)
|
||||||
|
prior = P @ (noise / torch.sqrt(D)[:, None])
|
||||||
|
return prior.view(lig_x0.size()).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def sample_complex_harmonic_prior(
|
||||||
|
x0: torch.Tensor,
|
||||||
|
latent_converter: LatentCoordinateConverter,
|
||||||
|
batch: MODEL_BATCH,
|
||||||
|
x0_sigma: float = 1e-4,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Sample protein-ligand complex noise from a harmonic prior distribution.
|
||||||
|
From: https://github.com/bjing2016/alphaflow
|
||||||
|
|
||||||
|
:param x0: ground-truth tensor
|
||||||
|
:param latent_converter: The latent coordinate converter
|
||||||
|
:param batch: A batch dictionary
|
||||||
|
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
|
||||||
|
:return: tuple of ground-truth and predicted tensors with additive Gaussian and harmonic prior noise, respectively
|
||||||
|
"""
|
||||||
|
ca_lat, cother_lat, lig_lat = x0.split(
|
||||||
|
[
|
||||||
|
latent_converter._n_res_per_sample,
|
||||||
|
latent_converter._n_cother_per_sample,
|
||||||
|
latent_converter._n_ligha_per_sample,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
harmonic_ca_lat, gaussian_cother_lat = sample_protein_harmonic_prior(
|
||||||
|
ca_lat,
|
||||||
|
cother_lat,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
harmonic_lig_lat = sample_ligand_harmonic_prior(lig_lat, harmonic_ca_lat, batch)
|
||||||
|
x1 = torch.cat(
|
||||||
|
[
|
||||||
|
# NOTE: the following normalization steps assume that `self.latent_model == "default"`
|
||||||
|
harmonic_ca_lat / latent_converter.ca_scale,
|
||||||
|
gaussian_cother_lat / latent_converter.other_scale,
|
||||||
|
harmonic_lig_lat / latent_converter.other_scale,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
gaussian_prior = torch.randn_like(x0)
|
||||||
|
return x0 + gaussian_prior * x0_sigma, x1
|
||||||
|
|
||||||
|
|
||||||
|
def sample_esmfold_prior(
|
||||||
|
x0: torch.Tensor, x1: torch.Tensor, sigma: float, x0_sigma: float = 1e-4
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Sample noise from an ESMFold prior distribution.
|
||||||
|
|
||||||
|
:param x0: ground-truth tensor
|
||||||
|
:param x1: predicted tensor
|
||||||
|
:param sigma: standard deviation of the ESMFold prior's additive Gaussian noise
|
||||||
|
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
|
||||||
|
:return: tuple of ground-truth and predicted tensors with additive Gaussian prior noise
|
||||||
|
"""
|
||||||
|
prior_noise = torch.randn_like(x0)
|
||||||
|
return x0 + prior_noise * x0_sigma, x1 + prior_noise * sigma
|
||||||
241
flowdock/models/components/transforms.py
Normal file
241
flowdock/models/components/transforms.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||||
|
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.utils.model_utils import segment_mean
|
||||||
|
|
||||||
|
|
||||||
|
class LatentCoordinateConverter:
|
||||||
|
"""Transform the batched feature dict to latent coordinate arrays."""
|
||||||
|
|
||||||
|
def __init__(self, config, prot_atom37_namemap, lig_namemap):
|
||||||
|
"""Initialize the converter."""
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.prot_namemap = prot_atom37_namemap
|
||||||
|
self.lig_namemap = lig_namemap
|
||||||
|
self.cached_noise = None
|
||||||
|
self._last_pred_ca_trace = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def nested_get(dic, keys):
|
||||||
|
"""Get the value in the nested dictionary."""
|
||||||
|
for key in keys:
|
||||||
|
dic = dic[key]
|
||||||
|
return dic
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def nested_set(dic, keys, value):
|
||||||
|
"""Set the value in the nested dictionary."""
|
||||||
|
for key in keys[:-1]:
|
||||||
|
dic = dic.setdefault(key, {})
|
||||||
|
dic[keys[-1]] = value
|
||||||
|
|
||||||
|
def to_latent(self, batch):
|
||||||
|
"""Convert the batched feature dict to latent coordinates."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def assign_to_batch(self, batch, x_int):
|
||||||
|
"""Assign the latent coordinates to the batched feature dict."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultPLCoordinateConverter(LatentCoordinateConverter):
|
||||||
|
"""Minimal conversion, using internal coords for sidechains and global coords for others."""
|
||||||
|
|
||||||
|
def __init__(self, config, prot_atom37_namemap, lig_namemap):
|
||||||
|
"""Initialize the converter."""
|
||||||
|
super().__init__(config, prot_atom37_namemap, lig_namemap)
|
||||||
|
# Scale parameters in Angstrom
|
||||||
|
self.ca_scale = config.global_max_sigma
|
||||||
|
self.other_scale = config.internal_max_sigma
|
||||||
|
|
||||||
|
def to_latent(self, batch: dict):
|
||||||
|
"""Convert the batched feature dict to latent coordinates."""
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
self._batch_size = metadata["num_structid"]
|
||||||
|
atom37_mask = batch["features"]["res_atom_mask"].bool()
|
||||||
|
self._cother_mask = atom37_mask.clone()
|
||||||
|
self._cother_mask[:, 1] = False
|
||||||
|
atom37_coords = self.nested_get(batch, self.prot_namemap[0])
|
||||||
|
try:
|
||||||
|
apo_available = True
|
||||||
|
apo_atom37_coords = self.nested_get(
|
||||||
|
batch, self.prot_namemap[0][:-1] + ("apo_" + self.prot_namemap[0][-1],)
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
apo_available = False
|
||||||
|
apo_atom37_coords = torch.zeros_like(atom37_coords)
|
||||||
|
ca_atom_centroid_coords = segment_mean(
|
||||||
|
# NOTE: in contrast to NeuralPLexer, we center all coordinates at the origin using the Ca atom centroids
|
||||||
|
atom37_coords[:, 1],
|
||||||
|
indexer["gather_idx_a_structid"],
|
||||||
|
self._batch_size,
|
||||||
|
)
|
||||||
|
if apo_available:
|
||||||
|
apo_ca_atom_centroid_coords = segment_mean(
|
||||||
|
apo_atom37_coords[:, 1],
|
||||||
|
indexer["gather_idx_a_structid"],
|
||||||
|
self._batch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
apo_ca_atom_centroid_coords = torch.zeros_like(ca_atom_centroid_coords)
|
||||||
|
ca_coords_glob = (
|
||||||
|
(atom37_coords[:, 1] - ca_atom_centroid_coords[indexer["gather_idx_a_structid"]])
|
||||||
|
.contiguous()
|
||||||
|
.view(self._batch_size, -1, 3)
|
||||||
|
)
|
||||||
|
if apo_available:
|
||||||
|
apo_ca_coords_glob = (
|
||||||
|
(
|
||||||
|
apo_atom37_coords[:, 1]
|
||||||
|
- apo_ca_atom_centroid_coords[indexer["gather_idx_a_structid"]]
|
||||||
|
)
|
||||||
|
.contiguous()
|
||||||
|
.view(self._batch_size, -1, 3)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
apo_ca_coords_glob = torch.zeros_like(ca_coords_glob)
|
||||||
|
cother_coords_int = (
|
||||||
|
(atom37_coords - ca_atom_centroid_coords[indexer["gather_idx_a_structid"], None])[
|
||||||
|
self._cother_mask
|
||||||
|
]
|
||||||
|
.contiguous()
|
||||||
|
.view(self._batch_size, -1, 3)
|
||||||
|
)
|
||||||
|
if apo_available:
|
||||||
|
apo_cother_coords_int = (
|
||||||
|
(
|
||||||
|
apo_atom37_coords
|
||||||
|
- apo_ca_atom_centroid_coords[indexer["gather_idx_a_structid"], None]
|
||||||
|
)[self._cother_mask]
|
||||||
|
.contiguous()
|
||||||
|
.view(self._batch_size, -1, 3)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
apo_cother_coords_int = torch.zeros_like(cother_coords_int)
|
||||||
|
self._n_res_per_sample = ca_coords_glob.shape[1]
|
||||||
|
self._n_cother_per_sample = cother_coords_int.shape[1]
|
||||||
|
if batch["misc"]["protein_only"]:
|
||||||
|
self._n_ligha_per_sample = 0
|
||||||
|
x_int = torch.cat(
|
||||||
|
[
|
||||||
|
ca_coords_glob / self.ca_scale,
|
||||||
|
apo_ca_coords_glob / self.ca_scale,
|
||||||
|
cother_coords_int / self.other_scale,
|
||||||
|
apo_cother_coords_int / self.other_scale,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
return x_int
|
||||||
|
lig_ha_coords = self.nested_get(batch, self.lig_namemap[0])
|
||||||
|
lig_ha_coords_int = (
|
||||||
|
lig_ha_coords - ca_atom_centroid_coords[indexer["gather_idx_i_structid"]]
|
||||||
|
)
|
||||||
|
lig_ha_coords_int = lig_ha_coords_int.contiguous().view(self._batch_size, -1, 3)
|
||||||
|
ca_atom_centroid_coords = ca_atom_centroid_coords.contiguous().view(
|
||||||
|
self._batch_size, -1, 3
|
||||||
|
)
|
||||||
|
apo_ca_atom_centroid_coords = apo_ca_atom_centroid_coords.contiguous().view(
|
||||||
|
self._batch_size, -1, 3
|
||||||
|
)
|
||||||
|
x_int = torch.cat(
|
||||||
|
[
|
||||||
|
ca_coords_glob / self.ca_scale,
|
||||||
|
apo_ca_coords_glob / self.ca_scale,
|
||||||
|
cother_coords_int / self.other_scale,
|
||||||
|
apo_cother_coords_int / self.other_scale,
|
||||||
|
ca_atom_centroid_coords / self.ca_scale,
|
||||||
|
apo_ca_atom_centroid_coords / self.ca_scale,
|
||||||
|
lig_ha_coords_int / self.other_scale,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
# NOTE: since we use the Ca atom centroids for centralization, we have only one molid per sample
|
||||||
|
self._n_molid_per_sample = ca_atom_centroid_coords.shape[1]
|
||||||
|
self._n_ligha_per_sample = lig_ha_coords_int.shape[1]
|
||||||
|
return x_int
|
||||||
|
|
||||||
|
def assign_to_batch(self, batch: dict, x_lat: torch.Tensor):
|
||||||
|
"""Assign the latent coordinates to the batched feature dict."""
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
new_atom37_coords = x_lat.new_zeros(self._batch_size * self._n_res_per_sample, 37, 3)
|
||||||
|
apo_new_atom37_coords = x_lat.new_zeros(self._batch_size * self._n_res_per_sample, 37, 3)
|
||||||
|
if batch["misc"]["protein_only"]:
|
||||||
|
ca_lat, apo_ca_lat, cother_lat, apo_cother_lat = torch.split(
|
||||||
|
x_lat,
|
||||||
|
[
|
||||||
|
self._n_res_per_sample,
|
||||||
|
self._n_res_per_sample,
|
||||||
|
self._n_cother_per_sample,
|
||||||
|
self._n_cother_per_sample,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
ca_lat,
|
||||||
|
apo_ca_lat,
|
||||||
|
cother_lat,
|
||||||
|
apo_cother_lat,
|
||||||
|
ca_cent_lat,
|
||||||
|
_,
|
||||||
|
lig_lat,
|
||||||
|
) = torch.split(
|
||||||
|
x_lat,
|
||||||
|
[
|
||||||
|
self._n_res_per_sample,
|
||||||
|
self._n_res_per_sample,
|
||||||
|
self._n_cother_per_sample,
|
||||||
|
self._n_cother_per_sample,
|
||||||
|
self._n_molid_per_sample,
|
||||||
|
self._n_molid_per_sample,
|
||||||
|
self._n_ligha_per_sample,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
new_ca_glob = (ca_lat * self.ca_scale).contiguous().flatten(0, 1)
|
||||||
|
apo_new_ca_glob = (apo_ca_lat * self.ca_scale).contiguous().flatten(0, 1)
|
||||||
|
new_atom37_coords[self._cother_mask] = (
|
||||||
|
(cother_lat * self.other_scale).contiguous().flatten(0, 1)
|
||||||
|
)
|
||||||
|
apo_new_atom37_coords[self._cother_mask] = (
|
||||||
|
(apo_cother_lat * self.other_scale).contiguous().flatten(0, 1)
|
||||||
|
)
|
||||||
|
new_atom37_coords = new_atom37_coords
|
||||||
|
apo_new_atom37_coords = apo_new_atom37_coords
|
||||||
|
new_atom37_coords[~self._cother_mask] = 0
|
||||||
|
apo_new_atom37_coords[~self._cother_mask] = 0
|
||||||
|
new_atom37_coords[:, 1] = new_ca_glob
|
||||||
|
apo_new_atom37_coords[:, 1] = apo_new_ca_glob
|
||||||
|
self.nested_set(batch, self.prot_namemap[1], new_atom37_coords)
|
||||||
|
self.nested_set(
|
||||||
|
batch,
|
||||||
|
self.prot_namemap[1][:-1] + ("apo_" + self.prot_namemap[1][-1],),
|
||||||
|
apo_new_atom37_coords,
|
||||||
|
)
|
||||||
|
if batch["misc"]["protein_only"]:
|
||||||
|
self.nested_set(batch, self.lig_namemap[1], None)
|
||||||
|
self.empty_cache()
|
||||||
|
return batch
|
||||||
|
new_ligha_coords_int = (lig_lat * self.other_scale).contiguous().flatten(0, 1)
|
||||||
|
new_ligha_coords_cent = (ca_cent_lat * self.ca_scale).contiguous().flatten(0, 1)
|
||||||
|
new_ligha_coords = (
|
||||||
|
new_ligha_coords_int + new_ligha_coords_cent[indexer["gather_idx_i_structid"]]
|
||||||
|
)
|
||||||
|
self.nested_set(batch, self.lig_namemap[1], new_ligha_coords)
|
||||||
|
self.empty_cache()
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def empty_cache(self):
|
||||||
|
"""Empty the cached variables."""
|
||||||
|
self._batch_size = None
|
||||||
|
self._cother_mask = None
|
||||||
|
self._n_res_per_sample = None
|
||||||
|
self._n_cother_per_sample = None
|
||||||
|
self._n_ligha_per_sample = None
|
||||||
|
self._n_molid_per_sample = None
|
||||||
943
flowdock/models/flowdock_fm_module.py
Normal file
943
flowdock/models/flowdock_fm_module.py
Normal file
@@ -0,0 +1,943 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import esm
|
||||||
|
import numpy as np
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, Literal, Optional, Union
|
||||||
|
from lightning import LightningModule
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from torchmetrics.functional.regression import (
|
||||||
|
mean_absolute_error,
|
||||||
|
mean_squared_error,
|
||||||
|
pearson_corrcoef,
|
||||||
|
spearman_corrcoef,
|
||||||
|
)
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.models.components.losses import (
|
||||||
|
eval_auxiliary_estimation_losses,
|
||||||
|
eval_structure_prediction_losses,
|
||||||
|
)
|
||||||
|
from flowdock.utils import RankedLogger
|
||||||
|
from flowdock.utils.data_utils import pdb_filepath_to_protein, prepare_batch
|
||||||
|
from flowdock.utils.model_utils import extract_esm_embeddings
|
||||||
|
from flowdock.utils.sampling_utils import multi_pose_sampling
|
||||||
|
from flowdock.utils.visualization_utils import (
|
||||||
|
construct_prot_lig_pairs,
|
||||||
|
write_prot_lig_pairs_to_pdb_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_BATCH = Dict[str, Any]
|
||||||
|
MODEL_STAGE = Literal["train", "val", "test", "predict"]
|
||||||
|
LOSS_MODES_LIST = [
|
||||||
|
"structure_prediction",
|
||||||
|
"auxiliary_estimation",
|
||||||
|
"auxiliary_estimation_without_structure_prediction",
|
||||||
|
]
|
||||||
|
LOSS_MODES = Literal[
|
||||||
|
"structure_prediction",
|
||||||
|
"auxiliary_estimation",
|
||||||
|
"auxiliary_estimation_without_structure_prediction",
|
||||||
|
]
|
||||||
|
AUX_ESTIMATION_STAGES = ["train", "val", "test"]
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
class FlowDockFMLitModule(LightningModule):
|
||||||
|
"""A `LightningModule` for geometric flow matching (FM) with FlowDock.
|
||||||
|
|
||||||
|
A `LightningModule` implements 8 key methods:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __init__(self):
|
||||||
|
# Define initialization code here.
|
||||||
|
|
||||||
|
def setup(self, stage):
|
||||||
|
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
|
||||||
|
# This hook is called on every process when using DDP.
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# The complete training step.
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
# The complete validation step.
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
# The complete test step.
|
||||||
|
|
||||||
|
def predict_step(self, batch, batch_idx):
|
||||||
|
# The complete predict step.
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
# Define and configure optimizers and LR schedulers.
|
||||||
|
```
|
||||||
|
|
||||||
|
Docs:
|
||||||
|
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
net: torch.nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
scheduler: torch.optim.lr_scheduler,
|
||||||
|
compile: bool,
|
||||||
|
cfg: DictConfig,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""Initialize a `FlowDockFMLitModule`.
|
||||||
|
|
||||||
|
:param net: The model to train.
|
||||||
|
:param optimizer: The optimizer to use for training.
|
||||||
|
:param scheduler: The learning rate scheduler to use for training.
|
||||||
|
:param compile: Whether to compile the model before training.
|
||||||
|
:param cfg: The model configuration.
|
||||||
|
:param kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# the model along with its hyperparameters
|
||||||
|
self.net = net(cfg)
|
||||||
|
|
||||||
|
# this line allows to access init params with 'self.hparams' attribute
|
||||||
|
# also ensures init params will be stored in ckpt
|
||||||
|
self.save_hyperparameters(logger=False, ignore=["net"])
|
||||||
|
|
||||||
|
# for validating input arguments
|
||||||
|
if self.hparams.cfg.task.loss_mode not in LOSS_MODES_LIST:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid loss mode: {self.hparams.cfg.task.loss_mode}. Must be one of {LOSS_MODES}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# for inspecting the model's outputs during validation and testing
|
||||||
|
(
|
||||||
|
self.training_step_outputs,
|
||||||
|
self.validation_step_outputs,
|
||||||
|
self.test_step_outputs,
|
||||||
|
self.predict_step_outputs,
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: MODEL_BATCH,
|
||||||
|
iter_id: Union[int, str] = 0,
|
||||||
|
observed_block_contacts: Optional[torch.Tensor] = None,
|
||||||
|
contact_prediction: bool = True,
|
||||||
|
infer_geometry_prior: bool = False,
|
||||||
|
score: bool = False,
|
||||||
|
affinity: bool = True,
|
||||||
|
use_template: bool = False,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> MODEL_BATCH:
|
||||||
|
"""Perform a forward pass through the model.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param iter_id: The current iteration ID.
|
||||||
|
:param observed_block_contacts: Observed block contacts.
|
||||||
|
:param contact_prediction: Whether to predict contacts.
|
||||||
|
:param infer_geometry_prior: Whether to predict using a geometry prior.
|
||||||
|
:param score: Whether to predict a denoised complex structure.
|
||||||
|
:param affinity: Whether to predict ligand binding affinity.
|
||||||
|
:param use_template: Whether to use a template protein structure.
|
||||||
|
:param kwargs: Additional keyword arguments.
|
||||||
|
:return: Batch dictionary with outputs.
|
||||||
|
"""
|
||||||
|
return self.net(
|
||||||
|
batch,
|
||||||
|
iter_id=iter_id,
|
||||||
|
observed_block_contacts=observed_block_contacts,
|
||||||
|
contact_prediction=contact_prediction,
|
||||||
|
infer_geometry_prior=infer_geometry_prior,
|
||||||
|
score=score,
|
||||||
|
affinity=affinity,
|
||||||
|
use_template=use_template,
|
||||||
|
training=self.training,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_step(
|
||||||
|
self,
|
||||||
|
batch: MODEL_BATCH,
|
||||||
|
batch_idx: int,
|
||||||
|
stage: MODEL_STAGE,
|
||||||
|
loss_mode: Optional[LOSS_MODES] = None,
|
||||||
|
) -> MODEL_BATCH:
|
||||||
|
"""Perform a single model step on a batch of data.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param batch_idx: The index of the current batch.
|
||||||
|
:param stage: The current model stage (i.e., `train`, `val`, `test`, or `predict`).
|
||||||
|
:param loss_mode: The loss mode to use for training.
|
||||||
|
:return: Batch dictionary with losses.
|
||||||
|
"""
|
||||||
|
prepare_batch(batch)
|
||||||
|
predicting_aux_outputs = (
|
||||||
|
self.hparams.cfg.confidence.enabled or self.hparams.cfg.affinity.enabled
|
||||||
|
)
|
||||||
|
is_aux_loss_stage = stage in AUX_ESTIMATION_STAGES
|
||||||
|
is_aux_batch = batch_idx % self.hparams.cfg.task.aux_batch_freq == 0
|
||||||
|
struct_pred_loss_mode_requested = (
|
||||||
|
loss_mode is not None and loss_mode == "structure_prediction"
|
||||||
|
)
|
||||||
|
should_eval_aux_loss = (
|
||||||
|
predicting_aux_outputs
|
||||||
|
and is_aux_loss_stage
|
||||||
|
and is_aux_batch
|
||||||
|
and not struct_pred_loss_mode_requested
|
||||||
|
and (
|
||||||
|
not self.hparams.cfg.task.freeze_confidence
|
||||||
|
or (
|
||||||
|
not self.hparams.cfg.task.freeze_affinity
|
||||||
|
and batch["features"]["affinity"].any().item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
eval_aux_loss_mode_requested = (
|
||||||
|
predicting_aux_outputs
|
||||||
|
and loss_mode is not None
|
||||||
|
and "auxiliary_estimation" in loss_mode
|
||||||
|
)
|
||||||
|
if should_eval_aux_loss or eval_aux_loss_mode_requested:
|
||||||
|
return eval_auxiliary_estimation_losses(
|
||||||
|
self, batch, stage, loss_mode, training=self.training
|
||||||
|
)
|
||||||
|
loss_fn = eval_structure_prediction_losses
|
||||||
|
return loss_fn(self, batch, batch_idx, self.device, stage, t_1=1.0)
|
||||||
|
|
||||||
|
def on_train_start(self):
|
||||||
|
"""Lightning hook that is called when training begins."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def training_step(self, batch: MODEL_BATCH, batch_idx: int) -> torch.Tensor:
|
||||||
|
"""Perform a single training step on a batch of data from the training set.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param batch_idx: The index of the current batch.
|
||||||
|
:return: A tensor of losses between model predictions and targets.
|
||||||
|
"""
|
||||||
|
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
|
||||||
|
name == self.hparams.cfg.task.overfitting_example_name
|
||||||
|
for name in batch["metadata"]["sample_ID_per_sample"]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
batch = self.model_step(batch, batch_idx, "train")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to perform training step for batch index {batch_idx} due to: {e}. Skipping example."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
|
||||||
|
training_outputs = {
|
||||||
|
"affinity_logits": batch["outputs"]["affinity_logits"],
|
||||||
|
"affinity": batch["features"]["affinity"],
|
||||||
|
}
|
||||||
|
self.training_step_outputs.append(training_outputs)
|
||||||
|
|
||||||
|
# return loss or backpropagation will fail
|
||||||
|
return batch["outputs"]["loss"]
|
||||||
|
|
||||||
|
def on_train_epoch_end(self):
|
||||||
|
"""Lightning hook that is called when a training epoch ends."""
|
||||||
|
if self.hparams.cfg.affinity.enabled and any(
|
||||||
|
"affinity_logits" in output for output in self.training_step_outputs
|
||||||
|
):
|
||||||
|
affinity_logits = torch.cat(
|
||||||
|
[
|
||||||
|
output["affinity_logits"]
|
||||||
|
for output in self.training_step_outputs
|
||||||
|
if "affinity_logits" in output
|
||||||
|
]
|
||||||
|
)
|
||||||
|
affinity = torch.cat(
|
||||||
|
[
|
||||||
|
output["affinity"]
|
||||||
|
for output in self.training_step_outputs
|
||||||
|
if "affinity_logits" in output
|
||||||
|
]
|
||||||
|
)
|
||||||
|
affinity_logits = affinity_logits[~affinity.isnan()]
|
||||||
|
affinity = affinity[~affinity.isnan()]
|
||||||
|
if affinity.numel() > 1:
|
||||||
|
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
|
||||||
|
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
|
||||||
|
aff_mae = mean_absolute_error(affinity_logits, affinity)
|
||||||
|
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
|
||||||
|
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
|
||||||
|
self.log(
|
||||||
|
"train_affinity/RMSE",
|
||||||
|
aff_rmse.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=False,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"train_affinity/MAE",
|
||||||
|
aff_mae.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=False,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"train_affinity/Pearson",
|
||||||
|
aff_pearson.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=False,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"train_affinity/Spearman",
|
||||||
|
aff_spearman.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=False,
|
||||||
|
)
|
||||||
|
self.training_step_outputs.clear() # free memory
|
||||||
|
|
||||||
|
def on_validation_start(self):
|
||||||
|
"""Lightning hook that is called when validation begins."""
|
||||||
|
# create a directory to store model outputs from each validation epoch
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(self.trainer.default_root_dir, "validation_epoch_outputs"), exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def validation_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
|
||||||
|
"""Perform a single validation step on a batch of data from the validation set.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param batch_idx: The index of the current batch.
|
||||||
|
:param dataloader_idx: The index of the current dataloader.
|
||||||
|
"""
|
||||||
|
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
|
||||||
|
name == self.hparams.cfg.task.overfitting_example_name
|
||||||
|
for name in batch["metadata"]["sample_ID_per_sample"]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
prepare_batch(batch)
|
||||||
|
sampling_stats = self.net.sample_pl_complex_structures(
|
||||||
|
batch,
|
||||||
|
sampler="VDODE",
|
||||||
|
sampler_eta=1.0,
|
||||||
|
num_steps=10,
|
||||||
|
start_time=1.0,
|
||||||
|
exact_prior=False,
|
||||||
|
return_all_states=True,
|
||||||
|
eval_input_protein=True,
|
||||||
|
)
|
||||||
|
all_frames = sampling_stats["all_frames"]
|
||||||
|
del sampling_stats["all_frames"]
|
||||||
|
for metric_name in sampling_stats.keys():
|
||||||
|
log_stat = sampling_stats[metric_name].mean().detach()
|
||||||
|
batch_size = sampling_stats[metric_name].shape[0]
|
||||||
|
self.log(
|
||||||
|
f"val_sampling/{metric_name}",
|
||||||
|
log_stat,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
sampling_stats = self.net.sample_pl_complex_structures(
|
||||||
|
batch,
|
||||||
|
sampler="VDODE",
|
||||||
|
sampler_eta=1.0,
|
||||||
|
num_steps=10,
|
||||||
|
start_time=1.0,
|
||||||
|
exact_prior=False,
|
||||||
|
use_template=False,
|
||||||
|
)
|
||||||
|
for metric_name in sampling_stats.keys():
|
||||||
|
log_stat = sampling_stats[metric_name].mean().detach()
|
||||||
|
batch_size = sampling_stats[metric_name].shape[0]
|
||||||
|
self.log(
|
||||||
|
f"val_sampling_notemplate/{metric_name}",
|
||||||
|
log_stat,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
sampling_stats = self.net.sample_pl_complex_structures(
|
||||||
|
batch,
|
||||||
|
sampler="VDODE",
|
||||||
|
sampler_eta=1.0,
|
||||||
|
num_steps=10,
|
||||||
|
start_time=1.0,
|
||||||
|
return_summary_stats=True,
|
||||||
|
exact_prior=True,
|
||||||
|
)
|
||||||
|
for metric_name in sampling_stats.keys():
|
||||||
|
log_stat = sampling_stats[metric_name].mean().detach()
|
||||||
|
batch_size = sampling_stats[metric_name].shape[0]
|
||||||
|
self.log(
|
||||||
|
f"val_sampling_trueprior/{metric_name}",
|
||||||
|
log_stat,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
batch = self.model_step(batch, batch_idx, "val")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to perform validation step for batch index {batch_idx} of dataloader {dataloader_idx} due to: {e}. Skipping example."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# store model outputs for inspection
|
||||||
|
validation_outputs = {}
|
||||||
|
if self.hparams.cfg.task.visualize_generated_samples:
|
||||||
|
validation_outputs = {
|
||||||
|
"name": batch["metadata"]["sample_ID_per_sample"],
|
||||||
|
"batch_size": batch["metadata"]["num_structid"],
|
||||||
|
"aatype": batch["features"]["res_type"].long().cpu().numpy(),
|
||||||
|
"res_atom_mask": batch["features"]["res_atom_mask"].cpu().numpy(),
|
||||||
|
"protein_coordinates_list": [
|
||||||
|
frame["receptor_padded"].cpu().numpy() for frame in all_frames
|
||||||
|
],
|
||||||
|
"ligand_coordinates_list": [
|
||||||
|
frame["ligands"].cpu().numpy() for frame in all_frames
|
||||||
|
],
|
||||||
|
"ligand_mol": batch["metadata"]["mol_per_sample"],
|
||||||
|
"protein_batch_indexer": batch["indexer"]["gather_idx_a_structid"].cpu().numpy(),
|
||||||
|
"ligand_batch_indexer": batch["indexer"]["gather_idx_i_structid"].cpu().numpy(),
|
||||||
|
"gt_protein_coordinates": batch["features"]["res_atom_positions"].cpu().numpy(),
|
||||||
|
"gt_ligand_coordinates": batch["features"]["sdf_coordinates"].cpu().numpy(),
|
||||||
|
"dataloader_idx": dataloader_idx,
|
||||||
|
}
|
||||||
|
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
|
||||||
|
validation_outputs.update(
|
||||||
|
{
|
||||||
|
"affinity_logits": batch["outputs"]["affinity_logits"],
|
||||||
|
"affinity": batch["features"]["affinity"],
|
||||||
|
"dataloader_idx": dataloader_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if validation_outputs:
|
||||||
|
self.validation_step_outputs.append(validation_outputs)
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self):
|
||||||
|
"Lightning hook that is called when a validation epoch ends."
|
||||||
|
if self.hparams.cfg.task.visualize_generated_samples:
|
||||||
|
for i, outputs in enumerate(self.validation_step_outputs):
|
||||||
|
for batch_index in range(outputs["batch_size"]):
|
||||||
|
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
|
||||||
|
write_prot_lig_pairs_to_pdb_file(
|
||||||
|
prot_lig_pairs,
|
||||||
|
os.path.join(
|
||||||
|
self.trainer.default_root_dir,
|
||||||
|
"validation_epoch_outputs",
|
||||||
|
f"{outputs['name'][batch_index]}_validation_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}_dataloader_{outputs['dataloader_idx']}.pdb",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.hparams.cfg.affinity.enabled and any(
|
||||||
|
"affinity_logits" in output for output in self.validation_step_outputs
|
||||||
|
):
|
||||||
|
affinity_logits = torch.cat(
|
||||||
|
[
|
||||||
|
output["affinity_logits"]
|
||||||
|
for output in self.validation_step_outputs
|
||||||
|
if "affinity_logits" in output
|
||||||
|
]
|
||||||
|
)
|
||||||
|
affinity = torch.cat(
|
||||||
|
[
|
||||||
|
output["affinity"]
|
||||||
|
for output in self.validation_step_outputs
|
||||||
|
if "affinity_logits" in output
|
||||||
|
]
|
||||||
|
)
|
||||||
|
affinity_logits = affinity_logits[~affinity.isnan()]
|
||||||
|
affinity = affinity[~affinity.isnan()]
|
||||||
|
if affinity.numel() > 1:
|
||||||
|
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
|
||||||
|
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
|
||||||
|
aff_mae = mean_absolute_error(affinity_logits, affinity)
|
||||||
|
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
|
||||||
|
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
|
||||||
|
self.log(
|
||||||
|
"val_affinity/RMSE",
|
||||||
|
aff_rmse.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"val_affinity/MAE",
|
||||||
|
aff_mae.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"val_affinity/Pearson",
|
||||||
|
aff_pearson.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"val_affinity/Spearman",
|
||||||
|
aff_spearman.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.validation_step_outputs.clear() # free memory
|
||||||
|
|
||||||
|
def on_test_start(self):
|
||||||
|
"""Lightning hook that is called when testing begins."""
|
||||||
|
# create a directory to store model outputs from each test epoch
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(self.trainer.default_root_dir, "test_epoch_outputs"), exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
|
||||||
|
"""Perform a single test step on a batch of data from the test set.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param batch_idx: The index of the current batch.
|
||||||
|
:param dataloader_idx: The index of the current dataloader.
|
||||||
|
"""
|
||||||
|
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
|
||||||
|
name == self.hparams.cfg.task.overfitting_example_name
|
||||||
|
for name in batch["metadata"]["sample_ID_per_sample"]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
prepare_batch(batch)
|
||||||
|
if self.hparams.cfg.task.eval_structure_prediction:
|
||||||
|
sampling_stats = self.net.sample_pl_complex_structures(
|
||||||
|
batch,
|
||||||
|
sampler=self.hparams.cfg.task.sampler,
|
||||||
|
sampler_eta=self.hparams.cfg.task.sampler_eta,
|
||||||
|
num_steps=self.hparams.cfg.task.num_steps,
|
||||||
|
start_time=self.hparams.cfg.task.start_time,
|
||||||
|
exact_prior=False,
|
||||||
|
return_all_states=True,
|
||||||
|
eval_input_protein=True,
|
||||||
|
)
|
||||||
|
all_frames = sampling_stats["all_frames"]
|
||||||
|
del sampling_stats["all_frames"]
|
||||||
|
for metric_name in sampling_stats.keys():
|
||||||
|
log_stat = sampling_stats[metric_name].mean().detach()
|
||||||
|
batch_size = sampling_stats[metric_name].shape[0]
|
||||||
|
self.log(
|
||||||
|
f"test_sampling/{metric_name}",
|
||||||
|
log_stat,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
sampling_stats = self.net.sample_pl_complex_structures(
|
||||||
|
batch,
|
||||||
|
sampler=self.hparams.cfg.task.sampler,
|
||||||
|
sampler_eta=self.hparams.cfg.task.sampler_eta,
|
||||||
|
num_steps=self.hparams.cfg.task.num_steps,
|
||||||
|
start_time=self.hparams.cfg.task.start_time,
|
||||||
|
exact_prior=False,
|
||||||
|
use_template=False,
|
||||||
|
)
|
||||||
|
for metric_name in sampling_stats.keys():
|
||||||
|
log_stat = sampling_stats[metric_name].mean().detach()
|
||||||
|
batch_size = sampling_stats[metric_name].shape[0]
|
||||||
|
self.log(
|
||||||
|
f"test_sampling_notemplate/{metric_name}",
|
||||||
|
log_stat,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
sampling_stats = self.net.sample_pl_complex_structures(
|
||||||
|
batch,
|
||||||
|
sampler=self.hparams.cfg.task.sampler,
|
||||||
|
sampler_eta=self.hparams.cfg.task.sampler_eta,
|
||||||
|
num_steps=self.hparams.cfg.task.num_steps,
|
||||||
|
start_time=self.hparams.cfg.task.start_time,
|
||||||
|
return_summary_stats=True,
|
||||||
|
exact_prior=True,
|
||||||
|
)
|
||||||
|
for metric_name in sampling_stats.keys():
|
||||||
|
log_stat = sampling_stats[metric_name].mean().detach()
|
||||||
|
batch_size = sampling_stats[metric_name].shape[0]
|
||||||
|
self.log(
|
||||||
|
f"test_sampling_trueprior/{metric_name}",
|
||||||
|
log_stat,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
batch = self.model_step(
|
||||||
|
batch, batch_idx, "test", loss_mode=self.hparams.cfg.task.loss_mode
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to perform test step for {batch['metadata']['sample_ID_per_sample']} with batch index {batch_idx} of dataloader {dataloader_idx} due to: {e}."
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# store model outputs for inspection
|
||||||
|
test_outputs = {}
|
||||||
|
if (
|
||||||
|
self.hparams.cfg.task.visualize_generated_samples
|
||||||
|
and self.hparams.cfg.task.eval_structure_prediction
|
||||||
|
):
|
||||||
|
test_outputs.update(
|
||||||
|
{
|
||||||
|
"name": batch["metadata"]["sample_ID_per_sample"],
|
||||||
|
"batch_size": batch["metadata"]["num_structid"],
|
||||||
|
"aatype": batch["features"]["res_type"].long().cpu().numpy(),
|
||||||
|
"res_atom_mask": batch["features"]["res_atom_mask"].cpu().numpy(),
|
||||||
|
"protein_coordinates_list": [
|
||||||
|
frame["receptor_padded"].cpu().numpy() for frame in all_frames
|
||||||
|
],
|
||||||
|
"ligand_coordinates_list": [
|
||||||
|
frame["ligands"].cpu().numpy() for frame in all_frames
|
||||||
|
],
|
||||||
|
"ligand_mol": batch["metadata"]["mol_per_sample"],
|
||||||
|
"protein_batch_indexer": batch["indexer"]["gather_idx_a_structid"]
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
|
"ligand_batch_indexer": batch["indexer"]["gather_idx_i_structid"]
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
|
"gt_protein_coordinates": batch["features"]["res_atom_positions"]
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
|
"gt_ligand_coordinates": batch["features"]["sdf_coordinates"].cpu().numpy(),
|
||||||
|
"dataloader_idx": dataloader_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
|
||||||
|
test_outputs.update(
|
||||||
|
{
|
||||||
|
"affinity_logits": batch["outputs"]["affinity_logits"],
|
||||||
|
"affinity": batch["features"]["affinity"],
|
||||||
|
"dataloader_idx": dataloader_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if test_outputs:
|
||||||
|
self.test_step_outputs.append(test_outputs)
|
||||||
|
|
||||||
|
def on_test_epoch_end(self):
|
||||||
|
"""Lightning hook that is called when a test epoch ends."""
|
||||||
|
if (
|
||||||
|
self.hparams.cfg.task.visualize_generated_samples
|
||||||
|
and self.hparams.cfg.task.eval_structure_prediction
|
||||||
|
):
|
||||||
|
for i, outputs in enumerate(self.test_step_outputs):
|
||||||
|
for batch_index in range(outputs["batch_size"]):
|
||||||
|
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
|
||||||
|
write_prot_lig_pairs_to_pdb_file(
|
||||||
|
prot_lig_pairs,
|
||||||
|
os.path.join(
|
||||||
|
self.trainer.default_root_dir,
|
||||||
|
"test_epoch_outputs",
|
||||||
|
f"{outputs['name'][batch_index]}_test_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}_dataloader_{outputs['dataloader_idx']}.pdb",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.hparams.cfg.affinity.enabled and any(
|
||||||
|
"affinity_logits" in output for output in self.test_step_outputs
|
||||||
|
):
|
||||||
|
affinity_logits = torch.cat(
|
||||||
|
[
|
||||||
|
output["affinity_logits"]
|
||||||
|
for output in self.test_step_outputs
|
||||||
|
if "affinity_logits" in output
|
||||||
|
]
|
||||||
|
)
|
||||||
|
affinity = torch.cat(
|
||||||
|
[
|
||||||
|
output["affinity"]
|
||||||
|
for output in self.test_step_outputs
|
||||||
|
if "affinity_logits" in output
|
||||||
|
]
|
||||||
|
)
|
||||||
|
affinity_logits = affinity_logits[~affinity.isnan()]
|
||||||
|
affinity = affinity[~affinity.isnan()]
|
||||||
|
if affinity.numel() > 1:
|
||||||
|
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
|
||||||
|
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
|
||||||
|
aff_mae = mean_absolute_error(affinity_logits, affinity)
|
||||||
|
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
|
||||||
|
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
|
||||||
|
self.log(
|
||||||
|
"test_affinity/RMSE",
|
||||||
|
aff_rmse.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"test_affinity/MAE",
|
||||||
|
aff_mae.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"test_affinity/Pearson",
|
||||||
|
aff_pearson.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"test_affinity/Spearman",
|
||||||
|
aff_spearman.detach(),
|
||||||
|
on_epoch=True,
|
||||||
|
batch_size=len(affinity),
|
||||||
|
sync_dist=True,
|
||||||
|
)
|
||||||
|
self.test_step_outputs.clear() # free memory
|
||||||
|
|
||||||
|
def on_predict_start(self):
|
||||||
|
"""Lightning hook that is called when testing begins."""
|
||||||
|
# create a directory to store model outputs from each predict epoch
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(self.trainer.default_root_dir, "predict_epoch_outputs"), exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Loading pretrained ESM model...")
|
||||||
|
esm_model, self.esm_alphabet = esm.pretrained.load_model_and_alphabet_hub(
|
||||||
|
self.hparams.cfg.model.cfg.protein_encoder.esm_version
|
||||||
|
)
|
||||||
|
self.esm_model = esm_model.eval().float()
|
||||||
|
self.esm_batch_converter = self.esm_alphabet.get_batch_converter()
|
||||||
|
self.esm_model.cpu()
|
||||||
|
|
||||||
|
skip_loading_esmfold_weights = (
|
||||||
|
# skip loading ESMFold weights if the template protein structure for a single complex input is provided
|
||||||
|
self.hparams.cfg.task.csv_path is None
|
||||||
|
and self.hparams.cfg.task.input_template is not None
|
||||||
|
and os.path.exists(self.hparams.cfg.task.input_template)
|
||||||
|
)
|
||||||
|
if not skip_loading_esmfold_weights:
|
||||||
|
log.info("Loading pretrained ESMFold model...")
|
||||||
|
esmfold_model = esm.pretrained.esmfold_v1()
|
||||||
|
self.esmfold_model = esmfold_model.eval().float()
|
||||||
|
self.esmfold_model.set_chunk_size(self.hparams.cfg.esmfold_chunk_size)
|
||||||
|
self.esmfold_model.cpu()
|
||||||
|
|
||||||
|
def predict_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
|
||||||
|
"""Perform a single predict step on a batch of data from the predict set.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param batch_idx: The index of the current batch.
|
||||||
|
:param dataloader_idx: The index of the current dataloader.
|
||||||
|
"""
|
||||||
|
rec_path = batch["rec_path"][0]
|
||||||
|
ligand_paths = list(
|
||||||
|
path[0] for path in batch["lig_paths"]
|
||||||
|
) # unpack a list of (batched) single-element string tuples
|
||||||
|
sample_id = batch["sample_id"][0] if "sample_id" in batch else "sample"
|
||||||
|
input_template = batch["input_template"][0] if "input_template" in batch else None
|
||||||
|
|
||||||
|
out_path = (
|
||||||
|
os.path.join(self.hparams.cfg.out_path, sample_id)
|
||||||
|
if "sample_id" in batch
|
||||||
|
else self.hparams.cfg.out_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# generate ESM embeddings for the protein
|
||||||
|
protein = pdb_filepath_to_protein(rec_path)
|
||||||
|
sequences = [
|
||||||
|
"".join(np.array(list(chain_seq))[chain_mask])
|
||||||
|
for (_, chain_seq, chain_mask) in protein.letter_sequences
|
||||||
|
]
|
||||||
|
esm_embeddings = extract_esm_embeddings(
|
||||||
|
self.esm_model,
|
||||||
|
self.esm_alphabet,
|
||||||
|
self.esm_batch_converter,
|
||||||
|
sequences,
|
||||||
|
device="cpu",
|
||||||
|
esm_repr_layer=self.hparams.cfg.model.cfg.protein_encoder.esm_repr_layer,
|
||||||
|
)
|
||||||
|
sequences_to_embeddings = {
|
||||||
|
f"{seq}:{i}": esm_embeddings[i].cpu().numpy() for i, seq in enumerate(sequences)
|
||||||
|
}
|
||||||
|
|
||||||
|
# generate initial ESMFold-predicted structure for the protein if a template is not provided
|
||||||
|
apo_rec_path = None
|
||||||
|
if input_template and os.path.exists(input_template):
|
||||||
|
apo_protein = pdb_filepath_to_protein(input_template)
|
||||||
|
apo_chain_seq_masked = "".join(
|
||||||
|
"".join(np.array(list(chain_seq))[chain_mask])
|
||||||
|
for (_, chain_seq, chain_mask) in apo_protein.letter_sequences
|
||||||
|
)
|
||||||
|
chain_seq_masked = "".join(
|
||||||
|
"".join(np.array(list(chain_seq))[chain_mask])
|
||||||
|
for (_, chain_seq, chain_mask) in protein.letter_sequences
|
||||||
|
)
|
||||||
|
if apo_chain_seq_masked != chain_seq_masked:
|
||||||
|
log.error(
|
||||||
|
f"Provided template protein structure {input_template} does not match the input protein sequence within {rec_path}. Skipping example {sample_id} at batch index {batch_idx} of dataloader {dataloader_idx}."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
log.info(f"Starting from provided template protein structure: {input_template}")
|
||||||
|
apo_rec_path = input_template
|
||||||
|
if apo_rec_path is None and self.hparams.cfg.prior_type == "esmfold":
|
||||||
|
esmfold_sequence = ":".join(sequences)
|
||||||
|
apo_rec_path = rec_path.replace(".pdb", "_apo.pdb")
|
||||||
|
with torch.no_grad():
|
||||||
|
esmfold_pdb_output = self.esmfold_model.infer_pdb(esmfold_sequence)
|
||||||
|
with open(apo_rec_path, "w") as f:
|
||||||
|
f.write(esmfold_pdb_output)
|
||||||
|
|
||||||
|
_, _, _, _, _, all_frames, batch_all, b_factors, plddt_rankings = multi_pose_sampling(
|
||||||
|
rec_path,
|
||||||
|
ligand_paths,
|
||||||
|
self.hparams.cfg,
|
||||||
|
self,
|
||||||
|
out_path,
|
||||||
|
separate_pdb=self.hparams.cfg.separate_pdb,
|
||||||
|
apo_receptor_path=apo_rec_path,
|
||||||
|
sample_id=sample_id,
|
||||||
|
protein=protein,
|
||||||
|
sequences_to_embeddings=sequences_to_embeddings,
|
||||||
|
return_all_states=self.hparams.cfg.task.visualize_generated_samples,
|
||||||
|
auxiliary_estimation_only=self.hparams.cfg.task.auxiliary_estimation_only,
|
||||||
|
)
|
||||||
|
# store model outputs for inspection
|
||||||
|
if self.hparams.cfg.task.visualize_generated_samples:
|
||||||
|
predict_outputs = {
|
||||||
|
"name": batch_all["metadata"]["sample_ID_per_sample"],
|
||||||
|
"batch_size": batch_all["metadata"]["num_structid"],
|
||||||
|
"aatype": batch_all["features"]["res_type"].long().cpu().numpy(),
|
||||||
|
"res_atom_mask": batch_all["features"]["res_atom_mask"].cpu().numpy(),
|
||||||
|
"protein_coordinates_list": [
|
||||||
|
frame["receptor_padded"].cpu().numpy() for frame in all_frames
|
||||||
|
],
|
||||||
|
"ligand_coordinates_list": [
|
||||||
|
frame["ligands"].cpu().numpy() for frame in all_frames
|
||||||
|
],
|
||||||
|
"ligand_mol": batch_all["metadata"]["mol_per_sample"],
|
||||||
|
"protein_batch_indexer": batch_all["indexer"]["gather_idx_a_structid"]
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
|
"ligand_batch_indexer": batch_all["indexer"]["gather_idx_i_structid"]
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
|
"b_factors": b_factors,
|
||||||
|
"plddt_rankings": plddt_rankings,
|
||||||
|
}
|
||||||
|
self.predict_step_outputs.append(predict_outputs)
|
||||||
|
|
||||||
|
def on_predict_epoch_end(self):
|
||||||
|
"""Lightning hook that is called when a predict epoch ends."""
|
||||||
|
if self.hparams.cfg.task.visualize_generated_samples:
|
||||||
|
for i, outputs in enumerate(self.predict_step_outputs):
|
||||||
|
for batch_index in range(outputs["batch_size"]):
|
||||||
|
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
|
||||||
|
ranking = (
|
||||||
|
outputs["plddt_rankings"][batch_index]
|
||||||
|
if "plddt_rankings" in outputs
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
write_prot_lig_pairs_to_pdb_file(
|
||||||
|
prot_lig_pairs,
|
||||||
|
os.path.join(
|
||||||
|
self.hparams.cfg.out_path,
|
||||||
|
outputs["name"][batch_index],
|
||||||
|
"predict_epoch_outputs",
|
||||||
|
f"{outputs['name'][batch_index]}{f'_rank{ranking + 1}' if ranking is not None else ''}_predict_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}.pdb",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.predict_step_outputs.clear() # free memory
|
||||||
|
|
||||||
|
def on_after_backward(self):
|
||||||
|
"""Skip updates in case of unstable gradients.
|
||||||
|
|
||||||
|
Reference: https://github.com/Lightning-AI/lightning/issues/4956
|
||||||
|
"""
|
||||||
|
valid_gradients = True
|
||||||
|
for _, param in self.named_parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
valid_gradients = not (
|
||||||
|
torch.isnan(param.grad).any() or torch.isinf(param.grad).any()
|
||||||
|
)
|
||||||
|
if not valid_gradients:
|
||||||
|
break
|
||||||
|
if not valid_gradients:
|
||||||
|
log.warning(
|
||||||
|
"Detected `inf` or `nan` values in gradients. Not updating model parameters."
|
||||||
|
)
|
||||||
|
self.zero_grad()
|
||||||
|
|
||||||
|
def optimizer_step(
|
||||||
|
self,
|
||||||
|
epoch,
|
||||||
|
batch_idx,
|
||||||
|
optimizer,
|
||||||
|
optimizer_closure,
|
||||||
|
):
|
||||||
|
"""Override the optimizer step to dynamically update the learning rate.
|
||||||
|
|
||||||
|
:param epoch: The current epoch.
|
||||||
|
:param batch_idx: The index of the current batch.
|
||||||
|
:param optimizer: The optimizer to use for training.
|
||||||
|
:param optimizer_closure: The optimizer closure.
|
||||||
|
"""
|
||||||
|
# update params
|
||||||
|
optimizer = optimizer.optimizer
|
||||||
|
optimizer.step(closure=optimizer_closure)
|
||||||
|
|
||||||
|
# warm up learning rate
|
||||||
|
if self.trainer.global_step < 1000:
|
||||||
|
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 1000.0)
|
||||||
|
for pg in optimizer.param_groups:
|
||||||
|
# NOTE: `self.hparams.optimizer.keywords["lr"]` refers to the optimizer's initial learning rate
|
||||||
|
pg["lr"] = lr_scale * self.hparams.optimizer.keywords["lr"]
|
||||||
|
|
||||||
|
def setup(self, stage: str):
|
||||||
|
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
|
||||||
|
test, or predict.
|
||||||
|
|
||||||
|
This is a good hook when you need to build models dynamically or adjust something about
|
||||||
|
them. This hook is called on every process when using DDP.
|
||||||
|
|
||||||
|
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
||||||
|
"""
|
||||||
|
if self.hparams.compile and stage == "fit":
|
||||||
|
self.net = torch.compile(self.net)
|
||||||
|
|
||||||
|
def configure_optimizers(self) -> Dict[str, Any]:
|
||||||
|
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
|
||||||
|
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
|
||||||
|
|
||||||
|
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
|
||||||
|
except TypeError:
|
||||||
|
# NOTE: strategies such as DeepSpeed require `params` to instead be specified as `model_params`
|
||||||
|
optimizer = self.hparams.optimizer(model_params=self.trainer.model.parameters())
|
||||||
|
if self.hparams.scheduler is not None:
|
||||||
|
scheduler = self.hparams.scheduler(optimizer=optimizer)
|
||||||
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"lr_scheduler": {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"monitor": "val/loss",
|
||||||
|
"interval": "epoch",
|
||||||
|
"frequency": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return {"optimizer": optimizer}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_ = FlowDockFMLitModule(None, None, None, None)
|
||||||
284
flowdock/sample.py
Normal file
284
flowdock/sample.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import lightning as L
|
||||||
|
import lovely_tensors as lt
|
||||||
|
import pandas as pd
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, List, Tuple
|
||||||
|
from lightning import LightningModule, Trainer
|
||||||
|
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
|
from lightning.pytorch.loggers import Logger
|
||||||
|
from lightning.pytorch.strategies.strategy import Strategy
|
||||||
|
from omegaconf import DictConfig, open_dict
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
lt.monkey_patch()
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
# the setup_root above is equivalent to:
|
||||||
|
# - adding project root dir to PYTHONPATH
|
||||||
|
# (so you don't need to force user to install project as a package)
|
||||||
|
# (necessary before importing any local modules e.g. `from flowdock import utils`)
|
||||||
|
# - setting up PROJECT_ROOT environment variable
|
||||||
|
# (which is used as a base for paths in "configs/paths/default.yaml")
|
||||||
|
# (this way all filepaths are the same no matter where you run the code)
|
||||||
|
# - loading environment variables from ".env" in root dir
|
||||||
|
#
|
||||||
|
# you can remove it if you:
|
||||||
|
# 1. either install project as a package or move entry files to project root dir
|
||||||
|
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
||||||
|
#
|
||||||
|
# more info: https://github.com/ashleve/rootutils
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
|
||||||
|
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
|
||||||
|
from flowdock.utils import (
|
||||||
|
RankedLogger,
|
||||||
|
extras,
|
||||||
|
instantiate_loggers,
|
||||||
|
log_hyperparameters,
|
||||||
|
task_wrapper,
|
||||||
|
)
|
||||||
|
from flowdock.utils.data_utils import (
|
||||||
|
create_full_pdb_with_zero_coordinates,
|
||||||
|
create_temp_ligand_frag_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
AVAILABLE_SAMPLING_TASKS = ["batched_structure_sampling"]
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingDataset(torch.utils.data.Dataset):
|
||||||
|
"""Dataset for sampling."""
|
||||||
|
|
||||||
|
def __init__(self, cfg: DictConfig):
|
||||||
|
"""Initializes the SamplingDataset."""
|
||||||
|
if cfg.sampling_task == "batched_structure_sampling":
|
||||||
|
if cfg.csv_path is not None:
|
||||||
|
# handle variable CSV inputs
|
||||||
|
df_rows = []
|
||||||
|
self.df = pd.read_csv(cfg.csv_path)
|
||||||
|
for _, row in self.df.iterrows():
|
||||||
|
sample_id = row.id
|
||||||
|
input_receptor = row.input_receptor
|
||||||
|
input_ligand = row.input_ligand
|
||||||
|
input_template = row.input_template
|
||||||
|
assert input_receptor is not None, "Receptor path is required for sampling."
|
||||||
|
if input_ligand is not None:
|
||||||
|
if input_ligand.endswith(".sdf"):
|
||||||
|
ligand_paths = create_temp_ligand_frag_files(input_ligand)
|
||||||
|
else:
|
||||||
|
ligand_paths = list(input_ligand.split("|"))
|
||||||
|
else:
|
||||||
|
ligand_paths = None # handle `null` ligand input
|
||||||
|
if not input_receptor.endswith(".pdb"):
|
||||||
|
log.warning(
|
||||||
|
"Assuming the provided receptor input is a protein sequence. Creating a dummy PDB file."
|
||||||
|
)
|
||||||
|
create_full_pdb_with_zero_coordinates(
|
||||||
|
input_receptor, os.path.join(cfg.out_path, f"input_{sample_id}.pdb")
|
||||||
|
)
|
||||||
|
input_receptor = os.path.join(cfg.out_path, f"input_{sample_id}.pdb")
|
||||||
|
df_row = {
|
||||||
|
"sample_id": sample_id,
|
||||||
|
"rec_path": input_receptor,
|
||||||
|
"lig_paths": ligand_paths,
|
||||||
|
}
|
||||||
|
if input_template is not None:
|
||||||
|
df_row["input_template"] = input_template
|
||||||
|
df_rows.append(df_row)
|
||||||
|
self.df = pd.DataFrame(df_rows)
|
||||||
|
else:
|
||||||
|
sample_id = cfg.sample_id
|
||||||
|
input_receptor = cfg.input_receptor
|
||||||
|
input_ligand = cfg.input_ligand
|
||||||
|
if input_ligand is not None:
|
||||||
|
if input_ligand.endswith(".sdf"):
|
||||||
|
ligand_paths = create_temp_ligand_frag_files(input_ligand)
|
||||||
|
else:
|
||||||
|
ligand_paths = list(input_ligand.split("|"))
|
||||||
|
else:
|
||||||
|
ligand_paths = None # handle `null` ligand input
|
||||||
|
if not input_receptor.endswith(".pdb"):
|
||||||
|
log.warning(
|
||||||
|
"Assuming the provided receptor input is a protein sequence. Creating a dummy PDB file."
|
||||||
|
)
|
||||||
|
create_full_pdb_with_zero_coordinates(
|
||||||
|
input_receptor, os.path.join(cfg.out_path, "input.pdb")
|
||||||
|
)
|
||||||
|
input_receptor = os.path.join(cfg.out_path, "input.pdb")
|
||||||
|
self.df = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"sample_id": sample_id,
|
||||||
|
"rec_path": input_receptor,
|
||||||
|
"lig_paths": ligand_paths,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if cfg.input_template is not None:
|
||||||
|
self.df["input_template"] = cfg.input_template
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Sampling task {cfg.sampling_task} is not implemented.")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Returns the length of the dataset."""
|
||||||
|
return len(self.df)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Tuple[str, str]:
|
||||||
|
"""Returns the input receptor and input ligand."""
|
||||||
|
return self.df.iloc[idx].to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@task_wrapper
|
||||||
|
def sample(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""Samples using given checkpoint on a datamodule predictset.
|
||||||
|
|
||||||
|
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
||||||
|
failure. Useful for multiruns, saving info about the crash, etc.
|
||||||
|
|
||||||
|
:param cfg: DictConfig configuration composed by Hydra.
|
||||||
|
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
|
||||||
|
"""
|
||||||
|
assert cfg.ckpt_path, "Please provide a checkpoint path with which to sample!"
|
||||||
|
assert os.path.exists(cfg.ckpt_path), f"Checkpoint path {cfg.ckpt_path} does not exist!"
|
||||||
|
assert (
|
||||||
|
cfg.sampling_task in AVAILABLE_SAMPLING_TASKS
|
||||||
|
), f"Sampling task {cfg.sampling_task} is not one of the following available tasks: {AVAILABLE_SAMPLING_TASKS}."
|
||||||
|
assert (cfg.input_receptor is not None and cfg.input_ligand is not None) or (
|
||||||
|
cfg.csv_path is not None and os.path.exists(cfg.csv_path)
|
||||||
|
), "Please provide either an input receptor and ligand or a CSV file with receptor and ligand sequences/filepaths."
|
||||||
|
|
||||||
|
# set seed for random number generators in pytorch, numpy and python.random
|
||||||
|
if cfg.get("seed"):
|
||||||
|
L.seed_everything(cfg.seed, workers=True)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
|
||||||
|
)
|
||||||
|
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
|
||||||
|
|
||||||
|
# Establish model input arguments
|
||||||
|
with open_dict(cfg):
|
||||||
|
# NOTE: Structure trajectories will not be visualized when performing auxiliary estimation only
|
||||||
|
cfg.model.cfg.prior_type = cfg.prior_type
|
||||||
|
cfg.model.cfg.task.detect_covalent = cfg.detect_covalent
|
||||||
|
cfg.model.cfg.task.use_template = cfg.use_template
|
||||||
|
cfg.model.cfg.task.csv_path = cfg.csv_path
|
||||||
|
cfg.model.cfg.task.input_receptor = cfg.input_receptor
|
||||||
|
cfg.model.cfg.task.input_ligand = cfg.input_ligand
|
||||||
|
cfg.model.cfg.task.input_template = cfg.input_template
|
||||||
|
cfg.model.cfg.task.visualize_generated_samples = (
|
||||||
|
cfg.visualize_sample_trajectories and not cfg.auxiliary_estimation_only
|
||||||
|
)
|
||||||
|
cfg.model.cfg.task.auxiliary_estimation_only = cfg.auxiliary_estimation_only
|
||||||
|
if cfg.latent_model is not None:
|
||||||
|
with open_dict(cfg):
|
||||||
|
cfg.model.cfg.latent_model = cfg.latent_model
|
||||||
|
with open_dict(cfg):
|
||||||
|
if cfg.start_time == "auto":
|
||||||
|
cfg.start_time = 1.0
|
||||||
|
else:
|
||||||
|
cfg.start_time = float(cfg.start_time)
|
||||||
|
|
||||||
|
log.info("Converting sampling inputs into a <SamplingDataset>")
|
||||||
|
dataloaders: List[DataLoader] = [
|
||||||
|
DataLoader(
|
||||||
|
SamplingDataset(cfg),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0,
|
||||||
|
pin_memory=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
log.info(f"Instantiating model <{cfg.model._target_}>")
|
||||||
|
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
||||||
|
model.hparams.cfg.update(cfg) # update model config with the sampling config
|
||||||
|
|
||||||
|
log.info("Instantiating loggers...")
|
||||||
|
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
||||||
|
|
||||||
|
plugins = None
|
||||||
|
if "_target_" in cfg.environment:
|
||||||
|
log.info(f"Instantiating environment <{cfg.environment._target_}>")
|
||||||
|
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
|
||||||
|
|
||||||
|
strategy = getattr(cfg.trainer, "strategy", None)
|
||||||
|
if "_target_" in cfg.strategy:
|
||||||
|
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
|
||||||
|
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
|
||||||
|
if (
|
||||||
|
"mixed_precision" in strategy.__dict__
|
||||||
|
and getattr(strategy, "mixed_precision", None) is not None
|
||||||
|
):
|
||||||
|
strategy.mixed_precision.param_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
strategy.mixed_precision.reduce_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
strategy.mixed_precision.buffer_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
||||||
|
trainer: Trainer = (
|
||||||
|
hydra.utils.instantiate(
|
||||||
|
cfg.trainer,
|
||||||
|
logger=logger,
|
||||||
|
plugins=plugins,
|
||||||
|
strategy=strategy,
|
||||||
|
)
|
||||||
|
if strategy is not None
|
||||||
|
else hydra.utils.instantiate(
|
||||||
|
cfg.trainer,
|
||||||
|
logger=logger,
|
||||||
|
plugins=plugins,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
object_dict = {
|
||||||
|
"cfg": cfg,
|
||||||
|
"model": model,
|
||||||
|
"logger": logger,
|
||||||
|
"trainer": trainer,
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger:
|
||||||
|
log.info("Logging hyperparameters!")
|
||||||
|
log_hyperparameters(object_dict)
|
||||||
|
|
||||||
|
trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
|
||||||
|
|
||||||
|
metric_dict = trainer.callback_metrics
|
||||||
|
|
||||||
|
return metric_dict, object_dict
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base="1.3", config_path="../configs", config_name="sample.yaml")
|
||||||
|
def main(cfg: DictConfig) -> None:
|
||||||
|
"""Main entry point for sampling.
|
||||||
|
|
||||||
|
:param cfg: DictConfig configuration composed by Hydra.
|
||||||
|
"""
|
||||||
|
# apply extra utilities
|
||||||
|
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
||||||
|
extras(cfg)
|
||||||
|
|
||||||
|
sample(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
register_custom_omegaconf_resolvers()
|
||||||
|
main()
|
||||||
196
flowdock/train.py
Normal file
196
flowdock/train.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import lightning as L
|
||||||
|
import lovely_tensors as lt
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
||||||
|
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
|
from lightning.pytorch.loggers import Logger
|
||||||
|
from lightning.pytorch.strategies.strategy import Strategy
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
lt.monkey_patch()
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
# the setup_root above is equivalent to:
|
||||||
|
# - adding project root dir to PYTHONPATH
|
||||||
|
# (so you don't need to force user to install project as a package)
|
||||||
|
# (necessary before importing any local modules e.g. `from flowdock import utils`)
|
||||||
|
# - setting up PROJECT_ROOT environment variable
|
||||||
|
# (which is used as a base for paths in "configs/paths/default.yaml")
|
||||||
|
# (this way all filepaths are the same no matter where you run the code)
|
||||||
|
# - loading environment variables from ".env" in root dir
|
||||||
|
#
|
||||||
|
# you can remove it if you:
|
||||||
|
# 1. either install project as a package or move entry files to project root dir
|
||||||
|
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
||||||
|
#
|
||||||
|
# more info: https://github.com/ashleve/rootutils
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
|
||||||
|
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
|
||||||
|
from flowdock.utils import (
|
||||||
|
RankedLogger,
|
||||||
|
extras,
|
||||||
|
get_metric_value,
|
||||||
|
instantiate_callbacks,
|
||||||
|
instantiate_loggers,
|
||||||
|
log_hyperparameters,
|
||||||
|
task_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
@task_wrapper
|
||||||
|
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
||||||
|
training.
|
||||||
|
|
||||||
|
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
||||||
|
failure. Useful for multiruns, saving info about the crash, etc.
|
||||||
|
|
||||||
|
:param cfg: A DictConfig configuration composed by Hydra.
|
||||||
|
:return: A tuple with metrics and dict with all instantiated objects.
|
||||||
|
"""
|
||||||
|
# set seed for random number generators in pytorch, numpy and python.random
|
||||||
|
if cfg.get("seed"):
|
||||||
|
L.seed_everything(cfg.seed, workers=True)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
|
||||||
|
)
|
||||||
|
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
|
||||||
|
|
||||||
|
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
||||||
|
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, stage="fit")
|
||||||
|
|
||||||
|
log.info(f"Instantiating model <{cfg.model._target_}>")
|
||||||
|
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
||||||
|
|
||||||
|
log.info("Instantiating callbacks...")
|
||||||
|
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
|
||||||
|
|
||||||
|
log.info("Instantiating loggers...")
|
||||||
|
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
||||||
|
|
||||||
|
plugins = None
|
||||||
|
if "_target_" in cfg.environment:
|
||||||
|
log.info(f"Instantiating environment <{cfg.environment._target_}>")
|
||||||
|
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
|
||||||
|
|
||||||
|
strategy = getattr(cfg.trainer, "strategy", None)
|
||||||
|
if "_target_" in cfg.strategy:
|
||||||
|
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
|
||||||
|
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
|
||||||
|
if (
|
||||||
|
"mixed_precision" in strategy.__dict__
|
||||||
|
and getattr(strategy, "mixed_precision", None) is not None
|
||||||
|
):
|
||||||
|
strategy.mixed_precision.param_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
strategy.mixed_precision.reduce_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
strategy.mixed_precision.buffer_dtype = (
|
||||||
|
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
|
||||||
|
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
||||||
|
trainer: Trainer = (
|
||||||
|
hydra.utils.instantiate(
|
||||||
|
cfg.trainer,
|
||||||
|
callbacks=callbacks,
|
||||||
|
logger=logger,
|
||||||
|
plugins=plugins,
|
||||||
|
strategy=strategy,
|
||||||
|
)
|
||||||
|
if strategy is not None
|
||||||
|
else hydra.utils.instantiate(
|
||||||
|
cfg.trainer,
|
||||||
|
callbacks=callbacks,
|
||||||
|
logger=logger,
|
||||||
|
plugins=plugins,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
object_dict = {
|
||||||
|
"cfg": cfg,
|
||||||
|
"datamodule": datamodule,
|
||||||
|
"model": model,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"logger": logger,
|
||||||
|
"trainer": trainer,
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger:
|
||||||
|
log.info("Logging hyperparameters!")
|
||||||
|
log_hyperparameters(object_dict)
|
||||||
|
|
||||||
|
if cfg.get("train"):
|
||||||
|
log.info("Starting training!")
|
||||||
|
ckpt_path = None
|
||||||
|
if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")):
|
||||||
|
ckpt_path = cfg.get("ckpt_path")
|
||||||
|
elif cfg.get("ckpt_path"):
|
||||||
|
log.warning(
|
||||||
|
"`ckpt_path` was given, but the path does not exist. Training with new model weights."
|
||||||
|
)
|
||||||
|
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
||||||
|
|
||||||
|
train_metrics = trainer.callback_metrics
|
||||||
|
|
||||||
|
if cfg.get("test"):
|
||||||
|
log.info("Starting testing!")
|
||||||
|
ckpt_path = trainer.checkpoint_callback.best_model_path
|
||||||
|
if ckpt_path == "":
|
||||||
|
log.warning("Best ckpt not found! Using current weights for testing...")
|
||||||
|
ckpt_path = None
|
||||||
|
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
||||||
|
log.info(f"Best ckpt path: {ckpt_path}")
|
||||||
|
|
||||||
|
test_metrics = trainer.callback_metrics
|
||||||
|
|
||||||
|
# merge train and test metrics
|
||||||
|
metric_dict = {**train_metrics, **test_metrics}
|
||||||
|
|
||||||
|
return metric_dict, object_dict
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
|
||||||
|
def main(cfg: DictConfig) -> Optional[float]:
|
||||||
|
"""Main entry point for training.
|
||||||
|
|
||||||
|
:param cfg: DictConfig configuration composed by Hydra.
|
||||||
|
:return: Optional[float] with optimized metric value.
|
||||||
|
"""
|
||||||
|
# apply extra utilities
|
||||||
|
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
||||||
|
extras(cfg)
|
||||||
|
|
||||||
|
# train the model
|
||||||
|
metric_dict, _ = train(cfg)
|
||||||
|
|
||||||
|
# safely retrieve metric value for hydra-based hyperparameter optimization
|
||||||
|
metric_value = get_metric_value(
|
||||||
|
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
|
||||||
|
)
|
||||||
|
|
||||||
|
# return optimized metric
|
||||||
|
return metric_value
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
register_custom_omegaconf_resolvers()
|
||||||
|
main()
|
||||||
5
flowdock/utils/__init__.py
Normal file
5
flowdock/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from flowdock.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
||||||
|
from flowdock.utils.logging_utils import log_hyperparameters
|
||||||
|
from flowdock.utils.pylogger import RankedLogger
|
||||||
|
from flowdock.utils.rich_utils import enforce_tags, print_config_tree
|
||||||
|
from flowdock.utils.utils import extras, get_metric_value, task_wrapper
|
||||||
1846
flowdock/utils/data_utils.py
Normal file
1846
flowdock/utils/data_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
77
flowdock/utils/frame_utils.py
Normal file
77
flowdock/utils/frame_utils.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
from beartype.typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class RigidTransform:
|
||||||
|
"""Rigid Transform class."""
|
||||||
|
|
||||||
|
def __init__(self, t: torch.Tensor, R: Optional[torch.Tensor] = None):
|
||||||
|
"""Initialize Rigid Transform."""
|
||||||
|
self.t = t
|
||||||
|
if R is None:
|
||||||
|
R = t.new_zeros(*t.shape, 3)
|
||||||
|
self.R = R
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""Get item from Rigid Transform."""
|
||||||
|
return RigidTransform(self.t[key], self.R[key])
|
||||||
|
|
||||||
|
def unsqueeze(self, dim):
|
||||||
|
"""Unsqueeze Rigid Transform."""
|
||||||
|
return RigidTransform(self.t.unsqueeze(dim), self.R.unsqueeze(dim))
|
||||||
|
|
||||||
|
def squeeze(self, dim):
|
||||||
|
"""Squeeze Rigid Transform."""
|
||||||
|
return RigidTransform(self.t.squeeze(dim), self.R.squeeze(dim))
|
||||||
|
|
||||||
|
def concatenate(self, other, dim=0):
|
||||||
|
"""Concatenate Rigid Transform."""
|
||||||
|
return RigidTransform(
|
||||||
|
torch.cat([self.t, other.t], dim=dim),
|
||||||
|
torch.cat([self.R, other.R], dim=dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_frame_matrix(
|
||||||
|
ri: torch.Tensor, rj: torch.Tensor, rk: torch.Tensor, eps: float = 1e-4, strict: bool = False
|
||||||
|
):
|
||||||
|
"""Get frame matrix from three points using the regularized Gram-Schmidt algorithm.
|
||||||
|
|
||||||
|
Note that this implementation allows for shearing.
|
||||||
|
"""
|
||||||
|
v1 = ri - rj
|
||||||
|
v2 = rk - rj
|
||||||
|
if strict:
|
||||||
|
# v1 = v1 + torch.randn_like(rj).mul(eps)
|
||||||
|
# v2 = v2 + torch.randn_like(rj).mul(eps)
|
||||||
|
e1 = v1 / v1.norm(dim=-1, keepdim=True)
|
||||||
|
# Project and pad
|
||||||
|
u2 = v2 - e1.mul(e1.mul(v2).sum(-1, keepdim=True))
|
||||||
|
e2 = u2 / u2.norm(dim=-1, keepdim=True)
|
||||||
|
else:
|
||||||
|
e1 = v1 / v1.square().sum(dim=-1, keepdim=True).add(eps).sqrt()
|
||||||
|
# Project and pad
|
||||||
|
u2 = v2 - e1.mul(e1.mul(v2).sum(-1, keepdim=True))
|
||||||
|
e2 = u2 / u2.square().sum(dim=-1, keepdim=True).add(eps).sqrt()
|
||||||
|
e3 = torch.cross(e1, e2, dim=-1)
|
||||||
|
# Rows - lab frame, columns - internal frame
|
||||||
|
rot_j = torch.stack([e1, e2, e3], dim=-1)
|
||||||
|
return RigidTransform(rj, torch.nan_to_num(rot_j, 0.0))
|
||||||
|
|
||||||
|
|
||||||
|
def cartesian_to_internal(rs: torch.Tensor, frames: RigidTransform):
|
||||||
|
"""Right-multiply the pose matrix."""
|
||||||
|
rs_loc = rs - frames.t
|
||||||
|
rs_loc = torch.matmul(rs_loc.unsqueeze(-2), frames.R)
|
||||||
|
return rs_loc.squeeze(-2)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_similarity_transform(
|
||||||
|
X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply a similarity transform to a set of points X.
|
||||||
|
|
||||||
|
From: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/ops/points_alignment.html
|
||||||
|
"""
|
||||||
|
X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :]
|
||||||
|
return X
|
||||||
4
flowdock/utils/generate_wandb_id.py
Normal file
4
flowdock/utils/generate_wandb_id.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
import wandb
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(f"Generated WandB run ID: {wandb.util.generate_id()}")
|
||||||
55
flowdock/utils/inspect_ode_samplers.py
Normal file
55
flowdock/utils/inspect_ode_samplers.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from beartype import beartype
|
||||||
|
from beartype.typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
def clamp_tensor(value: torch.Tensor, min: float = 1e-6, max: float = 1 - 1e-6) -> torch.Tensor:
|
||||||
|
"""Set the upper and lower bounds of a tensor via clamping.
|
||||||
|
|
||||||
|
:param value: The tensor to clamp.
|
||||||
|
:param min: The minimum value to clamp to. Default is `1e-6`.
|
||||||
|
:param max: The maximum value to clamp to. Default is `1 - 1e-6`.
|
||||||
|
:return: The clamped tensor.
|
||||||
|
"""
|
||||||
|
return value.clamp(min=min, max=max)
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def main(
|
||||||
|
start_time: float, num_steps: int, sampler: Literal["ODE", "VDODE"] = "VDODE", eta: float = 1.0
|
||||||
|
):
|
||||||
|
"""Inspect different ODE samplers by printing the left hand side (LHS) and right hand side.
|
||||||
|
|
||||||
|
(RHS) of their time ratio schedules. Note that the LHS and RHS are clamped to the range
|
||||||
|
`[1e-6, 1 - 1e-6]` by default.
|
||||||
|
|
||||||
|
:param start_time: The start time of the ODE sampler.
|
||||||
|
:param num_steps: The number of steps to take.
|
||||||
|
:param sampler: The ODE sampler to use.
|
||||||
|
:param eta: The variance diminishing factor.
|
||||||
|
"""
|
||||||
|
assert 0 < start_time <= 1.0, "The argument `start_time` must be in the range (0, 1]."
|
||||||
|
schedule = torch.linspace(start_time, 0, num_steps + 1)
|
||||||
|
for t, s in zip(schedule[:-1], schedule[1:]):
|
||||||
|
if sampler == "ODE":
|
||||||
|
# Baseline ODE
|
||||||
|
print(
|
||||||
|
f"t: {t:.3f}; s: {s:.3f}; LHS -> (1 - t) * x0_hat: {clamp_tensor((1 - t)):.3f}; RHS -> t * xt: {clamp_tensor(t):.3f}"
|
||||||
|
)
|
||||||
|
elif sampler == "VDODE":
|
||||||
|
# Variance Diminishing ODE (VD-ODE)
|
||||||
|
print(
|
||||||
|
f"t: {t:.3f}; s: {s:.3f}; LHS -> (1 - ((s / t) * eta)) * x0_hat: {clamp_tensor(1 - ((s / t) * eta)):.3f}; RHS -> ((s / t) * eta) * xt: {clamp_tensor((s / t) * eta):.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argparser = argparse.ArgumentParser()
|
||||||
|
argparser.add_argument("--start_time", type=float, default=1.0)
|
||||||
|
argparser.add_argument("--num_steps", type=int, default=40)
|
||||||
|
argparser.add_argument("--sampler", type=str, choices=["ODE", "VDODE"], default="VDODE")
|
||||||
|
argparser.add_argument("--eta", type=float, default=1.0)
|
||||||
|
args = argparser.parse_args()
|
||||||
|
main(args.start_time, args.num_steps, sampler=args.sampler, eta=args.eta)
|
||||||
58
flowdock/utils/instantiators.py
Normal file
58
flowdock/utils/instantiators.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import hydra
|
||||||
|
import rootutils
|
||||||
|
from beartype.typing import List
|
||||||
|
from lightning import Callback
|
||||||
|
from lightning.pytorch.loggers import Logger
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.utils import pylogger
|
||||||
|
|
||||||
|
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
||||||
|
"""Instantiates callbacks from config.
|
||||||
|
|
||||||
|
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
||||||
|
:return: A list of instantiated callbacks.
|
||||||
|
"""
|
||||||
|
callbacks: List[Callback] = []
|
||||||
|
|
||||||
|
if not callbacks_cfg:
|
||||||
|
log.warning("No callback configs found! Skipping..")
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
if not isinstance(callbacks_cfg, DictConfig):
|
||||||
|
raise TypeError("Callbacks config must be a DictConfig!")
|
||||||
|
|
||||||
|
for _, cb_conf in callbacks_cfg.items():
|
||||||
|
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
||||||
|
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
||||||
|
callbacks.append(hydra.utils.instantiate(cb_conf))
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
||||||
|
"""Instantiates loggers from config.
|
||||||
|
|
||||||
|
:param logger_cfg: A DictConfig object containing logger configurations.
|
||||||
|
:return: A list of instantiated loggers.
|
||||||
|
"""
|
||||||
|
logger: List[Logger] = []
|
||||||
|
|
||||||
|
if not logger_cfg:
|
||||||
|
log.warning("No logger configs found! Skipping...")
|
||||||
|
return logger
|
||||||
|
|
||||||
|
if not isinstance(logger_cfg, DictConfig):
|
||||||
|
raise TypeError("Logger config must be a DictConfig!")
|
||||||
|
|
||||||
|
for _, lg_conf in logger_cfg.items():
|
||||||
|
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
||||||
|
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
||||||
|
logger.append(hydra.utils.instantiate(lg_conf))
|
||||||
|
|
||||||
|
return logger
|
||||||
59
flowdock/utils/logging_utils.py
Normal file
59
flowdock/utils/logging_utils.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import rootutils
|
||||||
|
from beartype.typing import Any, Dict
|
||||||
|
from lightning_utilities.core.rank_zero import rank_zero_only
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.utils import pylogger
|
||||||
|
|
||||||
|
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
||||||
|
"""Controls which config parts are saved by Lightning loggers.
|
||||||
|
|
||||||
|
Additionally saves:
|
||||||
|
- Number of model parameters
|
||||||
|
|
||||||
|
:param object_dict: A dictionary containing the following objects:
|
||||||
|
- `"cfg"`: A DictConfig object containing the main config.
|
||||||
|
- `"model"`: The Lightning model.
|
||||||
|
- `"trainer"`: The Lightning trainer.
|
||||||
|
"""
|
||||||
|
hparams = {}
|
||||||
|
|
||||||
|
cfg = OmegaConf.to_container(object_dict["cfg"])
|
||||||
|
model = object_dict["model"]
|
||||||
|
trainer = object_dict["trainer"]
|
||||||
|
|
||||||
|
if not trainer.logger:
|
||||||
|
log.warning("Logger not found! Skipping hyperparameter logging...")
|
||||||
|
return
|
||||||
|
|
||||||
|
hparams["model"] = cfg["model"]
|
||||||
|
|
||||||
|
# save number of model parameters
|
||||||
|
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
||||||
|
hparams["model/params/trainable"] = sum(
|
||||||
|
p.numel() for p in model.parameters() if p.requires_grad
|
||||||
|
)
|
||||||
|
hparams["model/params/non_trainable"] = sum(
|
||||||
|
p.numel() for p in model.parameters() if not p.requires_grad
|
||||||
|
)
|
||||||
|
|
||||||
|
hparams["data"] = cfg["data"]
|
||||||
|
hparams["trainer"] = cfg["trainer"]
|
||||||
|
|
||||||
|
hparams["callbacks"] = cfg.get("callbacks")
|
||||||
|
hparams["extras"] = cfg.get("extras")
|
||||||
|
|
||||||
|
hparams["task_name"] = cfg.get("task_name")
|
||||||
|
hparams["tags"] = cfg.get("tags")
|
||||||
|
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
||||||
|
hparams["seed"] = cfg.get("seed")
|
||||||
|
|
||||||
|
# send hparams to all loggers
|
||||||
|
for logger in trainer.loggers:
|
||||||
|
logger.log_hyperparams(hparams)
|
||||||
88
flowdock/utils/metric_utils.py
Normal file
88
flowdock/utils/metric_utils.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import subprocess # nosec
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from beartype import beartype
|
||||||
|
from beartype.typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
MODEL_BATCH = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def calculate_usalign_metrics(
|
||||||
|
pred_pdb_filepath: str,
|
||||||
|
reference_pdb_filepath: str,
|
||||||
|
usalign_exec_path: str,
|
||||||
|
flags: Optional[List[str]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Calculates US-align structural metrics between predicted and reference macromolecular
|
||||||
|
structures.
|
||||||
|
|
||||||
|
:param pred_pdb_filepath: Filepath to predicted macromolecular structure in PDB format.
|
||||||
|
:param reference_pdb_filepath: Filepath to reference macromolecular structure in PDB format.
|
||||||
|
:param usalign_exec_path: Path to US-align executable.
|
||||||
|
:param flags: Command-line flags to pass to US-align, optional.
|
||||||
|
:return: Dictionary containing macromolecular US-align structural metrics and metadata.
|
||||||
|
"""
|
||||||
|
# run US-align with subprocess and capture output
|
||||||
|
cmd = [usalign_exec_path, pred_pdb_filepath, reference_pdb_filepath]
|
||||||
|
if flags is not None:
|
||||||
|
cmd += flags
|
||||||
|
output = subprocess.check_output(cmd, text=True, stderr=subprocess.PIPE) # nosec
|
||||||
|
|
||||||
|
# parse US-align output to extract structural metrics
|
||||||
|
metrics = {}
|
||||||
|
for line in output.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("Name of Structure_1:"):
|
||||||
|
metrics["Name of Structure_1"] = line.split(": ", 1)[1]
|
||||||
|
elif line.startswith("Name of Structure_2:"):
|
||||||
|
metrics["Name of Structure_2"] = line.split(": ", 1)[1]
|
||||||
|
elif line.startswith("Length of Structure_1:"):
|
||||||
|
metrics["Length of Structure_1"] = int(line.split(": ")[1].split()[0])
|
||||||
|
elif line.startswith("Length of Structure_2:"):
|
||||||
|
metrics["Length of Structure_2"] = int(line.split(": ")[1].split()[0])
|
||||||
|
elif line.startswith("Aligned length="):
|
||||||
|
aligned_length = line.split("=")[1].split(",")[0]
|
||||||
|
rmsd = line.split("=")[2].split(",")[0]
|
||||||
|
seq_id = line.split("=")[4]
|
||||||
|
metrics["Aligned length"] = int(aligned_length.strip())
|
||||||
|
metrics["RMSD"] = float(rmsd.strip())
|
||||||
|
metrics["Seq_ID"] = float(seq_id.strip())
|
||||||
|
elif line.startswith("TM-score="):
|
||||||
|
if "normalized by length of Structure_1" in line:
|
||||||
|
metrics["TM-score_1"] = float(line.split("=")[1].split()[0])
|
||||||
|
elif "normalized by length of Structure_2" in line:
|
||||||
|
metrics["TM-score_2"] = float(line.split("=")[1].split()[0])
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def compute_per_atom_lddt(
|
||||||
|
batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Computes per-atom local distance difference test (LDDT) between predicted and target
|
||||||
|
coordinates.
|
||||||
|
|
||||||
|
:param batch: Dictionary containing metadata and target coordinates.
|
||||||
|
:param pred_coords: Predicted atomic coordinates.
|
||||||
|
:param target_coords: Target atomic coordinates.
|
||||||
|
:return: Tuple of lDDT and lDDT list.
|
||||||
|
"""
|
||||||
|
pred_coords = pred_coords.contiguous().view(batch["metadata"]["num_structid"], -1, 3)
|
||||||
|
target_coords = target_coords.contiguous().view(batch["metadata"]["num_structid"], -1, 3)
|
||||||
|
target_dist = (target_coords[:, :, None] - target_coords[:, None, :]).norm(dim=-1)
|
||||||
|
pred_dist = (pred_coords[:, :, None] - pred_coords[:, None, :]).norm(dim=-1)
|
||||||
|
conserved_mask = target_dist < 15.0
|
||||||
|
lddt_list = []
|
||||||
|
thresholds = [0, 0.5, 1, 2, 4, 6, 8, 12, 1e9]
|
||||||
|
for threshold_idx in range(8):
|
||||||
|
distdiff = (pred_dist - target_dist).abs()
|
||||||
|
bin_fraction = (distdiff > thresholds[threshold_idx]) & (
|
||||||
|
distdiff < thresholds[threshold_idx + 1]
|
||||||
|
)
|
||||||
|
lddt_list.append(
|
||||||
|
bin_fraction.mul(conserved_mask).long().sum(dim=2) / conserved_mask.long().sum(dim=2)
|
||||||
|
)
|
||||||
|
lddt_list = torch.stack(lddt_list, dim=-1)
|
||||||
|
lddt = torch.cumsum(lddt_list[:, :, :4], dim=-1).mean(dim=-1)
|
||||||
|
return lddt, lddt_list
|
||||||
446
flowdock/utils/model_utils.py
Normal file
446
flowdock/utils/model_utils.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||||
|
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from beartype.typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
from torch_scatter import scatter_max, scatter_min
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.utils import RankedLogger
|
||||||
|
|
||||||
|
MODEL_BATCH = Dict[str, Any]
|
||||||
|
STATE_DICT = Dict[str, Any]
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GELUMLP(torch.nn.Module):
|
||||||
|
"""Simple MLP with post-LayerNorm."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_in_feats: int,
|
||||||
|
n_out_feats: int,
|
||||||
|
n_hidden_feats: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
zero_init: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the GELUMLP."""
|
||||||
|
super().__init__()
|
||||||
|
self.dropout = dropout
|
||||||
|
if n_hidden_feats is None:
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(n_in_feats, n_in_feats),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.LayerNorm(n_in_feats),
|
||||||
|
torch.nn.Linear(n_in_feats, n_out_feats),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(n_in_feats, n_hidden_feats),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Dropout(p=self.dropout),
|
||||||
|
torch.nn.Linear(n_hidden_feats, n_hidden_feats),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.LayerNorm(n_hidden_feats),
|
||||||
|
torch.nn.Linear(n_hidden_feats, n_out_feats),
|
||||||
|
)
|
||||||
|
torch.nn.init.xavier_uniform_(self.layers[0].weight, gain=1)
|
||||||
|
# zero init for residual branches
|
||||||
|
if zero_init:
|
||||||
|
self.layers[-1].weight.data.fill_(0.0)
|
||||||
|
else:
|
||||||
|
torch.nn.init.xavier_uniform_(self.layers[-1].weight, gain=1)
|
||||||
|
|
||||||
|
def _zero_init(self, module):
|
||||||
|
"""Zero-initialize weights and biases."""
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
module.weight.data.zero_()
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Forward pass through the GELUMLP."""
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SumPooling(torch.nn.Module):
|
||||||
|
"""Sum pooling layer."""
|
||||||
|
|
||||||
|
def __init__(self, learnable: bool, hidden_dim: int = 1):
|
||||||
|
"""Initialize the SumPooling layer."""
|
||||||
|
super().__init__()
|
||||||
|
self.pooled_transform = (
|
||||||
|
torch.nn.Linear(hidden_dim, hidden_dim) if learnable else torch.nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, dst_idx, dst_size):
|
||||||
|
"""Forward pass through the SumPooling layer."""
|
||||||
|
return self.pooled_transform(segment_sum(x, dst_idx, dst_size))
|
||||||
|
|
||||||
|
|
||||||
|
class AveragePooling(torch.nn.Module):
|
||||||
|
"""Average pooling layer."""
|
||||||
|
|
||||||
|
def __init__(self, learnable: bool, hidden_dim: int = 1):
|
||||||
|
"""Initialize the AveragePooling layer."""
|
||||||
|
super().__init__()
|
||||||
|
self.pooled_transform = (
|
||||||
|
torch.nn.Linear(hidden_dim, hidden_dim) if learnable else torch.nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, dst_idx, dst_size):
|
||||||
|
"""Forward pass through the AveragePooling layer."""
|
||||||
|
out = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*x.shape[1:],
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device,
|
||||||
|
).index_add_(0, dst_idx, x)
|
||||||
|
nmr = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*x.shape[1:],
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device,
|
||||||
|
).index_add_(0, dst_idx, torch.ones_like(x))
|
||||||
|
return self.pooled_transform(out / (nmr + 1e-8))
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
"""Initialize weights with Kaiming uniform."""
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
torch.nn.init.kaiming_uniform_(m.weight)
|
||||||
|
|
||||||
|
|
||||||
|
def segment_sum(src, dst_idx, dst_size):
|
||||||
|
"""Computes the sum of each segment in a tensor."""
|
||||||
|
out = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*src.shape[1:],
|
||||||
|
dtype=src.dtype,
|
||||||
|
device=src.device,
|
||||||
|
).index_add_(0, dst_idx, src)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def segment_mean(src, dst_idx, dst_size):
|
||||||
|
"""Computes the mean value of each segment in a tensor."""
|
||||||
|
out = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*src.shape[1:],
|
||||||
|
dtype=src.dtype,
|
||||||
|
device=src.device,
|
||||||
|
).index_add_(0, dst_idx, src)
|
||||||
|
denom = (
|
||||||
|
torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*src.shape[1:],
|
||||||
|
dtype=src.dtype,
|
||||||
|
device=src.device,
|
||||||
|
).index_add_(0, dst_idx, torch.ones_like(src))
|
||||||
|
+ 1e-8
|
||||||
|
)
|
||||||
|
return out / denom
|
||||||
|
|
||||||
|
|
||||||
|
def segment_argmin(scores, dst_idx, dst_size, randomize: bool = False) -> torch.Tensor:
|
||||||
|
"""Samples the index of the minimum value in each segment."""
|
||||||
|
if randomize:
|
||||||
|
noise = torch.rand_like(scores)
|
||||||
|
scores = scores - torch.log(-torch.log(noise))
|
||||||
|
_, sampled_idx = scatter_min(scores, dst_idx, dim=0, dim_size=dst_size)
|
||||||
|
return sampled_idx
|
||||||
|
|
||||||
|
|
||||||
|
def segment_logsumexp(src, dst_idx, dst_size, extra_dims=None):
|
||||||
|
"""Computes the logsumexp of each segment in a tensor."""
|
||||||
|
src_max, _ = scatter_max(src, dst_idx, dim=0, dim_size=dst_size)
|
||||||
|
if extra_dims is not None:
|
||||||
|
src_max = torch.amax(src_max, dim=extra_dims, keepdim=True)
|
||||||
|
src = src - src_max[dst_idx]
|
||||||
|
out = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*src.shape[1:],
|
||||||
|
dtype=src.dtype,
|
||||||
|
device=src.device,
|
||||||
|
).index_add_(0, dst_idx, torch.exp(src))
|
||||||
|
if extra_dims is not None:
|
||||||
|
out = torch.sum(out, dim=extra_dims)
|
||||||
|
return torch.log(out + 1e-8) + src_max.view(*out.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def segment_softmax(src, dst_idx, dst_size, extra_dims=None, floor_value=None):
|
||||||
|
"""Computes the softmax of each segment in a tensor."""
|
||||||
|
src_max, _ = scatter_max(src, dst_idx, dim=0, dim_size=dst_size)
|
||||||
|
if extra_dims is not None:
|
||||||
|
src_max = torch.amax(src_max, dim=extra_dims, keepdim=True)
|
||||||
|
src = src - src_max[dst_idx]
|
||||||
|
exp1 = torch.exp(src)
|
||||||
|
exp0 = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*src.shape[1:],
|
||||||
|
dtype=src.dtype,
|
||||||
|
device=src.device,
|
||||||
|
).index_add_(0, dst_idx, exp1)
|
||||||
|
if extra_dims is not None:
|
||||||
|
exp0 = torch.sum(exp0, dim=extra_dims, keepdim=True)
|
||||||
|
exp0 = torch.index_select(input=exp0, dim=0, index=dst_idx)
|
||||||
|
exp = exp1.div(exp0 + 1e-8)
|
||||||
|
if floor_value is not None:
|
||||||
|
exp = exp.clamp(min=floor_value)
|
||||||
|
exp0 = torch.zeros(
|
||||||
|
dst_size,
|
||||||
|
*src.shape[1:],
|
||||||
|
dtype=src.dtype,
|
||||||
|
device=src.device,
|
||||||
|
).index_add_(0, dst_idx, exp)
|
||||||
|
if extra_dims is not None:
|
||||||
|
exp0 = torch.sum(exp0, dim=extra_dims, keepdim=True)
|
||||||
|
exp0 = torch.index_select(input=exp0, dim=0, index=dst_idx)
|
||||||
|
exp = exp.div(exp0 + 1e-8)
|
||||||
|
return exp
|
||||||
|
|
||||||
|
|
||||||
|
def batched_sample_onehot(logits, dim=0, max_only=False):
|
||||||
|
"""Implements the Gumbel-Max trick to sample from a one-hot distribution."""
|
||||||
|
if max_only:
|
||||||
|
sampled_idx = torch.argmax(logits, dim=dim, keepdim=True)
|
||||||
|
else:
|
||||||
|
noise = torch.rand_like(logits)
|
||||||
|
sampled_idx = torch.argmax(logits - torch.log(-torch.log(noise)), dim=dim, keepdim=True)
|
||||||
|
out_onehot = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
out_onehot.scatter_(dim=dim, index=sampled_idx, value=1)
|
||||||
|
return out_onehot
|
||||||
|
|
||||||
|
|
||||||
|
def topk_edge_mask_from_logits(scores, k, randomize=False):
|
||||||
|
"""Samples the top-k edges from a set of logits."""
|
||||||
|
assert len(scores.shape) == 3, "Scores should have shape [B, N, N]"
|
||||||
|
if randomize:
|
||||||
|
noise = torch.rand_like(scores)
|
||||||
|
scores = scores - torch.log(-torch.log(noise))
|
||||||
|
node_degree = min(k, scores.shape[2])
|
||||||
|
_, topk_idx = torch.topk(scores, node_degree, dim=-1, largest=True)
|
||||||
|
edge_mask = scores.new_zeros(scores.shape, dtype=torch.bool)
|
||||||
|
edge_mask = edge_mask.scatter_(dim=2, index=topk_idx, value=1).bool()
|
||||||
|
return edge_mask
|
||||||
|
|
||||||
|
|
||||||
|
def sample_inplace_to_torch(sample):
|
||||||
|
"""Convert NumPy sample to PyTorch tensors."""
|
||||||
|
if sample is None:
|
||||||
|
return None
|
||||||
|
sample["features"] = {k: torch.FloatTensor(v) for k, v in sample["features"].items()}
|
||||||
|
sample["indexer"] = {k: torch.LongTensor(v) for k, v in sample["indexer"].items()}
|
||||||
|
if "labels" in sample.keys():
|
||||||
|
sample["labels"] = {k: torch.FloatTensor(v) for k, v in sample["labels"].items()}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def inplace_to_device(sample, device):
|
||||||
|
"""Move sample to device."""
|
||||||
|
sample["features"] = {k: v.to(device) for k, v in sample["features"].items()}
|
||||||
|
sample["indexer"] = {k: v.to(device) for k, v in sample["indexer"].items()}
|
||||||
|
if "labels" in sample.keys():
|
||||||
|
sample["labels"] = sample["labels"].to(device)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def inplace_to_torch(sample):
|
||||||
|
"""Convert NumPy sample to PyTorch tensors."""
|
||||||
|
if sample is None:
|
||||||
|
return None
|
||||||
|
sample["features"] = {k: torch.FloatTensor(v) for k, v in sample["features"].items()}
|
||||||
|
sample["indexer"] = {k: torch.LongTensor(v) for k, v in sample["indexer"].items()}
|
||||||
|
if "labels" in sample.keys():
|
||||||
|
sample["labels"] = {k: torch.FloatTensor(v) for k, v in sample["labels"].items()}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def distance_to_gaussian_contact_logits(
|
||||||
|
x: torch.Tensor, contact_scale: float, cutoff: Optional[float] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert distance to Gaussian contact logits.
|
||||||
|
|
||||||
|
:param x: Distance tensor.
|
||||||
|
:param contact_scale: The contact scale.
|
||||||
|
:param cutoff: The distance cutoff.
|
||||||
|
:return: Gaussian contact logits.
|
||||||
|
"""
|
||||||
|
if cutoff is None:
|
||||||
|
cutoff = contact_scale * 2
|
||||||
|
return torch.log(torch.clamp(1 - (x / cutoff), min=1e-9))
|
||||||
|
|
||||||
|
|
||||||
|
def distogram_to_gaussian_contact_logits(
|
||||||
|
dgram: torch.Tensor, dist_bins: torch.Tensor, contact_scale: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert a distance histogram (distogram) matrix to a Gaussian contact map.
|
||||||
|
|
||||||
|
:param dgram: A distogram matrix.
|
||||||
|
:return: A Gaussian contact map.
|
||||||
|
"""
|
||||||
|
return torch.logsumexp(
|
||||||
|
dgram + distance_to_gaussian_contact_logits(dist_bins, contact_scale),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_true_contact_maps(
|
||||||
|
batch: MODEL_BATCH, contact_scale: float, **kwargs: Dict[str, Any]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Evaluate true contact maps.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param contact_scale: The contact scale.
|
||||||
|
:param kwargs: Additional keyword arguments.
|
||||||
|
:return: True contact maps.
|
||||||
|
"""
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
batch_size = batch["metadata"]["num_structid"]
|
||||||
|
with torch.no_grad():
|
||||||
|
# Residue centroids
|
||||||
|
res_cent_coords = (
|
||||||
|
batch["features"]["res_atom_positions"]
|
||||||
|
.mul(batch["features"]["res_atom_mask"].bool()[:, :, None])
|
||||||
|
.sum(dim=1)
|
||||||
|
.div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9)
|
||||||
|
)
|
||||||
|
res_lig_dist = (
|
||||||
|
res_cent_coords.view(batch_size, -1, 3)[:, :, None]
|
||||||
|
- batch["features"]["sdf_coordinates"][indexer["gather_idx_U_u"]].view(
|
||||||
|
batch_size, -1, 3
|
||||||
|
)[:, None, :]
|
||||||
|
).norm(dim=-1)
|
||||||
|
res_lig_contact_logit = distance_to_gaussian_contact_logits(
|
||||||
|
res_lig_dist, contact_scale, **kwargs
|
||||||
|
)
|
||||||
|
return res_lig_dist, res_lig_contact_logit.flatten()
|
||||||
|
|
||||||
|
|
||||||
|
def sample_reslig_contact_matrix(
|
||||||
|
batch: MODEL_BATCH, res_lig_logits: torch.Tensor, last: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Sample residue-ligand contact matrix.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param res_lig_logits: Residue-ligand contact logits.
|
||||||
|
:param last: The last contact matrix.
|
||||||
|
:return: Sampled residue-ligand contact matrix.
|
||||||
|
"""
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
batch_size = metadata["num_structid"]
|
||||||
|
max(metadata["num_molid_per_sample"])
|
||||||
|
n_a_per_sample = max(metadata["num_a_per_sample"])
|
||||||
|
n_I_per_sample = max(metadata["num_I_per_sample"])
|
||||||
|
res_lig_logits = res_lig_logits.view(batch_size, n_a_per_sample, n_I_per_sample)
|
||||||
|
# Sampling from unoccupied lattice sites
|
||||||
|
if last is None:
|
||||||
|
last = torch.zeros_like(res_lig_logits, dtype=torch.bool)
|
||||||
|
# Column-graph-wise masking for already sampled ligands
|
||||||
|
# sampled_ligand_mask = torch.amax(last, dim=1, keepdim=True)
|
||||||
|
sampled_frame_mask = torch.sum(last, dim=1, keepdim=True).contiguous()
|
||||||
|
masked_logits = res_lig_logits - sampled_frame_mask * 1e9
|
||||||
|
sampled_block_onehot = batched_sample_onehot(masked_logits.flatten(1, 2), dim=1).view(
|
||||||
|
batch_size, n_a_per_sample, n_I_per_sample
|
||||||
|
)
|
||||||
|
new_block_contact_mat = last + sampled_block_onehot
|
||||||
|
# Remove non-contact patches
|
||||||
|
valid_logit_mask = res_lig_logits > -16.0
|
||||||
|
new_block_contact_mat = (new_block_contact_mat * valid_logit_mask).bool()
|
||||||
|
return new_block_contact_mat
|
||||||
|
|
||||||
|
|
||||||
|
def merge_res_lig_logits_to_graph(
|
||||||
|
batch: MODEL_BATCH,
|
||||||
|
res_lig_logits: torch.Tensor,
|
||||||
|
single_protein_batch: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Patch merging [B, N_res, N_atm] -> [B, N_res, N_graph].
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param res_lig_logits: Residue-ligand contact logits.
|
||||||
|
:param single_protein_batch: Whether to use single protein batch.
|
||||||
|
:return: Merged residue-ligand logits.
|
||||||
|
"""
|
||||||
|
assert single_protein_batch, "Only single protein batch is supported."
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
indexer = batch["indexer"]
|
||||||
|
batch_size = metadata["num_structid"]
|
||||||
|
max(metadata["num_molid_per_sample"])
|
||||||
|
n_mol_per_sample = max(metadata["num_molid_per_sample"])
|
||||||
|
n_a_per_sample = max(metadata["num_a_per_sample"])
|
||||||
|
n_I_per_sample = max(metadata["num_I_per_sample"])
|
||||||
|
res_lig_logits = res_lig_logits.view(batch_size, n_a_per_sample, n_I_per_sample)
|
||||||
|
graph_wise_logits = segment_logsumexp(
|
||||||
|
res_lig_logits.permute(2, 0, 1),
|
||||||
|
indexer["gather_idx_I_molid"][:n_I_per_sample],
|
||||||
|
n_mol_per_sample,
|
||||||
|
).permute(1, 2, 0)
|
||||||
|
return graph_wise_logits
|
||||||
|
|
||||||
|
|
||||||
|
def sample_res_rowmask_from_contacts(
|
||||||
|
batch: MODEL_BATCH,
|
||||||
|
res_ligatm_logits: torch.Tensor,
|
||||||
|
single_protein_batch: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Sample residue row mask from contacts.
|
||||||
|
|
||||||
|
:param batch: A batch dictionary.
|
||||||
|
:param res_ligatm_logits: Residue-ligand atom contact logits.
|
||||||
|
:return: Sampled residue row mask.
|
||||||
|
"""
|
||||||
|
metadata = batch["metadata"]
|
||||||
|
max(metadata["num_molid_per_sample"])
|
||||||
|
lig_wise_logits = (
|
||||||
|
merge_res_lig_logits_to_graph(batch, res_ligatm_logits, single_protein_batch)
|
||||||
|
.permute(0, 2, 1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
sampled_res_onehot_mask = batched_sample_onehot(lig_wise_logits.flatten(0, 1), dim=1)
|
||||||
|
return sampled_res_onehot_mask
|
||||||
|
|
||||||
|
|
||||||
|
def extract_esm_embeddings(
|
||||||
|
esm_model: torch.nn.Module,
|
||||||
|
esm_alphabet: torch.nn.Module,
|
||||||
|
esm_batch_converter: torch.nn.Module,
|
||||||
|
sequences: List[str],
|
||||||
|
device: Union[str, torch.device],
|
||||||
|
esm_repr_layer: int = 33,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Extract embeddings from ESM model.
|
||||||
|
|
||||||
|
:param esm_model: ESM model.
|
||||||
|
:param esm_alphabet: ESM alphabet.
|
||||||
|
:param esm_batch_converter: ESM batch converter.
|
||||||
|
:param sequences: A list of sequences.
|
||||||
|
:param device: Device to use.
|
||||||
|
:param esm_repr_layer: ESM representation layer index from which to extract embeddings.
|
||||||
|
:return: A corresponding list of embeddings.
|
||||||
|
"""
|
||||||
|
# Disable dropout for deterministic results
|
||||||
|
esm_model.eval()
|
||||||
|
|
||||||
|
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
|
||||||
|
data = [(str(i), seq) for i, seq in enumerate(sequences)]
|
||||||
|
_, _, batch_tokens = esm_batch_converter(data)
|
||||||
|
batch_tokens = batch_tokens.to(device)
|
||||||
|
batch_lens = (batch_tokens != esm_alphabet.padding_idx).sum(1)
|
||||||
|
|
||||||
|
# Extract per-residue representations (on CPU)
|
||||||
|
with torch.no_grad():
|
||||||
|
results = esm_model(batch_tokens, repr_layers=[esm_repr_layer], return_contacts=True)
|
||||||
|
token_representations = results["representations"][esm_repr_layer]
|
||||||
|
|
||||||
|
# Generate per-residue representations
|
||||||
|
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
|
||||||
|
sequence_representations = []
|
||||||
|
for i, tokens_len in enumerate(batch_lens):
|
||||||
|
sequence_representations.append(token_representations[i, 1 : tokens_len - 1])
|
||||||
|
|
||||||
|
return sequence_representations
|
||||||
51
flowdock/utils/pylogger.py
Normal file
51
flowdock/utils/pylogger.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from beartype.typing import Mapping, Optional
|
||||||
|
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
||||||
|
|
||||||
|
|
||||||
|
class RankedLogger(logging.LoggerAdapter):
|
||||||
|
"""A multi-GPU-friendly python command line logger."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = __name__,
|
||||||
|
rank_zero_only: bool = False,
|
||||||
|
extra: Optional[Mapping[str, object]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
||||||
|
with their rank prefixed in the log message.
|
||||||
|
|
||||||
|
:param name: The name of the logger. Default is ``__name__``.
|
||||||
|
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
||||||
|
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
super().__init__(logger=logger, extra=extra)
|
||||||
|
self.rank_zero_only = rank_zero_only
|
||||||
|
|
||||||
|
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
|
||||||
|
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
||||||
|
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
||||||
|
occur on that rank/process.
|
||||||
|
|
||||||
|
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
||||||
|
:param msg: The message to log.
|
||||||
|
:param rank: The rank to log at.
|
||||||
|
:param args: Additional args to pass to the underlying logging function.
|
||||||
|
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
||||||
|
"""
|
||||||
|
if self.isEnabledFor(level):
|
||||||
|
msg, kwargs = self.process(msg, kwargs)
|
||||||
|
current_rank = getattr(rank_zero_only, "rank", None)
|
||||||
|
if current_rank is None:
|
||||||
|
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
|
||||||
|
msg = rank_prefixed_message(msg, current_rank)
|
||||||
|
if self.rank_zero_only:
|
||||||
|
if current_rank == 0:
|
||||||
|
self.logger.log(level, msg, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
if rank is None:
|
||||||
|
self.logger.log(level, msg, *args, **kwargs)
|
||||||
|
elif current_rank == rank:
|
||||||
|
self.logger.log(level, msg, *args, **kwargs)
|
||||||
102
flowdock/utils/rich_utils.py
Normal file
102
flowdock/utils/rich_utils.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import rich
|
||||||
|
import rich.syntax
|
||||||
|
import rich.tree
|
||||||
|
import rootutils
|
||||||
|
from beartype.typing import Sequence
|
||||||
|
from hydra.core.hydra_config import HydraConfig
|
||||||
|
from lightning_utilities.core.rank_zero import rank_zero_only
|
||||||
|
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||||
|
from rich.prompt import Prompt
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.utils import pylogger
|
||||||
|
|
||||||
|
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def print_config_tree(
|
||||||
|
cfg: DictConfig,
|
||||||
|
print_order: Sequence[str] = (
|
||||||
|
"data",
|
||||||
|
"model",
|
||||||
|
"callbacks",
|
||||||
|
"logger",
|
||||||
|
"trainer",
|
||||||
|
"paths",
|
||||||
|
"extras",
|
||||||
|
),
|
||||||
|
resolve: bool = False,
|
||||||
|
save_to_file: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
||||||
|
|
||||||
|
:param cfg: A DictConfig composed by Hydra.
|
||||||
|
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
|
||||||
|
"callbacks", "logger", "trainer", "paths", "extras")``.
|
||||||
|
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
||||||
|
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
|
||||||
|
"""
|
||||||
|
style = "dim"
|
||||||
|
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
||||||
|
|
||||||
|
queue = []
|
||||||
|
|
||||||
|
# add fields from `print_order` to queue
|
||||||
|
for field in print_order:
|
||||||
|
queue.append(field) if field in cfg else log.warning(
|
||||||
|
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# add all the other fields to queue (not specified in `print_order`)
|
||||||
|
for field in cfg:
|
||||||
|
if field not in queue:
|
||||||
|
queue.append(field)
|
||||||
|
|
||||||
|
# generate config tree from queue
|
||||||
|
for field in queue:
|
||||||
|
branch = tree.add(field, style=style, guide_style=style)
|
||||||
|
|
||||||
|
config_group = cfg[field]
|
||||||
|
if isinstance(config_group, DictConfig):
|
||||||
|
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
||||||
|
else:
|
||||||
|
branch_content = str(config_group)
|
||||||
|
|
||||||
|
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
||||||
|
|
||||||
|
# print config tree
|
||||||
|
rich.print(tree)
|
||||||
|
|
||||||
|
# save config tree to file
|
||||||
|
if save_to_file:
|
||||||
|
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
||||||
|
rich.print(tree, file=file)
|
||||||
|
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
||||||
|
"""Prompts user to input tags from command line if no tags are provided in config.
|
||||||
|
|
||||||
|
:param cfg: A DictConfig composed by Hydra.
|
||||||
|
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
|
||||||
|
"""
|
||||||
|
if not cfg.get("tags"):
|
||||||
|
if "id" in HydraConfig().cfg.hydra.job:
|
||||||
|
raise ValueError("Specify tags before launching a multirun!")
|
||||||
|
|
||||||
|
log.warning("No tags provided in config. Prompting user to input tags...")
|
||||||
|
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
||||||
|
tags = [t.strip() for t in tags.split(",") if t != ""]
|
||||||
|
|
||||||
|
with open_dict(cfg):
|
||||||
|
cfg.tags = tags
|
||||||
|
|
||||||
|
log.info(f"Tags: {cfg.tags}")
|
||||||
|
|
||||||
|
if save_to_file:
|
||||||
|
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
||||||
|
rich.print(cfg.tags, file=file)
|
||||||
427
flowdock/utils/sampling_utils.py
Normal file
427
flowdock/utils/sampling_utils.py
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import rootutils
|
||||||
|
import torch
|
||||||
|
from beartype.typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from lightning import LightningModule
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from rdkit import Chem
|
||||||
|
from rdkit.Chem import AllChem
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.data.components.mol_features import (
|
||||||
|
collate_numpy_samples,
|
||||||
|
process_mol_file,
|
||||||
|
)
|
||||||
|
from flowdock.utils import RankedLogger
|
||||||
|
from flowdock.utils.data_utils import (
|
||||||
|
FDProtein,
|
||||||
|
merge_protein_and_ligands,
|
||||||
|
pdb_filepath_to_protein,
|
||||||
|
prepare_batch,
|
||||||
|
process_protein,
|
||||||
|
)
|
||||||
|
from flowdock.utils.model_utils import inplace_to_device, inplace_to_torch, segment_mean
|
||||||
|
from flowdock.utils.visualization_utils import (
|
||||||
|
write_conformer_sdf,
|
||||||
|
write_pdb_models,
|
||||||
|
write_pdb_single,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
def featurize_protein_and_ligands(
|
||||||
|
rec_path: str,
|
||||||
|
lig_paths: List[str],
|
||||||
|
n_lig_patches: int,
|
||||||
|
apo_rec_path: Optional[str] = None,
|
||||||
|
chain_id: Optional[str] = None,
|
||||||
|
protein: Optional[FDProtein] = None,
|
||||||
|
sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None,
|
||||||
|
enforce_sanitization: bool = False,
|
||||||
|
discard_sdf_coords: bool = False,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""Featurize a protein-ligand complex.
|
||||||
|
|
||||||
|
:param rec_path: Path to the receptor file.
|
||||||
|
:param lig_paths: List of paths to the ligand files.
|
||||||
|
:param n_lig_patches: Number of ligand patches.
|
||||||
|
:param apo_rec_path: Path to the apo receptor file.
|
||||||
|
:param chain_id: Chain ID of the receptor.
|
||||||
|
:param protein: Optional protein object.
|
||||||
|
:param sequences_to_embeddings: Mapping of sequences to embeddings.
|
||||||
|
:param enforce_sanitization: Whether to enforce sanitization.
|
||||||
|
:param discard_sdf_coords: Whether to discard SDF coordinates.
|
||||||
|
:param kwargs: Additional keyword arguments.
|
||||||
|
:return: Featurized protein-ligand complex.
|
||||||
|
"""
|
||||||
|
assert rec_path is not None
|
||||||
|
if lig_paths is None:
|
||||||
|
lig_paths = []
|
||||||
|
if isinstance(lig_paths, str):
|
||||||
|
lig_paths = [lig_paths]
|
||||||
|
out_mol = None
|
||||||
|
lig_samples = []
|
||||||
|
for lig_path in lig_paths:
|
||||||
|
try:
|
||||||
|
lig_sample, mol_ref = process_mol_file(
|
||||||
|
lig_path,
|
||||||
|
sanitize=True,
|
||||||
|
return_mol=True,
|
||||||
|
discard_coords=discard_sdf_coords,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if enforce_sanitization:
|
||||||
|
raise
|
||||||
|
log.warning(
|
||||||
|
f"RDKit sanitization failed for ligand {lig_path} due to: {e}. Loading raw attributes."
|
||||||
|
)
|
||||||
|
lig_sample, mol_ref = process_mol_file(
|
||||||
|
lig_path,
|
||||||
|
sanitize=False,
|
||||||
|
return_mol=True,
|
||||||
|
discard_coords=discard_sdf_coords,
|
||||||
|
)
|
||||||
|
lig_samples.append(lig_sample)
|
||||||
|
if out_mol is None:
|
||||||
|
out_mol = mol_ref
|
||||||
|
else:
|
||||||
|
out_mol = AllChem.CombineMols(out_mol, mol_ref)
|
||||||
|
protein = protein if protein is not None else pdb_filepath_to_protein(rec_path)
|
||||||
|
rec_sample = process_protein(
|
||||||
|
protein,
|
||||||
|
chain_id=chain_id,
|
||||||
|
sequences_to_embeddings=None if apo_rec_path is not None else sequences_to_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if apo_rec_path is not None:
|
||||||
|
apo_protein = pdb_filepath_to_protein(apo_rec_path)
|
||||||
|
apo_rec_sample = process_protein(
|
||||||
|
apo_protein,
|
||||||
|
chain_id=chain_id,
|
||||||
|
sequences_to_embeddings=sequences_to_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for key in rec_sample.keys():
|
||||||
|
for subkey, value in apo_rec_sample[key].items():
|
||||||
|
rec_sample[key]["apo_" + subkey] = value
|
||||||
|
merged_sample = merge_protein_and_ligands(
|
||||||
|
lig_samples,
|
||||||
|
rec_sample,
|
||||||
|
n_lig_patches=n_lig_patches,
|
||||||
|
label=None,
|
||||||
|
)
|
||||||
|
return merged_sample, out_mol
|
||||||
|
|
||||||
|
|
||||||
|
def multi_pose_sampling(
|
||||||
|
receptor_path: str,
|
||||||
|
ligand_path: str,
|
||||||
|
cfg: DictConfig,
|
||||||
|
lit_module: LightningModule,
|
||||||
|
out_path: str,
|
||||||
|
save_pdb: bool = True,
|
||||||
|
separate_pdb: bool = True,
|
||||||
|
chain_id: Optional[str] = None,
|
||||||
|
apo_receptor_path: Optional[str] = None,
|
||||||
|
sample_id: Optional[str] = None,
|
||||||
|
protein: Optional[FDProtein] = None,
|
||||||
|
sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None,
|
||||||
|
confidence: bool = True,
|
||||||
|
affinity: bool = True,
|
||||||
|
return_all_states: bool = False,
|
||||||
|
auxiliary_estimation_only: bool = False,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> Tuple[
|
||||||
|
Optional[Chem.Mol],
|
||||||
|
Optional[List[float]],
|
||||||
|
Optional[List[float]],
|
||||||
|
Optional[List[float]],
|
||||||
|
Optional[List[Any]],
|
||||||
|
Optional[Any],
|
||||||
|
Optional[np.ndarray],
|
||||||
|
Optional[np.ndarray],
|
||||||
|
]:
|
||||||
|
"""Sample multiple poses of a protein-ligand complex.
|
||||||
|
|
||||||
|
:param receptor_path: Path to the receptor file.
|
||||||
|
:param ligand_path: Path to the ligand file.
|
||||||
|
:param cfg: Config dictionary.
|
||||||
|
:param lit_module: LightningModule instance.
|
||||||
|
:param out_path: Path to save the output files.
|
||||||
|
:param save_pdb: Whether to save PDB files.
|
||||||
|
:param separate_pdb: Whether to save separate PDB files for each pose.
|
||||||
|
:param chain_id: Chain ID of the receptor.
|
||||||
|
:param apo_receptor_path: Path to the optional apo receptor file.
|
||||||
|
:param sample_id: Optional sample ID.
|
||||||
|
:param protein: Optional protein object.
|
||||||
|
:param sequences_to_embeddings: Mapping of sequences to embeddings.
|
||||||
|
:param confidence: Whether to estimate confidence scores.
|
||||||
|
:param affinity: Whether to estimate affinity scores.
|
||||||
|
:param return_all_states: Whether to return all states.
|
||||||
|
:param auxiliary_estimation_only: Whether to only estimate auxiliary outputs (e.g., confidence,
|
||||||
|
affinity) for the input (generated) samples (potentially derived from external sources).
|
||||||
|
:param kwargs: Additional keyword arguments.
|
||||||
|
:return: Reference molecule, protein plDDTs, ligand plDDTs, ligand fragment plDDTs, estimated
|
||||||
|
binding affinities, structure trajectories, input batch, B-factors, and structure rankings.
|
||||||
|
"""
|
||||||
|
if return_all_states and auxiliary_estimation_only:
|
||||||
|
# NOTE: If auxiliary estimation is solely enabled, structure trajectory sampling will be disabled
|
||||||
|
return_all_states = False
|
||||||
|
struct_res_all, lig_res_all = [], []
|
||||||
|
plddt_all, plddt_lig_all, plddt_ligs_all, res_plddt_all = [], [], [], []
|
||||||
|
affinity_all, ligs_affinity_all = [], []
|
||||||
|
frames_all = []
|
||||||
|
chunk_size = cfg.chunk_size
|
||||||
|
for _ in range(cfg.n_samples // chunk_size):
|
||||||
|
# Resample anchor node frames
|
||||||
|
np_sample, mol = featurize_protein_and_ligands(
|
||||||
|
receptor_path,
|
||||||
|
ligand_path,
|
||||||
|
n_lig_patches=lit_module.hparams.cfg.mol_encoder.n_patches,
|
||||||
|
apo_rec_path=apo_receptor_path,
|
||||||
|
chain_id=chain_id,
|
||||||
|
protein=protein,
|
||||||
|
sequences_to_embeddings=sequences_to_embeddings,
|
||||||
|
discard_sdf_coords=cfg.discard_sdf_coords and not auxiliary_estimation_only,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
np_sample_batched = collate_numpy_samples([np_sample for _ in range(chunk_size)])
|
||||||
|
sample = inplace_to_device(inplace_to_torch(np_sample_batched), device=lit_module.device)
|
||||||
|
prepare_batch(sample)
|
||||||
|
if auxiliary_estimation_only:
|
||||||
|
# Predict auxiliary quantities using the provided input protein and ligand structures
|
||||||
|
if "num_molid" in sample["metadata"].keys() and sample["metadata"]["num_molid"] > 0:
|
||||||
|
sample["misc"]["protein_only"] = False
|
||||||
|
else:
|
||||||
|
sample["misc"]["protein_only"] = True
|
||||||
|
output_struct = {
|
||||||
|
"receptor": sample["features"]["res_atom_positions"].flatten(0, 1),
|
||||||
|
"receptor_padded": sample["features"]["res_atom_positions"],
|
||||||
|
"ligands": sample["features"]["sdf_coordinates"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
output_struct = lit_module.net.sample_pl_complex_structures(
|
||||||
|
sample,
|
||||||
|
sampler=cfg.sampler,
|
||||||
|
sampler_eta=cfg.sampler_eta,
|
||||||
|
num_steps=cfg.num_steps,
|
||||||
|
return_all_states=return_all_states,
|
||||||
|
start_time=cfg.start_time,
|
||||||
|
exact_prior=cfg.exact_prior,
|
||||||
|
)
|
||||||
|
frames_all.append(output_struct.get("all_frames", None))
|
||||||
|
if mol is not None:
|
||||||
|
ref_mol = AllChem.Mol(mol)
|
||||||
|
out_x1 = np.split(output_struct["ligands"].cpu().numpy(), cfg.chunk_size)
|
||||||
|
out_x2 = np.split(output_struct["receptor_padded"].cpu().numpy(), cfg.chunk_size)
|
||||||
|
if confidence and affinity:
|
||||||
|
assert (
|
||||||
|
lit_module.net.confidence_cfg.enabled
|
||||||
|
), "Confidence estimation must be enabled in the model configuration."
|
||||||
|
assert (
|
||||||
|
lit_module.net.affinity_cfg.enabled
|
||||||
|
), "Affinity estimation must be enabled in the model configuration."
|
||||||
|
plddt, plddt_lig, plddt_ligs = lit_module.net.run_auxiliary_estimation(
|
||||||
|
sample,
|
||||||
|
output_struct,
|
||||||
|
return_avg_stats=True,
|
||||||
|
training=False,
|
||||||
|
)
|
||||||
|
aff = sample["outputs"]["affinity_logits"]
|
||||||
|
elif confidence:
|
||||||
|
assert (
|
||||||
|
lit_module.net.confidence_cfg.enabled
|
||||||
|
), "Confidence estimation must be enabled in the model configuration."
|
||||||
|
plddt, plddt_lig, plddt_ligs = lit_module.net.run_auxiliary_estimation(
|
||||||
|
sample,
|
||||||
|
output_struct,
|
||||||
|
return_avg_stats=True,
|
||||||
|
training=False,
|
||||||
|
)
|
||||||
|
elif affinity:
|
||||||
|
assert (
|
||||||
|
lit_module.net.affinity_cfg.enabled
|
||||||
|
), "Affinity estimation must be enabled in the model configuration."
|
||||||
|
lit_module.net.run_auxiliary_estimation(
|
||||||
|
sample, output_struct, return_avg_stats=True, training=False
|
||||||
|
)
|
||||||
|
plddt, plddt_lig = None, None
|
||||||
|
aff = sample["outputs"]["affinity_logits"].cpu()
|
||||||
|
|
||||||
|
mol_idx_i_structid = segment_mean(
|
||||||
|
sample["indexer"]["gather_idx_i_structid"],
|
||||||
|
sample["indexer"]["gather_idx_i_molid"],
|
||||||
|
sample["metadata"]["num_molid"],
|
||||||
|
).long()
|
||||||
|
for struct_idx in range(cfg.chunk_size):
|
||||||
|
struct_res = {
|
||||||
|
"features": {
|
||||||
|
"asym_id": np_sample["features"]["res_chain_id"],
|
||||||
|
"residue_index": np.arange(len(np_sample["features"]["res_type"])) + 1,
|
||||||
|
"aatype": np_sample["features"]["res_type"],
|
||||||
|
},
|
||||||
|
"structure_module": {
|
||||||
|
"final_atom_positions": out_x2[struct_idx],
|
||||||
|
"final_atom_mask": sample["features"]["res_atom_mask"].bool().cpu().numpy(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
struct_res_all.append(struct_res)
|
||||||
|
if mol is not None:
|
||||||
|
lig_res_all.append(out_x1[struct_idx])
|
||||||
|
if confidence:
|
||||||
|
plddt_all.append(plddt[struct_idx].item())
|
||||||
|
res_plddt_all.append(
|
||||||
|
sample["outputs"]["plddt"][
|
||||||
|
struct_idx, : sample["metadata"]["num_a_per_sample"][0]
|
||||||
|
]
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
if plddt_lig is None:
|
||||||
|
plddt_lig_all.append(None)
|
||||||
|
else:
|
||||||
|
plddt_lig_all.append(plddt_lig[struct_idx].item())
|
||||||
|
if plddt_ligs is None:
|
||||||
|
plddt_ligs_all.append(None)
|
||||||
|
else:
|
||||||
|
plddt_ligs_all.append(plddt_ligs[mol_idx_i_structid == struct_idx].tolist())
|
||||||
|
if affinity:
|
||||||
|
# collect the average affinity across all ligands in each complex
|
||||||
|
ligs_aff = aff[mol_idx_i_structid == struct_idx]
|
||||||
|
affinity_all.append(ligs_aff.mean().item())
|
||||||
|
ligs_affinity_all.append(ligs_aff.tolist())
|
||||||
|
if confidence and cfg.rank_outputs_by_confidence:
|
||||||
|
plddt_lig_predicted = all(plddt_lig_all)
|
||||||
|
if cfg.plddt_ranking_type == "protein":
|
||||||
|
struct_plddts = np.array(plddt_all) # rank outputs using average protein plDDT
|
||||||
|
elif cfg.plddt_ranking_type == "ligand":
|
||||||
|
struct_plddts = np.array(
|
||||||
|
plddt_lig_all if plddt_lig_predicted else plddt_all
|
||||||
|
) # rank outputs using average ligand plDDT if available
|
||||||
|
if not plddt_lig_predicted:
|
||||||
|
log.warning(
|
||||||
|
"Ligand plDDT not available for all samples, using protein plDDT instead"
|
||||||
|
)
|
||||||
|
elif cfg.plddt_ranking_type == "protein_ligand":
|
||||||
|
struct_plddts = np.array(
|
||||||
|
plddt_all + plddt_lig_all if plddt_lig_predicted else plddt_all
|
||||||
|
) # rank outputs using the sum of the average protein and ligand plDDTs if ligand plDDT is available
|
||||||
|
if not plddt_lig_predicted:
|
||||||
|
log.warning(
|
||||||
|
"Ligand plDDT not available for all samples, using protein plDDT instead"
|
||||||
|
)
|
||||||
|
struct_plddt_rankings = np.argsort(
|
||||||
|
-struct_plddts
|
||||||
|
).argsort() # ensure that higher plDDTs have a higher rank (e.g., `rank1`)
|
||||||
|
receptor_plddt = np.array(res_plddt_all) if confidence else None
|
||||||
|
b_factors = (
|
||||||
|
np.repeat(
|
||||||
|
receptor_plddt[..., None],
|
||||||
|
struct_res_all[0]["structure_module"]["final_atom_mask"].shape[-1],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
if confidence
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if save_pdb:
|
||||||
|
if separate_pdb:
|
||||||
|
for struct_id, struct_res in enumerate(struct_res_all):
|
||||||
|
if confidence and cfg.rank_outputs_by_confidence:
|
||||||
|
write_pdb_single(
|
||||||
|
struct_res,
|
||||||
|
out_path=os.path.join(
|
||||||
|
out_path,
|
||||||
|
f"prot_rank{struct_plddt_rankings[struct_id] + 1}_plddt{struct_plddts[struct_id]:.7f}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.pdb",
|
||||||
|
),
|
||||||
|
b_factors=b_factors[struct_id] if confidence else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
write_pdb_single(
|
||||||
|
struct_res,
|
||||||
|
out_path=os.path.join(
|
||||||
|
out_path,
|
||||||
|
f"prot_{struct_id}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.pdb",
|
||||||
|
),
|
||||||
|
b_factors=b_factors[struct_id] if confidence else None,
|
||||||
|
)
|
||||||
|
write_pdb_models(
|
||||||
|
struct_res_all, out_path=os.path.join(out_path, "prot_all.pdb"), b_factors=b_factors
|
||||||
|
)
|
||||||
|
if mol is not None:
|
||||||
|
write_conformer_sdf(ref_mol, None, out_path=os.path.join(out_path, "lig_ref.sdf"))
|
||||||
|
lig_res_all = np.array(lig_res_all)
|
||||||
|
write_conformer_sdf(mol, lig_res_all, out_path=os.path.join(out_path, "lig_all.sdf"))
|
||||||
|
for struct_id in range(len(lig_res_all)):
|
||||||
|
if confidence and cfg.rank_outputs_by_confidence:
|
||||||
|
write_conformer_sdf(
|
||||||
|
mol,
|
||||||
|
lig_res_all[struct_id : struct_id + 1],
|
||||||
|
out_path=os.path.join(
|
||||||
|
out_path,
|
||||||
|
f"lig_rank{struct_plddt_rankings[struct_id] + 1}_plddt{struct_plddts[struct_id]:.7f}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.sdf",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
write_conformer_sdf(
|
||||||
|
mol,
|
||||||
|
lig_res_all[struct_id : struct_id + 1],
|
||||||
|
out_path=os.path.join(
|
||||||
|
out_path,
|
||||||
|
f"lig_{struct_id}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.sdf",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if confidence:
|
||||||
|
aux_estimation_all_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"sample_id": [sample_id] * len(struct_res_all),
|
||||||
|
"rank": struct_plddt_rankings + 1 if cfg.rank_outputs_by_confidence else None,
|
||||||
|
"plddt_ligs": plddt_ligs_all,
|
||||||
|
"affinity_ligs": ligs_affinity_all,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
aux_estimation_all_df.to_csv(
|
||||||
|
os.path.join(out_path, f"{sample_id if sample_id is not None else 'sample'}_auxiliary_estimation.csv"), index=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ref_mol = None
|
||||||
|
if not confidence:
|
||||||
|
plddt_all, plddt_lig_all, plddt_ligs_all = None, None, None
|
||||||
|
if not affinity:
|
||||||
|
affinity_all = None
|
||||||
|
if return_all_states:
|
||||||
|
if mol is not None:
|
||||||
|
np_sample["metadata"]["sample_ID"] = sample_id if sample_id is not None else "sample"
|
||||||
|
np_sample["metadata"]["mol"] = ref_mol
|
||||||
|
batch_all = inplace_to_torch(
|
||||||
|
collate_numpy_samples([np_sample for _ in range(cfg.n_samples)])
|
||||||
|
)
|
||||||
|
merge_frames_all = frames_all[0]
|
||||||
|
for frames in frames_all[1:]:
|
||||||
|
for frame_index, frame in enumerate(frames):
|
||||||
|
for key in frame.keys():
|
||||||
|
merge_frames_all[frame_index][key] = torch.cat(
|
||||||
|
[merge_frames_all[frame_index][key], frame[key]], dim=0
|
||||||
|
)
|
||||||
|
frames_all = merge_frames_all
|
||||||
|
else:
|
||||||
|
frames_all = None
|
||||||
|
batch_all = None
|
||||||
|
if not (confidence and cfg.rank_outputs_by_confidence):
|
||||||
|
struct_plddt_rankings = None
|
||||||
|
return (
|
||||||
|
ref_mol,
|
||||||
|
plddt_all,
|
||||||
|
plddt_lig_all,
|
||||||
|
plddt_ligs_all,
|
||||||
|
affinity_all,
|
||||||
|
frames_all,
|
||||||
|
batch_all,
|
||||||
|
b_factors,
|
||||||
|
struct_plddt_rankings,
|
||||||
|
)
|
||||||
153
flowdock/utils/utils.py
Normal file
153
flowdock/utils/utils.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
import warnings
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
import rootutils
|
||||||
|
from beartype.typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.utils import pylogger, rich_utils
|
||||||
|
|
||||||
|
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
def extras(cfg: DictConfig) -> None:
|
||||||
|
"""Applies optional utilities before the task is started.
|
||||||
|
|
||||||
|
Utilities:
|
||||||
|
- Ignoring python warnings
|
||||||
|
- Setting tags from command line
|
||||||
|
- Rich config printing
|
||||||
|
|
||||||
|
:param cfg: A DictConfig object containing the config tree.
|
||||||
|
"""
|
||||||
|
# return if no `extras` config
|
||||||
|
if not cfg.get("extras"):
|
||||||
|
log.warning("Extras config not found! <cfg.extras=null>")
|
||||||
|
return
|
||||||
|
|
||||||
|
# disable python warnings
|
||||||
|
if cfg.extras.get("ignore_warnings"):
|
||||||
|
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
# prompt user to input tags from command line if none are provided in the config
|
||||||
|
if cfg.extras.get("enforce_tags"):
|
||||||
|
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
||||||
|
rich_utils.enforce_tags(cfg, save_to_file=True)
|
||||||
|
|
||||||
|
# pretty print config tree using Rich library
|
||||||
|
if cfg.extras.get("print_config"):
|
||||||
|
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
||||||
|
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
||||||
|
|
||||||
|
|
||||||
|
def task_wrapper(task_func: Callable) -> Callable:
|
||||||
|
"""Optional decorator that controls the failure behavior when executing the task function.
|
||||||
|
|
||||||
|
This wrapper can be used to:
|
||||||
|
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
||||||
|
- save the exception to a `.log` file
|
||||||
|
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
||||||
|
- etc. (adjust depending on your needs)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
@utils.task_wrapper
|
||||||
|
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
...
|
||||||
|
return metric_dict, object_dict
|
||||||
|
```
|
||||||
|
|
||||||
|
:param task_func: The task function to be wrapped.
|
||||||
|
|
||||||
|
:return: The wrapped task function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
# execute the task
|
||||||
|
try:
|
||||||
|
metric_dict, object_dict = task_func(cfg=cfg)
|
||||||
|
|
||||||
|
# things to do if exception occurs
|
||||||
|
except Exception as ex:
|
||||||
|
# save exception to `.log` file
|
||||||
|
log.exception("")
|
||||||
|
|
||||||
|
# some hyperparameter combinations might be invalid or cause out-of-memory errors
|
||||||
|
# so when using hparam search plugins like Optuna, you might want to disable
|
||||||
|
# raising the below exception to avoid multirun failure
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
# things to always do after either success or exception
|
||||||
|
finally:
|
||||||
|
# display output dir path in terminal
|
||||||
|
log.info(f"Output dir: {cfg.paths.output_dir}")
|
||||||
|
|
||||||
|
# always close wandb run (even if exception occurs so multirun won't fail)
|
||||||
|
if find_spec("wandb"): # check if wandb is installed
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
if wandb.run:
|
||||||
|
log.info("Closing wandb!")
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
return metric_dict, object_dict
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
|
||||||
|
"""Safely retrieves value of the metric logged in LightningModule.
|
||||||
|
|
||||||
|
:param metric_dict: A dict containing metric values.
|
||||||
|
:param metric_name: If provided, the name of the metric to retrieve.
|
||||||
|
:return: If a metric name was provided, the value of the metric.
|
||||||
|
"""
|
||||||
|
if not metric_name:
|
||||||
|
log.info("Metric name is None! Skipping metric value retrieval...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if metric_name not in metric_dict:
|
||||||
|
raise Exception(
|
||||||
|
f"Metric value not found! <metric_name={metric_name}>\n"
|
||||||
|
"Make sure metric name logged in LightningModule is correct!\n"
|
||||||
|
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
||||||
|
)
|
||||||
|
|
||||||
|
metric_value = metric_dict[metric_name].item()
|
||||||
|
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
||||||
|
|
||||||
|
return metric_value
|
||||||
|
|
||||||
|
|
||||||
|
def read_strings_from_txt(path: str) -> List[str]:
|
||||||
|
"""Reads strings from a text file and returns them as a list.
|
||||||
|
|
||||||
|
:param path: Path to the text file.
|
||||||
|
:return: List of strings.
|
||||||
|
"""
|
||||||
|
with open(path) as file:
|
||||||
|
# NOTE: every line will be one element of the returned list
|
||||||
|
lines = file.readlines()
|
||||||
|
return [line.rstrip() for line in lines]
|
||||||
|
|
||||||
|
|
||||||
|
def fasta_to_dict(filename: str) -> Dict[str, str]:
|
||||||
|
"""Converts a FASTA file to a dictionary where the keys are the sequence IDs and the values are
|
||||||
|
the sequences.
|
||||||
|
|
||||||
|
:param filename: Path to the FASTA file.
|
||||||
|
:return: Dictionary with sequence IDs as keys and sequences as values.
|
||||||
|
"""
|
||||||
|
fasta_dict = {}
|
||||||
|
with open(filename) as file:
|
||||||
|
for line in file:
|
||||||
|
line = line.rstrip() # remove trailing whitespace
|
||||||
|
if line.startswith(">"): # identifier line
|
||||||
|
seq_id = line[1:] # remove the '>' character
|
||||||
|
fasta_dict[seq_id] = ""
|
||||||
|
else: # sequence line
|
||||||
|
fasta_dict[seq_id] += line
|
||||||
|
return fasta_dict
|
||||||
365
flowdock/utils/visualization_utils.py
Normal file
365
flowdock/utils/visualization_utils.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import rootutils
|
||||||
|
from beartype import beartype
|
||||||
|
from beartype.typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||||
|
from openfold.np.protein import Protein as OFProtein
|
||||||
|
from rdkit import Chem
|
||||||
|
from rdkit.Geometry.rdGeometry import Point3D
|
||||||
|
|
||||||
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||||
|
|
||||||
|
from flowdock.data.components import residue_constants
|
||||||
|
from flowdock.utils.data_utils import (
|
||||||
|
PDB_CHAIN_IDS,
|
||||||
|
PDB_MAX_CHAINS,
|
||||||
|
FDProtein,
|
||||||
|
create_full_prot,
|
||||||
|
get_mol_with_new_conformer_coords,
|
||||||
|
)
|
||||||
|
|
||||||
|
FeatureDict = Mapping[str, np.ndarray]
|
||||||
|
ModelOutput = Mapping[str, Any] # Is a nested dict.
|
||||||
|
PROT_LIG_PAIRS = List[Tuple[OFProtein, Tuple[Chem.Mol, ...]]]
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def _chain_end(
|
||||||
|
atom_index: Union[int, np.int64],
|
||||||
|
end_resname: str,
|
||||||
|
chain_name: str,
|
||||||
|
residue_index: Union[int, np.int64],
|
||||||
|
) -> str:
|
||||||
|
"""Returns a PDB `TER` record for the end of a chain.
|
||||||
|
|
||||||
|
Adapted from: https://github.com/jasonkyuyim/se3_diffusion
|
||||||
|
|
||||||
|
:param atom_index: The index of the last atom in the chain.
|
||||||
|
:param end_resname: The residue name of the last residue in the chain.
|
||||||
|
:param chain_name: The chain name of the last residue in the chain.
|
||||||
|
:param residue_index: The residue index of the last residue in the chain.
|
||||||
|
:return: A PDB `TER` record.
|
||||||
|
"""
|
||||||
|
chain_end = "TER"
|
||||||
|
return (
|
||||||
|
f"{chain_end:<6}{atom_index:>5} {end_resname:>3} "
|
||||||
|
f"{chain_name:>1}{residue_index:>4}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def res_1to3(restypes: List[str], r: Union[int, np.int64]) -> str:
|
||||||
|
"""Convert a residue type from 1-letter to 3-letter code.
|
||||||
|
|
||||||
|
:param restypes: List of residue types.
|
||||||
|
:param r: Residue type index.
|
||||||
|
:return: 3-letter code as a string.
|
||||||
|
"""
|
||||||
|
return residue_constants.restype_1to3.get(restypes[r], "UNK")
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def to_pdb(prot: Union[OFProtein, FDProtein], model=1, add_end=True, add_endmdl=True) -> str:
|
||||||
|
"""Converts a `Protein` instance to a PDB string.
|
||||||
|
|
||||||
|
Adapted from: https://github.com/jasonkyuyim/se3_diffusion
|
||||||
|
|
||||||
|
:param prot: The protein to convert to PDB.
|
||||||
|
:param model: The model number to use.
|
||||||
|
:param add_end: Whether to add an `END` record.
|
||||||
|
:param add_endmdl: Whether to add an `ENDMDL` record.
|
||||||
|
:return: PDB string.
|
||||||
|
"""
|
||||||
|
restypes = residue_constants.restypes + ["X"]
|
||||||
|
atom_types = residue_constants.atom_types
|
||||||
|
|
||||||
|
pdb_lines = []
|
||||||
|
|
||||||
|
atom_mask = prot.atom_mask
|
||||||
|
aatype = prot.aatype
|
||||||
|
atom_positions = prot.atom_positions
|
||||||
|
residue_index = prot.residue_index.astype(int)
|
||||||
|
chain_index = prot.chain_index.astype(int)
|
||||||
|
b_factors = prot.b_factors
|
||||||
|
|
||||||
|
if np.any(aatype > residue_constants.restype_num):
|
||||||
|
raise ValueError("Invalid aatypes.")
|
||||||
|
|
||||||
|
# construct a mapping from chain integer indices to chain ID strings
|
||||||
|
chain_ids = {}
|
||||||
|
for i in np.unique(chain_index): # NOTE: `np.unique` gives sorted output
|
||||||
|
if i >= PDB_MAX_CHAINS:
|
||||||
|
raise ValueError(f"The PDB format supports at most {PDB_MAX_CHAINS} chains.")
|
||||||
|
chain_ids[i] = PDB_CHAIN_IDS[i]
|
||||||
|
|
||||||
|
pdb_lines.append(f"MODEL {model}")
|
||||||
|
atom_index = 1
|
||||||
|
last_chain_index = chain_index[0]
|
||||||
|
# add all atom sites
|
||||||
|
for i in range(aatype.shape[0]):
|
||||||
|
# close the previous chain if in a multichain PDB
|
||||||
|
if last_chain_index != chain_index[i]:
|
||||||
|
pdb_lines.append(
|
||||||
|
_chain_end(
|
||||||
|
atom_index,
|
||||||
|
res_1to3(restypes, aatype[i - 1]),
|
||||||
|
chain_ids[chain_index[i - 1]],
|
||||||
|
residue_index[i - 1],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last_chain_index = chain_index[i]
|
||||||
|
atom_index += 1 # NOTE: atom index increases at the `TER` symbol
|
||||||
|
|
||||||
|
res_name_3 = res_1to3(restypes, aatype[i])
|
||||||
|
for atom_name, pos, mask, b_factor in zip(
|
||||||
|
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
|
||||||
|
):
|
||||||
|
if mask < 0.5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
record_type = "ATOM"
|
||||||
|
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
||||||
|
alt_loc = ""
|
||||||
|
insertion_code = ""
|
||||||
|
occupancy = 1.00
|
||||||
|
element = atom_name[0] # NOTE: `Protein` supports only C, N, O, S, this works
|
||||||
|
charge = ""
|
||||||
|
# NOTE: PDB is a columnar format, every space matters here!
|
||||||
|
atom_line = (
|
||||||
|
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
|
||||||
|
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
|
||||||
|
f"{residue_index[i]:>4}{insertion_code:>1} "
|
||||||
|
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
||||||
|
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
||||||
|
f"{element:>2}{charge:>2}"
|
||||||
|
)
|
||||||
|
pdb_lines.append(atom_line)
|
||||||
|
atom_index += 1
|
||||||
|
|
||||||
|
# close the final chain
|
||||||
|
pdb_lines.append(
|
||||||
|
_chain_end(
|
||||||
|
atom_index,
|
||||||
|
res_1to3(restypes, aatype[-1]),
|
||||||
|
chain_ids[chain_index[-1]],
|
||||||
|
residue_index[-1],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if add_endmdl:
|
||||||
|
pdb_lines.append("ENDMDL")
|
||||||
|
if add_end:
|
||||||
|
pdb_lines.append("END")
|
||||||
|
|
||||||
|
# pad all lines to 80 characters
|
||||||
|
pdb_lines = [line.ljust(80) for line in pdb_lines]
|
||||||
|
return "\n".join(pdb_lines) + "\n" # add terminating newline
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def construct_prot_lig_pairs(outputs: Dict[str, Any], batch_index: int) -> PROT_LIG_PAIRS:
|
||||||
|
"""Construct protein-ligand pairs from model outputs.
|
||||||
|
|
||||||
|
:param outputs: The model outputs.
|
||||||
|
:param batch_index: The index of the current batch.
|
||||||
|
:return: A list of protein-ligand object pairs.
|
||||||
|
"""
|
||||||
|
protein_batch_indexer = outputs["protein_batch_indexer"]
|
||||||
|
ligand_batch_indexer = outputs["ligand_batch_indexer"]
|
||||||
|
|
||||||
|
protein_all_atom_mask = outputs["res_atom_mask"][protein_batch_indexer == batch_index]
|
||||||
|
protein_all_atom_coordinates_mask = np.broadcast_to(
|
||||||
|
np.expand_dims(protein_all_atom_mask, -1), (protein_all_atom_mask.shape[0], 37, 3)
|
||||||
|
)
|
||||||
|
protein_aatype = outputs["aatype"][protein_batch_indexer == batch_index]
|
||||||
|
|
||||||
|
# assemble predicted structures
|
||||||
|
prot_lig_pairs = []
|
||||||
|
for protein_coordinates, ligand_coordinates in zip(
|
||||||
|
outputs["protein_coordinates_list"], outputs["ligand_coordinates_list"]
|
||||||
|
):
|
||||||
|
protein_all_atom_coordinates = (
|
||||||
|
protein_coordinates[protein_batch_indexer == batch_index]
|
||||||
|
* protein_all_atom_coordinates_mask
|
||||||
|
)
|
||||||
|
protein = create_full_prot(
|
||||||
|
protein_all_atom_coordinates,
|
||||||
|
protein_all_atom_mask,
|
||||||
|
protein_aatype,
|
||||||
|
b_factors=outputs["b_factors"][batch_index] if "b_factors" in outputs else None,
|
||||||
|
)
|
||||||
|
ligand = get_mol_with_new_conformer_coords(
|
||||||
|
outputs["ligand_mol"][batch_index],
|
||||||
|
ligand_coordinates[ligand_batch_indexer == batch_index],
|
||||||
|
)
|
||||||
|
ligands = tuple(Chem.GetMolFrags(ligand, asMols=True, sanitizeFrags=False))
|
||||||
|
prot_lig_pairs.append((protein, ligands))
|
||||||
|
|
||||||
|
# assemble ground-truth structures
|
||||||
|
if "gt_protein_coordinates" in outputs and "gt_ligand_coordinates" in outputs:
|
||||||
|
protein_gt_all_atom_coordinates = (
|
||||||
|
outputs["gt_protein_coordinates"][protein_batch_indexer == batch_index]
|
||||||
|
* protein_all_atom_coordinates_mask
|
||||||
|
)
|
||||||
|
gt_protein = create_full_prot(
|
||||||
|
protein_gt_all_atom_coordinates,
|
||||||
|
protein_all_atom_mask,
|
||||||
|
protein_aatype,
|
||||||
|
)
|
||||||
|
gt_ligand = get_mol_with_new_conformer_coords(
|
||||||
|
outputs["ligand_mol"][batch_index],
|
||||||
|
outputs["gt_ligand_coordinates"][ligand_batch_indexer == batch_index],
|
||||||
|
)
|
||||||
|
gt_ligands = tuple(Chem.GetMolFrags(gt_ligand, asMols=True, sanitizeFrags=False))
|
||||||
|
prot_lig_pairs.append((gt_protein, gt_ligands))
|
||||||
|
|
||||||
|
return prot_lig_pairs
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def write_prot_lig_pairs_to_pdb_file(prot_lig_pairs: PROT_LIG_PAIRS, output_filepath: str):
|
||||||
|
"""Write a list of protein-ligand pairs to a PDB file.
|
||||||
|
|
||||||
|
:param prot_lig_pairs: List of protein-ligand object pairs, where each ligand may consist of
|
||||||
|
multiple ligand chains.
|
||||||
|
:param output_filepath: Output file path.
|
||||||
|
"""
|
||||||
|
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
|
||||||
|
with open(output_filepath, "w") as f:
|
||||||
|
model_id = 1
|
||||||
|
for prot, lig_mols in prot_lig_pairs:
|
||||||
|
pdb_prot = to_pdb(prot, model=model_id, add_end=False, add_endmdl=False)
|
||||||
|
f.write(pdb_prot)
|
||||||
|
for lig_mol in lig_mols:
|
||||||
|
f.write(
|
||||||
|
Chem.MolToPDBBlock(lig_mol).replace(
|
||||||
|
"END\n", "TER\n"
|
||||||
|
) # enable proper ligand chain separation
|
||||||
|
)
|
||||||
|
f.write("END\n")
|
||||||
|
f.write("ENDMDL\n") # add `ENDMDL` line to separate models
|
||||||
|
model_id += 1
|
||||||
|
|
||||||
|
|
||||||
|
def from_prediction(
|
||||||
|
features: FeatureDict,
|
||||||
|
result: ModelOutput,
|
||||||
|
b_factors: Optional[np.ndarray] = None,
|
||||||
|
remove_leading_feature_dimension: bool = False,
|
||||||
|
) -> FDProtein:
|
||||||
|
"""Assembles a protein from a prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Dictionary holding model inputs.
|
||||||
|
result: Dictionary holding model outputs.
|
||||||
|
b_factors: (Optional) B-factors to use for the protein.
|
||||||
|
remove_leading_feature_dimension: Whether to remove the leading dimension
|
||||||
|
of the `features` values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A protein instance.
|
||||||
|
"""
|
||||||
|
fold_output = result["structure_module"]
|
||||||
|
|
||||||
|
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
|
||||||
|
return arr[0] if remove_leading_feature_dimension else arr
|
||||||
|
|
||||||
|
if "asym_id" in features:
|
||||||
|
chain_index = _maybe_remove_leading_dim(features["asym_id"])
|
||||||
|
else:
|
||||||
|
chain_index = np.zeros_like(_maybe_remove_leading_dim(features["aatype"]))
|
||||||
|
|
||||||
|
if b_factors is None:
|
||||||
|
b_factors = np.zeros_like(fold_output["final_atom_mask"])
|
||||||
|
|
||||||
|
return FDProtein(
|
||||||
|
letter_sequences=None,
|
||||||
|
aatype=_maybe_remove_leading_dim(features["aatype"]),
|
||||||
|
atom_positions=fold_output["final_atom_positions"],
|
||||||
|
atom_mask=fold_output["final_atom_mask"],
|
||||||
|
residue_index=_maybe_remove_leading_dim(features["residue_index"]),
|
||||||
|
chain_index=chain_index,
|
||||||
|
b_factors=b_factors,
|
||||||
|
atomtypes=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_pdb_single(
|
||||||
|
result: ModelOutput,
|
||||||
|
out_path: str = os.path.join("test_results", "debug.pdb"),
|
||||||
|
model: int = 1,
|
||||||
|
b_factors: Optional[np.ndarray] = None,
|
||||||
|
):
|
||||||
|
"""Write a single model to a PDB file.
|
||||||
|
|
||||||
|
:param result: Model results batch.
|
||||||
|
:param out_path: Output path.
|
||||||
|
:param model: Model ID.
|
||||||
|
:param b_factors: Optional B-factors.
|
||||||
|
"""
|
||||||
|
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||||
|
protein = from_prediction(result["features"], result, b_factors=b_factors)
|
||||||
|
out_string = to_pdb(protein, model=model)
|
||||||
|
with open(out_path, "w") as of:
|
||||||
|
of.write(out_string)
|
||||||
|
|
||||||
|
|
||||||
|
def write_pdb_models(
|
||||||
|
results,
|
||||||
|
out_path: str = os.path.join("test_results", "debug.pdb"),
|
||||||
|
b_factors: Optional[np.ndarray] = None,
|
||||||
|
):
|
||||||
|
"""Write multiple models to a PDB file.
|
||||||
|
|
||||||
|
:param results: Model results.
|
||||||
|
:param out_path: Output path.
|
||||||
|
:param b_factors: Optional B-factors.
|
||||||
|
"""
|
||||||
|
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||||
|
with open(out_path, "w") as of:
|
||||||
|
for mid, result in enumerate(results):
|
||||||
|
protein = from_prediction(
|
||||||
|
result["features"],
|
||||||
|
result,
|
||||||
|
b_factors=b_factors[mid] if b_factors is not None else None,
|
||||||
|
)
|
||||||
|
out_string = to_pdb(protein, model=mid + 1)
|
||||||
|
of.write(out_string)
|
||||||
|
of.write("END")
|
||||||
|
|
||||||
|
|
||||||
|
def write_conformer_sdf(
|
||||||
|
mol: Chem.Mol,
|
||||||
|
confs: Optional[np.array] = None,
|
||||||
|
out_path: str = os.path.join("test_results", "debug.sdf"),
|
||||||
|
):
|
||||||
|
"""Write a molecule with conformers to an SDF file.
|
||||||
|
|
||||||
|
:param mol: RDKit molecule.
|
||||||
|
:param confs: Conformers.
|
||||||
|
:param out_path: Output path.
|
||||||
|
"""
|
||||||
|
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||||
|
if confs is None:
|
||||||
|
w = Chem.SDWriter(out_path)
|
||||||
|
w.write(mol)
|
||||||
|
w.close()
|
||||||
|
return 0
|
||||||
|
mol.RemoveAllConformers()
|
||||||
|
for i in range(len(confs)):
|
||||||
|
conf = Chem.Conformer(mol.GetNumAtoms())
|
||||||
|
for j in range(mol.GetNumAtoms()):
|
||||||
|
x, y, z = confs[i, j].tolist()
|
||||||
|
conf.SetAtomPosition(j, Point3D(x, y, z))
|
||||||
|
mol.AddConformer(conf, assignId=True)
|
||||||
|
|
||||||
|
w = Chem.SDWriter(out_path)
|
||||||
|
try:
|
||||||
|
for cid in range(len(confs)):
|
||||||
|
w.write(mol, confId=cid)
|
||||||
|
except Exception as e:
|
||||||
|
w.SetKekulize(False)
|
||||||
|
for cid in range(len(confs)):
|
||||||
|
w.write(mol, confId=cid)
|
||||||
|
w.close()
|
||||||
|
return 0
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user