Initial commit: FlowDock pipeline configured for WES execution
Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled
Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled
This commit is contained in:
6
.env.example
Normal file
6
.env.example
Normal file
@@ -0,0 +1,6 @@
|
||||
# example of file for storing private and user specific environment variables, like keys or system paths
|
||||
# rename it to ".env" (excluded from version control by default)
|
||||
# .env is loaded by train.py automatically
|
||||
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
|
||||
|
||||
PLINDER_MOUNT="$(pwd)/data/PLINDER"
|
||||
22
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
22
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
## What does this PR do?
|
||||
|
||||
<!--
|
||||
Please include a summary of the change and which issue is fixed.
|
||||
Please also include relevant motivation and context.
|
||||
List any dependencies that are required for this change.
|
||||
List all the breaking changes introduced by this pull request.
|
||||
-->
|
||||
|
||||
Fixes #\<issue_number>
|
||||
|
||||
## Before submitting
|
||||
|
||||
- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**?
|
||||
- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together?
|
||||
- [ ] Did you list all the **breaking changes** introduced by this pull request?
|
||||
- [ ] Did you **test your PR locally** with `pytest` command?
|
||||
- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command?
|
||||
|
||||
## Did you have fun?
|
||||
|
||||
Make sure you had fun coding 🙃
|
||||
15
.github/codecov.yml
vendored
Normal file
15
.github/codecov.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
coverage:
|
||||
status:
|
||||
# measures overall project coverage
|
||||
project:
|
||||
default:
|
||||
threshold: 100% # how much decrease in coverage is needed to not consider success
|
||||
|
||||
# measures PR or single commit coverage
|
||||
patch:
|
||||
default:
|
||||
threshold: 100% # how much decrease in coverage is needed to not consider success
|
||||
|
||||
|
||||
# project: off
|
||||
# patch: off
|
||||
16
.github/dependabot.yml
vendored
Normal file
16
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# To get started with Dependabot version updates, you'll need to specify which
|
||||
# package ecosystems to update and where the package manifests are located.
|
||||
# Please see the documentation for all configuration options:
|
||||
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
||||
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "pip" # See documentation for possible values
|
||||
directory: "/" # Location of package manifests
|
||||
schedule:
|
||||
interval: "daily"
|
||||
ignore:
|
||||
- dependency-name: "pytorch-lightning"
|
||||
update-types: ["version-update:semver-patch"]
|
||||
- dependency-name: "torchmetrics"
|
||||
update-types: ["version-update:semver-patch"]
|
||||
44
.github/release-drafter.yml
vendored
Normal file
44
.github/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
name-template: "v$RESOLVED_VERSION"
|
||||
tag-template: "v$RESOLVED_VERSION"
|
||||
|
||||
categories:
|
||||
- title: "🚀 Features"
|
||||
labels:
|
||||
- "feature"
|
||||
- "enhancement"
|
||||
- title: "🐛 Bug Fixes"
|
||||
labels:
|
||||
- "fix"
|
||||
- "bugfix"
|
||||
- "bug"
|
||||
- title: "🧹 Maintenance"
|
||||
labels:
|
||||
- "maintenance"
|
||||
- "dependencies"
|
||||
- "refactoring"
|
||||
- "cosmetic"
|
||||
- "chore"
|
||||
- title: "📝️ Documentation"
|
||||
labels:
|
||||
- "documentation"
|
||||
- "docs"
|
||||
|
||||
change-template: "- $TITLE @$AUTHOR (#$NUMBER)"
|
||||
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions
|
||||
|
||||
version-resolver:
|
||||
major:
|
||||
labels:
|
||||
- "major"
|
||||
minor:
|
||||
labels:
|
||||
- "minor"
|
||||
patch:
|
||||
labels:
|
||||
- "patch"
|
||||
default: patch
|
||||
|
||||
template: |
|
||||
## Changes
|
||||
|
||||
$CHANGES
|
||||
22
.github/workflows/code-quality-main.yaml
vendored
Normal file
22
.github/workflows/code-quality-main.yaml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
# Same as `code-quality-pr.yaml` but triggered on commit to main branch
|
||||
# and runs on all files (instead of only the changed ones)
|
||||
|
||||
name: Code Quality Main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
code-quality:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
|
||||
- name: Run pre-commits
|
||||
uses: pre-commit/action@v2.0.3
|
||||
36
.github/workflows/code-quality-pr.yaml
vendored
Normal file
36
.github/workflows/code-quality-pr.yaml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
# This workflow finds which files were changed, prints them,
|
||||
# and runs `pre-commit` on those files.
|
||||
|
||||
# Inspired by the sktime library:
|
||||
# https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml
|
||||
|
||||
name: Code Quality PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main, "release/*", "dev"]
|
||||
|
||||
jobs:
|
||||
code-quality:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
|
||||
- name: Find modified files
|
||||
id: file_changes
|
||||
uses: trilom/file-changes-action@v1.2.4
|
||||
with:
|
||||
output: " "
|
||||
|
||||
- name: List modified files
|
||||
run: echo '${{ steps.file_changes.outputs.files}}'
|
||||
|
||||
- name: Run pre-commits
|
||||
uses: pre-commit/action@v2.0.3
|
||||
with:
|
||||
extra_args: --files ${{ steps.file_changes.outputs.files}}
|
||||
27
.github/workflows/release-drafter.yml
vendored
Normal file
27
.github/workflows/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Release Drafter
|
||||
|
||||
on:
|
||||
push:
|
||||
# branches to consider in the event; optional, defaults to all
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
update_release_draft:
|
||||
permissions:
|
||||
# write permission is required to create a github release
|
||||
contents: write
|
||||
# write permission is required for autolabeler
|
||||
# otherwise, read permission is required at least
|
||||
pull-requests: write
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Drafts your next Release notes as Pull Requests are merged into "master"
|
||||
- uses: release-drafter/release-drafter@v5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
139
.github/workflows/test.yml
vendored
Normal file
139
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main, "release/*", "dev"]
|
||||
|
||||
jobs:
|
||||
run_tests_ubuntu:
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["ubuntu-latest"]
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
conda env create -f environment.yml
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install sh
|
||||
|
||||
- name: List dependencies
|
||||
run: |
|
||||
python -m pip list
|
||||
|
||||
- name: Run pytest
|
||||
run: |
|
||||
pytest -v
|
||||
|
||||
run_tests_macos:
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["macos-latest"]
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
conda env create -f environment.yml
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install sh
|
||||
|
||||
- name: List dependencies
|
||||
run: |
|
||||
python -m pip list
|
||||
|
||||
- name: Run pytest
|
||||
run: |
|
||||
pytest -v
|
||||
|
||||
run_tests_windows:
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["windows-latest"]
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
conda env create -f environment.yml
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
|
||||
- name: List dependencies
|
||||
run: |
|
||||
python -m pip list
|
||||
|
||||
- name: Run pytest
|
||||
run: |
|
||||
pytest -v
|
||||
|
||||
# upload code coverage report
|
||||
code-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
conda env create -f environment.yml
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install pytest-cov[toml]
|
||||
pip install sh
|
||||
|
||||
- name: Run tests and collect coverage
|
||||
run: pytest --cov flowdock # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
work/
|
||||
.nextflow/
|
||||
.nextflow.log*
|
||||
*.log.*
|
||||
results/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.vscode/
|
||||
.idea/
|
||||
*.tmp
|
||||
*.swp
|
||||
150
.pre-commit-config.yaml
Normal file
150
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,150 @@
|
||||
default_language_version:
|
||||
python: python3
|
||||
|
||||
exclude: "^forks/"
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
# list of supported hooks: https://pre-commit.com/hooks.html
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-docstring-first
|
||||
- id: check-yaml
|
||||
- id: debug-statements
|
||||
- id: detect-private-key
|
||||
- id: check-executables-have-shebangs
|
||||
- id: check-toml
|
||||
- id: check-case-conflict
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=20000"]
|
||||
|
||||
# python code formatting
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [--line-length, "99"]
|
||||
|
||||
# python import sorting
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile", "black", "--filter-files"]
|
||||
|
||||
# python upgrading syntax to newer version
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py38-plus]
|
||||
|
||||
# python docstring formatting
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: eb1df347edd128b30cd3368dddc3aa65edcfac38 # Don't autoupdate until https://github.com/PyCQA/docformatter/issues/293 is fixed
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args:
|
||||
[
|
||||
--in-place,
|
||||
--wrap-summaries=99,
|
||||
--wrap-descriptions=99,
|
||||
--style=sphinx,
|
||||
--black,
|
||||
]
|
||||
|
||||
# python docstring coverage checking
|
||||
- repo: https://github.com/econchick/interrogate
|
||||
rev: 1.7.0 # or master if you're bold
|
||||
hooks:
|
||||
- id: interrogate
|
||||
args:
|
||||
[
|
||||
--verbose,
|
||||
--fail-under=80,
|
||||
--ignore-init-module,
|
||||
--ignore-init-method,
|
||||
--ignore-module,
|
||||
--ignore-nested-functions,
|
||||
-vv,
|
||||
]
|
||||
|
||||
# python check (PEP8), programming errors and code complexity
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.1.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
args:
|
||||
[
|
||||
"--extend-ignore",
|
||||
"E203,E402,E501,F401,F841,RST2,RST301",
|
||||
"--exclude",
|
||||
"logs/*,data/*",
|
||||
]
|
||||
additional_dependencies: [flake8-rst-docstrings==0.3.0]
|
||||
|
||||
# python security linter
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: "1.8.3"
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: ["-s", "B101"]
|
||||
|
||||
# yaml formatting
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v4.0.0-alpha.8
|
||||
hooks:
|
||||
- id: prettier
|
||||
types: [yaml]
|
||||
exclude: "environment.yaml"
|
||||
|
||||
# shell scripts linter
|
||||
- repo: https://github.com/shellcheck-py/shellcheck-py
|
||||
rev: v0.10.0.1
|
||||
hooks:
|
||||
- id: shellcheck
|
||||
|
||||
# md formatting
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
rev: 0.7.22
|
||||
hooks:
|
||||
- id: mdformat
|
||||
args: ["--number"]
|
||||
additional_dependencies:
|
||||
- mdformat-gfm
|
||||
- mdformat-tables
|
||||
- mdformat_frontmatter
|
||||
# - mdformat-toc
|
||||
# - mdformat-black
|
||||
|
||||
# word spelling linter
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.1
|
||||
hooks:
|
||||
- id: codespell
|
||||
args:
|
||||
- --skip=logs/**,data/**,*.ipynb,flowdock/data/components/constants.py,flowdock/data/components/process_mols.py,flowdock/data/components/residue_constants.py,flowdock/data/components/uff_parameters.csv,flowdock/data/components/chemical/*,flowdock/utils/data_utils.py
|
||||
# - --ignore-words-list=abc,def
|
||||
|
||||
# jupyter notebook cell output clearing
|
||||
- repo: https://github.com/kynan/nbstripout
|
||||
rev: 0.8.1
|
||||
hooks:
|
||||
- id: nbstripout
|
||||
|
||||
# jupyter notebook linting
|
||||
- repo: https://github.com/nbQA-dev/nbQA
|
||||
rev: 1.9.1
|
||||
hooks:
|
||||
- id: nbqa-black
|
||||
args: ["--line-length=99"]
|
||||
- id: nbqa-isort
|
||||
args: ["--profile=black"]
|
||||
- id: nbqa-flake8
|
||||
args:
|
||||
[
|
||||
"--extend-ignore=E203,E402,E501,F401,F841",
|
||||
"--exclude=logs/*,data/*",
|
||||
]
|
||||
2
.project-root
Normal file
2
.project-root
Normal file
@@ -0,0 +1,2 @@
|
||||
# this file is required for inferring the project root directory
|
||||
# do not delete
|
||||
49
Dockerfile
Normal file
49
Dockerfile
Normal file
@@ -0,0 +1,49 @@
|
||||
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-runtime
|
||||
|
||||
LABEL authors="BioinfoMachineLearning"
|
||||
|
||||
# Install system requirements
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --reinstall ca-certificates && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
git \
|
||||
wget \
|
||||
libxml2 \
|
||||
libgl-dev \
|
||||
libgl1 \
|
||||
gcc \
|
||||
g++ \
|
||||
procps && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set working directory
|
||||
RUN mkdir -p /software/flowdock
|
||||
WORKDIR /software/flowdock
|
||||
|
||||
# Clone FlowDock repository
|
||||
RUN git clone https://github.com/BioinfoMachineLearning/FlowDock /software/flowdock
|
||||
|
||||
# Create conda environment
|
||||
RUN conda env create -f /software/flowdock/environments/flowdock_environment.yaml
|
||||
|
||||
# Install local package and ProDy
|
||||
RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh && \
|
||||
conda activate FlowDock && \
|
||||
pip install --no-cache-dir -e /software/flowdock && \
|
||||
pip install --no-cache-dir --no-dependencies prody==2.4.1"
|
||||
|
||||
# Create checkpoints directory
|
||||
RUN mkdir -p /software/flowdock/checkpoints
|
||||
|
||||
# Download pretrained weights
|
||||
RUN wget -q https://zenodo.org/records/15066450/files/flowdock_checkpoints.tar.gz && \
|
||||
tar -xzf flowdock_checkpoints.tar.gz && \
|
||||
rm flowdock_checkpoints.tar.gz
|
||||
|
||||
# Activate conda environment by default
|
||||
RUN echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
|
||||
echo "conda activate FlowDock" >> ~/.bashrc
|
||||
|
||||
# Default shell
|
||||
SHELL ["/bin/bash", "-l", "-c"]
|
||||
CMD ["/bin/bash"]
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 BioinfoMachineLearning
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
30
Makefile
Normal file
30
Makefile
Normal file
@@ -0,0 +1,30 @@
|
||||
|
||||
help: ## Show help
|
||||
@grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
clean: ## Clean autogenerated files
|
||||
rm -rf dist
|
||||
find . -type f -name "*.DS_Store" -ls -delete
|
||||
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
|
||||
find . | grep -E ".pytest_cache" | xargs rm -rf
|
||||
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
|
||||
rm -f .coverage
|
||||
|
||||
clean-logs: ## Clean logs
|
||||
rm -rf logs/**
|
||||
|
||||
format: ## Run pre-commit hooks
|
||||
pre-commit run -a
|
||||
|
||||
sync: ## Merge changes from main branch to your current branch
|
||||
git pull
|
||||
git pull origin main
|
||||
|
||||
test: ## Run not slow tests
|
||||
pytest -k "not slow"
|
||||
|
||||
test-full: ## Run all tests
|
||||
pytest
|
||||
|
||||
train: ## Train the model
|
||||
python flowdock/train.py
|
||||
471
README.md
Normal file
471
README.md
Normal file
@@ -0,0 +1,471 @@
|
||||
<div align="center">
|
||||
|
||||
# FlowDock
|
||||
|
||||
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
|
||||
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
|
||||
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
|
||||
|
||||
<!-- <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br> -->
|
||||
|
||||
[](https://arxiv.org/abs/2412.10966)
|
||||
[](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i198/8199366)
|
||||
[](https://doi.org/10.5281/zenodo.15066450)
|
||||
|
||||
<img src="./img/FlowDock.png" width="600">
|
||||
|
||||
</div>
|
||||
|
||||
## Description
|
||||
|
||||
This is the official codebase of the paper
|
||||
|
||||
**FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction**
|
||||
|
||||
\[[arXiv](https://arxiv.org/abs/2412.10966)\] \[[ISMB](https://academic.oup.com/bioinformatics/article/41/Supplement_1/i198/8199366)\] \[[Neurosnap](https://neurosnap.ai/service/FlowDock)\] \[[Tamarind Bio](https://app.tamarind.bio/tools/flowdock)\]
|
||||
|
||||
<div align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
</div>
|
||||
|
||||
## Contents
|
||||
|
||||
- [FlowDock](#flowdock)
|
||||
- [Description](#description)
|
||||
- [Contents](#contents)
|
||||
- [Installation](#installation)
|
||||
- [How to prepare data for `FlowDock`](#how-to-prepare-data-for-flowdock)
|
||||
- [Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint)](#generating-esm2-embeddings-for-each-protein-optional-cached-input-data-available-on-sharepoint)
|
||||
- [Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo)](#predicting-apo-protein-structures-using-esmfold-optional-cached-data-available-on-zenodo)
|
||||
- [How to train `FlowDock`](#how-to-train-flowdock)
|
||||
- [How to evaluate `FlowDock`](#how-to-evaluate-flowdock)
|
||||
- [How to create comparative plots of benchmarking results](#how-to-create-comparative-plots-of-benchmarking-results)
|
||||
- [How to predict new protein-ligand complex structures and their affinities using `FlowDock`](#how-to-predict-new-protein-ligand-complex-structures-and-their-affinities-using-flowdock)
|
||||
- [For developers](#for-developers)
|
||||
- [Docker](#docker)
|
||||
- [Acknowledgements](#acknowledgements)
|
||||
- [License](#license)
|
||||
- [Citing this work](#citing-this-work)
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
|
||||
Install Mamba
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
bash Miniforge3-$(uname)-$(uname -m).sh # accept all terms and install to the default location
|
||||
rm Miniforge3-$(uname)-$(uname -m).sh # (optionally) remove installer after using it
|
||||
source ~/.bashrc # alternatively, one can restart their shell session to achieve the same result
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
|
||||
```bash
|
||||
# clone project
|
||||
git clone https://github.com/BioinfoMachineLearning/FlowDock
|
||||
cd FlowDock
|
||||
|
||||
# create conda environment
|
||||
mamba env create -f environments/flowdock_environment.yaml
|
||||
conda activate FlowDock # NOTE: one still needs to use `conda` to (de)activate environments
|
||||
pip3 install -e . # install local project as package
|
||||
pip3 install prody==2.4.1 --no-dependencies # install ProDy without NumPy dependency
|
||||
```
|
||||
|
||||
Download checkpoints
|
||||
|
||||
```bash
|
||||
# pretrained NeuralPLexer weights
|
||||
cd checkpoints/
|
||||
wget https://zenodo.org/records/10373581/files/neuralplexermodels_downstream_datasets_predictions.zip
|
||||
unzip neuralplexermodels_downstream_datasets_predictions.zip
|
||||
rm neuralplexermodels_downstream_datasets_predictions.zip
|
||||
cd ../
|
||||
```
|
||||
|
||||
```bash
|
||||
# pretrained FlowDock weights
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_checkpoints.tar.gz
|
||||
tar -xzf flowdock_checkpoints.tar.gz
|
||||
rm flowdock_checkpoints.tar.gz
|
||||
```
|
||||
|
||||
Download preprocessed datasets
|
||||
|
||||
```bash
|
||||
# cached input data for training/validation/testing
|
||||
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/ER1hctIBhDVFjM7YepOI6WcBXNBm4_e6EBjFEHAM1A3y5g?download=1"
|
||||
tar -xzf flowdock_data_cache.tar.gz
|
||||
rm flowdock_data_cache.tar.gz
|
||||
|
||||
# cached data for PDBBind, Binding MOAD, DockGen, and the PDB-based van der Mers (vdM) dataset
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_pdbbind_data.tar.gz
|
||||
tar -xzf flowdock_pdbbind_data.tar.gz
|
||||
rm flowdock_pdbbind_data.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_moad_data.tar.gz
|
||||
tar -xzf flowdock_moad_data.tar.gz
|
||||
rm flowdock_moad_data.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_dockgen_data.tar.gz
|
||||
tar -xzf flowdock_dockgen_data.tar.gz
|
||||
rm flowdock_dockgen_data.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_pdbsidechain_data.tar.gz
|
||||
tar -xzf flowdock_pdbsidechain_data.tar.gz
|
||||
rm flowdock_pdbsidechain_data.tar.gz
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## How to prepare data for `FlowDock`
|
||||
|
||||
<details>
|
||||
|
||||
**NOTE:** The following steps (besides downloading PDBBind and Binding MOAD's PDB files) are only necessary if one wants to fully process each of the following datasets manually.
|
||||
Otherwise, preprocessed versions of each dataset can be found on [Zenodo](https://zenodo.org/records/15066450).
|
||||
|
||||
Download data
|
||||
|
||||
```bash
|
||||
# fetch preprocessed PDBBind and Binding MOAD (as well as the optional DockGen and vdM datasets)
|
||||
cd data/
|
||||
|
||||
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/EXesf4oh6ztOusGqFcDyqP0Bvk-LdJ1DagEl8GNK-HxDtg?download=1"
|
||||
wget https://zenodo.org/records/10656052/files/BindingMOAD_2020_processed.tar
|
||||
wget https://zenodo.org/records/10656052/files/DockGen.tar
|
||||
wget https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02.tar.gz
|
||||
|
||||
mv EXesf4oh6ztOusGqFcDyqP0Bvk-LdJ1DagEl8GNK-HxDtg?download=1 PDBBind.tar.gz
|
||||
|
||||
tar -xzf PDBBind.tar.gz
|
||||
tar -xf BindingMOAD_2020_processed.tar
|
||||
tar -xf DockGen.tar
|
||||
tar -xzf pdb_2021aug02.tar.gz
|
||||
|
||||
rm PDBBind.tar.gz BindingMOAD_2020_processed.tar DockGen.tar pdb_2021aug02.tar.gz
|
||||
|
||||
mkdir pdbbind/ moad/ pdbsidechain/
|
||||
mv PDBBind_processed/ pdbbind/
|
||||
mv BindingMOAD_2020_processed/ moad/
|
||||
mv pdb_2021aug02/ pdbsidechain/
|
||||
|
||||
cd ../
|
||||
```
|
||||
|
||||
Lastly, to finetune `FlowDock` using the `PLINDER` dataset, one must first prepare this data for training
|
||||
|
||||
```bash
|
||||
# fetch PLINDER data (NOTE: requires ~1 hour to download and ~750G of storage)
|
||||
export PLINDER_MOUNT="$(pwd)/data/PLINDER"
|
||||
mkdir -p "$PLINDER_MOUNT" # create the directory if it doesn't exist
|
||||
|
||||
plinder_download -y
|
||||
```
|
||||
|
||||
### Generating ESM2 embeddings for each protein (optional, cached input data available on SharePoint)
|
||||
|
||||
To generate the ESM2 embeddings for the protein inputs,
|
||||
first create all the corresponding FASTA files for each protein sequence
|
||||
|
||||
```bash
|
||||
python flowdock/data/components/esm_embedding_preparation.py --dataset pdbbind --data_dir data/pdbbind/PDBBind_processed/ --out_file data/pdbbind/pdbbind_sequences.fasta
|
||||
python flowdock/data/components/esm_embedding_preparation.py --dataset moad --data_dir data/moad/BindingMOAD_2020_processed/pdb_protein/ --out_file data/moad/moad_sequences.fasta
|
||||
python flowdock/data/components/esm_embedding_preparation.py --dataset dockgen --data_dir data/DockGen/processed_files/ --out_file data/DockGen/dockgen_sequences.fasta
|
||||
python flowdock/data/components/esm_embedding_preparation.py --dataset pdbsidechain --data_dir data/pdbsidechain/pdb_2021aug02/pdb/ --out_file data/pdbsidechain/pdbsidechain_sequences.fasta
|
||||
```
|
||||
|
||||
Then, generate all ESM2 embeddings in batch using the ESM repository's helper script
|
||||
|
||||
```bash
|
||||
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/pdbbind/pdbbind_sequences.fasta data/pdbbind/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
|
||||
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/moad/moad_sequences.fasta data/moad/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
|
||||
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/DockGen/dockgen_sequences.fasta data/DockGen/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
|
||||
python flowdock/data/components/esm_embedding_extraction.py esm2_t33_650M_UR50D data/pdbsidechain/pdbsidechain_sequences.fasta data/pdbsidechain/embeddings_output --repr_layers 33 --include per_tok --truncation_seq_length 4096 --cuda_device_index 0
|
||||
```
|
||||
|
||||
### Predicting apo protein structures using ESMFold (optional, cached data available on Zenodo)
|
||||
|
||||
To generate the apo version of each protein structure,
|
||||
first create ESMFold-ready versions of the combined FASTA files
|
||||
prepared above by the script `esm_embedding_preparation.py`
|
||||
for the PDBBind, Binding MOAD, DockGen, and PDBSidechain datasets, respectively
|
||||
|
||||
```bash
|
||||
python flowdock/data/components/esmfold_sequence_preparation.py dataset=pdbbind
|
||||
python flowdock/data/components/esmfold_sequence_preparation.py dataset=moad
|
||||
python flowdock/data/components/esmfold_sequence_preparation.py dataset=dockgen
|
||||
python flowdock/data/components/esmfold_sequence_preparation.py dataset=pdbsidechain
|
||||
```
|
||||
|
||||
Then, predict each apo protein structure using ESMFold's batch
|
||||
inference script
|
||||
|
||||
```bash
|
||||
# Note: Having a CUDA-enabled device available when running this script is highly recommended
|
||||
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/pdbbind/pdbbind_esmfold_sequences.fasta -o data/pdbbind/pdbbind_esmfold_structures --cuda-device-index 0 --skip-existing
|
||||
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/moad/moad_esmfold_sequences.fasta -o data/moad/moad_esmfold_structures --cuda-device-index 0 --skip-existing
|
||||
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/DockGen/dockgen_esmfold_sequences.fasta -o data/DockGen/dockgen_esmfold_structures --cuda-device-index 0 --skip-existing
|
||||
python flowdock/data/components/esmfold_batch_structure_prediction.py -i data/pdbsidechain/pdbsidechain_esmfold_sequences.fasta -o data/pdbsidechain/pdbsidechain_esmfold_structures --cuda-device-index 0 --skip-existing
|
||||
```
|
||||
|
||||
Align each apo protein structure to its corresponding
|
||||
holo protein structure counterpart in PDBBind, Binding MOAD, and PDBSidechain,
|
||||
taking ligand conformations into account during each alignment
|
||||
|
||||
```bash
|
||||
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=pdbbind num_workers=1
|
||||
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=moad num_workers=1
|
||||
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=dockgen num_workers=1
|
||||
python flowdock/data/components/esmfold_apo_to_holo_alignment.py dataset=pdbsidechain num_workers=1
|
||||
```
|
||||
|
||||
Lastly, assess the apo-to-holo alignments in terms of statistics and structural metrics
|
||||
to enable runtime-dynamic dataset filtering using such information
|
||||
|
||||
```bash
|
||||
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=pdbbind usalign_exec_path=$MY_USALIGN_EXEC_PATH
|
||||
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=moad usalign_exec_path=$MY_USALIGN_EXEC_PATH
|
||||
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=dockgen usalign_exec_path=$MY_USALIGN_EXEC_PATH
|
||||
python flowdock/data/components/esmfold_apo_to_holo_assessment.py dataset=pdbsidechain usalign_exec_path=$MY_USALIGN_EXEC_PATH
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## How to train `FlowDock`
|
||||
|
||||
<details>
|
||||
|
||||
Train model with default configuration
|
||||
|
||||
```bash
|
||||
# train on CPU
|
||||
python flowdock/train.py trainer=cpu
|
||||
|
||||
# train on GPU
|
||||
python flowdock/train.py trainer=gpu
|
||||
```
|
||||
|
||||
Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
|
||||
|
||||
```bash
|
||||
python flowdock/train.py experiment=experiment_name.yaml
|
||||
```
|
||||
|
||||
For example, reproduce `FlowDock`'s default model training run
|
||||
|
||||
```bash
|
||||
python flowdock/train.py experiment=flowdock_fm
|
||||
```
|
||||
|
||||
**Note:** You can override any parameter from command line like this
|
||||
|
||||
```bash
|
||||
python flowdock/train.py experiment=flowdock_fm trainer.max_epochs=20 data.batch_size=8
|
||||
```
|
||||
|
||||
For example, override parameters to finetune `FlowDock`'s pretrained weights using a new dataset such as [PLINDER](https://www.plinder.sh/)
|
||||
|
||||
```bash
|
||||
python flowdock/train.py experiment=flowdock_fm data=plinder ckpt_path=checkpoints/esmfold_prior_paper_weights.ckpt
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## How to evaluate `FlowDock`
|
||||
|
||||
<details>
|
||||
|
||||
To reproduce `FlowDock`'s evaluation results for structure prediction, please refer to its documentation in version `0.6.0-FlowDock` of the [PoseBench](https://github.com/BioinfoMachineLearning/PoseBench/tree/0.6.0-FlowDock?tab=readme-ov-file#how-to-run-inference-with-flowdock) GitHub repository.
|
||||
|
||||
To reproduce `FlowDock`'s evaluation results for binding affinity prediction using the PDBBind dataset
|
||||
|
||||
```bash
|
||||
python flowdock/eval.py data.test_datasets=[pdbbind] ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt trainer=gpu
|
||||
... # re-run two more times to gather triplicate results
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## How to create comparative plots of benchmarking results
|
||||
|
||||
<details>
|
||||
|
||||
Download baseline method predictions and results
|
||||
|
||||
```bash
|
||||
# cached predictions and evaluation metrics for reproducing structure prediction paper results
|
||||
wget https://zenodo.org/records/15066450/files/alphafold3_baseline_method_predictions.tar.gz
|
||||
tar -xzf alphafold3_baseline_method_predictions.tar.gz
|
||||
rm alphafold3_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/chai_baseline_method_predictions.tar.gz
|
||||
tar -xzf chai_baseline_method_predictions.tar.gz
|
||||
rm chai_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/diffdock_baseline_method_predictions.tar.gz
|
||||
tar -xzf diffdock_baseline_method_predictions.tar.gz
|
||||
rm diffdock_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/dynamicbind_baseline_method_predictions.tar.gz
|
||||
tar -xzf dynamicbind_baseline_method_predictions.tar.gz
|
||||
rm dynamicbind_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_baseline_method_predictions.tar.gz
|
||||
tar -xzf flowdock_baseline_method_predictions.tar.gz
|
||||
rm flowdock_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_aft_baseline_method_predictions.tar.gz
|
||||
tar -xzf flowdock_aft_baseline_method_predictions.tar.gz
|
||||
rm flowdock_aft_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_pft_baseline_method_predictions.tar.gz
|
||||
tar -xzf flowdock_pft_baseline_method_predictions.tar.gz
|
||||
rm flowdock_pft_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_esmfold_baseline_method_predictions.tar.gz
|
||||
tar -xzf flowdock_esmfold_baseline_method_predictions.tar.gz
|
||||
rm flowdock_esmfold_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_chai_baseline_method_predictions.tar.gz
|
||||
tar -xzf flowdock_chai_baseline_method_predictions.tar.gz
|
||||
rm flowdock_chai_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/flowdock_hp_baseline_method_predictions.tar.gz
|
||||
tar -xzf flowdock_hp_baseline_method_predictions.tar.gz
|
||||
rm flowdock_hp_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/neuralplexer_baseline_method_predictions.tar.gz
|
||||
tar -xzf neuralplexer_baseline_method_predictions.tar.gz
|
||||
rm neuralplexer_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/vina_p2rank_baseline_method_predictions.tar.gz
|
||||
tar -xzf vina_p2rank_baseline_method_predictions.tar.gz
|
||||
rm vina_p2rank_baseline_method_predictions.tar.gz
|
||||
|
||||
wget https://zenodo.org/records/15066450/files/rfaa_baseline_method_predictions.tar.gz
|
||||
tar -xzf rfaa_baseline_method_predictions.tar.gz
|
||||
rm rfaa_baseline_method_predictions.tar.gz
|
||||
```
|
||||
|
||||
Reproduce paper result figures
|
||||
|
||||
```bash
|
||||
jupyter notebook notebooks/casp16_binding_affinity_prediction_results_plotting.ipynb
|
||||
jupyter notebook notebooks/casp16_flowdock_vs_multicom_ligand_structure_prediction_results_plotting.ipynb
|
||||
jupyter notebook notebooks/dockgen_structure_prediction_results_plotting.ipynb
|
||||
jupyter notebook notebooks/posebusters_benchmark_structure_prediction_chemical_similarity_analysis.ipynb
|
||||
jupyter notebook notebooks/posebusters_benchmark_structure_prediction_results_plotting.ipynb
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## How to predict new protein-ligand complex structures and their affinities using `FlowDock`
|
||||
|
||||
<details>
|
||||
|
||||
For example, generate new protein-ligand complexes for a pair of protein sequence and ligand SMILES strings such as those of the PDBBind 2020 test target `6i67`
|
||||
|
||||
```bash
|
||||
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling input_receptor='YNKIVHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNKLEGKVPMHKLFLEMLEAKV' input_ligand='"c1cc2c(cc1O)CCCC2"' input_template=data/pdbbind/pdbbind_holo_aligned_esmfold_structures/6i67_holo_aligned_esmfold_protein.pdb sample_id='6i67' out_path='./6i67_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=true auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
|
||||
```
|
||||
|
||||
Or, for example, generate new protein-ligand complexes for pairs of protein sequences and (multi-)ligand SMILES strings (delimited via `|`) such as those of the CASP15 target `T1152`
|
||||
|
||||
```bash
|
||||
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling input_receptor='MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIP|MYTVKPGDTMWKIAVKYQIGISEIIAANPQIKNPNLIYPGQKINIPN' input_ligand='"CC(=O)NC1C(O)OC(CO)C(OC2OC(CO)C(OC3OC(CO)C(O)C(O)C3NC(C)=O)C(O)C2NC(C)=O)C1O"' input_template=data/test_cases/predicted_structures/T1152.pdb sample_id='T1152' out_path='./T1152_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=true auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
|
||||
```
|
||||
|
||||
If you do not already have a template protein structure available for your target of interest, set `input_template=null` to instead have the sampling script predict the ESMFold structure of your provided `input_protein` sequence before running the sampling pipeline. For more information regarding the input arguments available for sampling, please refer to the config at `configs/sample.yaml`.
|
||||
|
||||
**NOTE:** To optimize prediction runtimes, a `csv_path` can be specified instead of the `input_receptor`, `input_ligand`, and `input_template` CLI arguments to perform *batched* prediction for a collection of protein-ligand sequence pairs, each represented as a CSV row containing column values for `id`, `input_receptor`, `input_ligand`, and `input_template`. Additionally, disabling `visualize_sample_trajectories` may reduce storage requirements when predicting a large batch of inputs.
|
||||
|
||||
For instance, one can perform batched prediction as follows:
|
||||
|
||||
```bash
|
||||
python flowdock/sample.py ckpt_path=checkpoints/esmfold_prior_paper_weights-EMA.ckpt model.cfg.prior_type=esmfold sampling_task=batched_structure_sampling csv_path='./data/test_cases/prediction_inputs/flowdock_batched_inputs.csv' out_path='./T1152_batch_sampled_structures/' n_samples=5 chunk_size=5 num_steps=40 sampler=VDODE sampler_eta=1.0 start_time='1.0' use_template=true separate_pdb=true visualize_sample_trajectories=false auxiliary_estimation_only=false esmfold_chunk_size=null trainer=gpu
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## For developers
|
||||
|
||||
<details>
|
||||
|
||||
Set up `pre-commit` (one time only) for automatic code linting and formatting upon each `git commit`
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Manually reformat all files in the project, as desired
|
||||
|
||||
```bash
|
||||
pre-commit run -a
|
||||
```
|
||||
|
||||
Update dependencies in a `*_environment.yml` file
|
||||
|
||||
```bash
|
||||
mamba env export > env.yaml # e.g., run this after installing new dependencies locally
|
||||
diff environments/flowdock_environment.yaml env.yaml # note the differences and copy accepted changes back into e.g., `environments/flowdock_environment.yaml`
|
||||
rm env.yaml # clean up temporary environment file
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Docker
|
||||
|
||||
<details>
|
||||
|
||||
Given that this tool has a number of dependencies, it may be easier to run it in a Docker container.
|
||||
|
||||
Pull from [Docker Hub](https://hub.docker.com/repository/docker/cford38/flowdock): `docker pull cford38/flowdock:latest`
|
||||
|
||||
Alternatively, build the Docker image locally:
|
||||
|
||||
```bash
|
||||
docker build --platform linux/amd64 -t flowdock .
|
||||
```
|
||||
|
||||
Then, run the Docker container (and mount your local `checkpoints/` directory)
|
||||
|
||||
```bash
|
||||
docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --name flowdock -it flowdock /bin/bash
|
||||
|
||||
# docker run --gpus all -v ./checkpoints:/software/flowdock/checkpoints --rm --name flowdock -it cford38/flowdock:latest /bin/bash
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
`FlowDock` builds upon the source code and data from the following projects:
|
||||
|
||||
- [DiffDock](https://github.com/gcorso/DiffDock)
|
||||
- [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template)
|
||||
- [NeuralPLexer](https://github.com/zrqiao/NeuralPLexer)
|
||||
|
||||
We thank all their contributors and maintainers!
|
||||
|
||||
## License
|
||||
|
||||
This project is covered under the **MIT License**.
|
||||
|
||||
## Citing this work
|
||||
|
||||
If you use the code or data associated with this package or otherwise find this work useful, please cite:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{morehead2025flowdock,
|
||||
title={FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction},
|
||||
author={Alex Morehead and Jianlin Cheng},
|
||||
booktitle={Intelligent Systems for Molecular Biology (ISMB)},
|
||||
year=2025,
|
||||
}
|
||||
```
|
||||
6
citation.bib
Normal file
6
citation.bib
Normal file
@@ -0,0 +1,6 @@
|
||||
@inproceedings{morehead2025flowdock,
|
||||
title={FlowDock: Geometric Flow Matching for Generative Protein-Ligand Docking and Affinity Prediction},
|
||||
author={Alex Morehead and Jianlin Cheng},
|
||||
booktitle={Intelligent Systems for Molecular Biology (ISMB)},
|
||||
year=2025,
|
||||
}
|
||||
1
configs/__init__.py
Normal file
1
configs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# this file is needed here to include configs when building project as a package
|
||||
21
configs/callbacks/default.yaml
Normal file
21
configs/callbacks/default.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
defaults:
|
||||
- ema
|
||||
- last_model_checkpoint
|
||||
- learning_rate_monitor
|
||||
- model_checkpoint
|
||||
- model_summary
|
||||
- rich_progress_bar
|
||||
- _self_
|
||||
|
||||
last_model_checkpoint:
|
||||
dirpath: ${paths.output_dir}/checkpoints
|
||||
filename: "last"
|
||||
monitor: null
|
||||
verbose: True
|
||||
auto_insert_metric_name: False
|
||||
every_n_epochs: 1
|
||||
save_on_train_epoch_end: True
|
||||
enable_version_counter: False
|
||||
|
||||
model_summary:
|
||||
max_depth: -1
|
||||
15
configs/callbacks/early_stopping.yaml
Normal file
15
configs/callbacks/early_stopping.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
||||
|
||||
early_stopping:
|
||||
_target_: lightning.pytorch.callbacks.EarlyStopping
|
||||
monitor: ??? # quantity to be monitored, must be specified !!!
|
||||
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
||||
patience: 3 # number of checks with no improvement after which training will be stopped
|
||||
verbose: False # verbosity mode
|
||||
mode: "min" # "max" means higher metric value is better, can be also "min"
|
||||
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
||||
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
||||
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
||||
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
||||
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
||||
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
||||
10
configs/callbacks/ema.yaml
Normal file
10
configs/callbacks/ema.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
|
||||
|
||||
# Maintains an exponential moving average (EMA) of model weights.
|
||||
# Look at the above link for more detailed information regarding the original implementation.
|
||||
ema:
|
||||
_target_: flowdock.models.components.callbacks.ema.EMA
|
||||
decay: 0.999
|
||||
validate_original_weights: false
|
||||
every_n_steps: 4
|
||||
cpu_offload: false
|
||||
21
configs/callbacks/last_model_checkpoint.yaml
Normal file
21
configs/callbacks/last_model_checkpoint.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
||||
|
||||
last_model_checkpoint:
|
||||
# NOTE: this is a direct copy of `model_checkpoint`,
|
||||
# which is necessary to make to work around the
|
||||
# key-duplication limitations of YAML config files
|
||||
_target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint
|
||||
dirpath: null # directory to save the model file
|
||||
filename: null # checkpoint filename
|
||||
monitor: null # name of the logged metric which determines when model is improving
|
||||
verbose: False # verbosity mode
|
||||
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
||||
save_top_k: 1 # save k best models (determined by above metric)
|
||||
mode: "min" # "max" means higher metric value is better, can be also "min"
|
||||
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
||||
save_weights_only: False # if True, then only the model’s weights will be saved
|
||||
every_n_train_steps: null # number of training steps between checkpoints
|
||||
train_time_interval: null # checkpoints are monitored at the specified time interval
|
||||
every_n_epochs: null # number of epochs between checkpoints
|
||||
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
||||
enable_version_counter: True # enables versioning for checkpoint names
|
||||
7
configs/callbacks/learning_rate_monitor.yaml
Normal file
7
configs/callbacks/learning_rate_monitor.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html
|
||||
|
||||
learning_rate_monitor:
|
||||
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
||||
logging_interval: null # set to `epoch` or `step` to log learning rate of all optimizers at the same interval, or set to `null` to log at individual interval according to the interval key of each scheduler
|
||||
log_momentum: false # whether to also log the momentum values of the optimizer, if the optimizer has the `momentum` or `betas` attribute
|
||||
log_weight_decay: false # whether to also log the weight decay values of the optimizer, if the optimizer has the `weight_decay` attribute
|
||||
18
configs/callbacks/model_checkpoint.yaml
Normal file
18
configs/callbacks/model_checkpoint.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
||||
|
||||
model_checkpoint:
|
||||
_target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint
|
||||
dirpath: null # directory to save the model file
|
||||
filename: "best" # checkpoint filename
|
||||
monitor: val_sampling/ligand_hit_score_2A_epoch # name of the logged metric which determines when model is improving
|
||||
verbose: True # verbosity mode
|
||||
save_last: False # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
||||
save_top_k: 1 # save k best models (determined by above metric)
|
||||
mode: "max" # "max" means higher metric value is better, can be also "min"
|
||||
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
||||
save_weights_only: False # if True, then only the model’s weights will be saved
|
||||
every_n_train_steps: null # number of training steps between checkpoints
|
||||
train_time_interval: null # checkpoints are monitored at the specified time interval
|
||||
every_n_epochs: null # number of epochs between checkpoints
|
||||
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
||||
enable_version_counter: False # enables versioning for checkpoint names
|
||||
5
configs/callbacks/model_summary.yaml
Normal file
5
configs/callbacks/model_summary.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
||||
|
||||
model_summary:
|
||||
_target_: lightning.pytorch.callbacks.RichModelSummary
|
||||
max_depth: 1 # the maximum depth of layer nesting that the summary will include
|
||||
0
configs/callbacks/none.yaml
Normal file
0
configs/callbacks/none.yaml
Normal file
4
configs/callbacks/rich_progress_bar.yaml
Normal file
4
configs/callbacks/rich_progress_bar.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
|
||||
|
||||
rich_progress_bar:
|
||||
_target_: lightning.pytorch.callbacks.RichProgressBar
|
||||
35
configs/debug/default.yaml
Normal file
35
configs/debug/default.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
# @package _global_
|
||||
|
||||
# default debugging setup, runs 1 full epoch
|
||||
# other debugging configs can inherit from this one
|
||||
|
||||
# overwrite task name so debugging logs are stored in separate folder
|
||||
task_name: "debug"
|
||||
|
||||
# disable callbacks and loggers during debugging
|
||||
callbacks: null
|
||||
logger: null
|
||||
|
||||
extras:
|
||||
ignore_warnings: False
|
||||
enforce_tags: False
|
||||
|
||||
# sets level of all command line loggers to 'DEBUG'
|
||||
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
||||
hydra:
|
||||
job_logging:
|
||||
root:
|
||||
level: DEBUG
|
||||
|
||||
# use this to also set hydra loggers to 'DEBUG'
|
||||
# verbose: True
|
||||
|
||||
trainer:
|
||||
max_epochs: 1
|
||||
accelerator: cpu # debuggers don't like gpus
|
||||
devices: 1 # debuggers don't like multiprocessing
|
||||
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
|
||||
|
||||
data:
|
||||
num_workers: 0 # debuggers don't like multiprocessing
|
||||
pin_memory: False # disable gpu memory pin
|
||||
9
configs/debug/fdr.yaml
Normal file
9
configs/debug/fdr.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
# @package _global_
|
||||
|
||||
# runs 1 train, 1 validation and 1 test step
|
||||
|
||||
defaults:
|
||||
- default
|
||||
|
||||
trainer:
|
||||
fast_dev_run: true
|
||||
12
configs/debug/limit.yaml
Normal file
12
configs/debug/limit.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# @package _global_
|
||||
|
||||
# uses only 1% of the training data and 5% of validation/test data
|
||||
|
||||
defaults:
|
||||
- default
|
||||
|
||||
trainer:
|
||||
max_epochs: 3
|
||||
limit_train_batches: 0.01
|
||||
limit_val_batches: 0.05
|
||||
limit_test_batches: 0.05
|
||||
13
configs/debug/overfit.yaml
Normal file
13
configs/debug/overfit.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
# @package _global_
|
||||
|
||||
# overfits to 3 batches
|
||||
|
||||
defaults:
|
||||
- default
|
||||
|
||||
trainer:
|
||||
max_epochs: 20
|
||||
overfit_batches: 3
|
||||
|
||||
# model ckpt and early stopping need to be disabled during overfitting
|
||||
callbacks: null
|
||||
12
configs/debug/profiler.yaml
Normal file
12
configs/debug/profiler.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# @package _global_
|
||||
|
||||
# runs with execution time profiling
|
||||
|
||||
defaults:
|
||||
- default
|
||||
|
||||
trainer:
|
||||
max_epochs: 1
|
||||
profiler: "simple"
|
||||
# profiler: "advanced"
|
||||
# profiler: "pytorch"
|
||||
2
configs/environment/default.yaml
Normal file
2
configs/environment/default.yaml
Normal file
@@ -0,0 +1,2 @@
|
||||
defaults:
|
||||
- _self_
|
||||
1
configs/environment/lightning.yaml
Normal file
1
configs/environment/lightning.yaml
Normal file
@@ -0,0 +1 @@
|
||||
_target_: lightning.fabric.plugins.environments.LightningEnvironment
|
||||
3
configs/environment/slurm.yaml
Normal file
3
configs/environment/slurm.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
_target_: lightning.fabric.plugins.environments.SLURMEnvironment
|
||||
auto_requeue: true
|
||||
requeue_signal: null
|
||||
49
configs/eval.yaml
Normal file
49
configs/eval.yaml
Normal file
@@ -0,0 +1,49 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- data: combined # choose datamodule with `test_dataloader()` for evaluation
|
||||
- model: flowdock_fm
|
||||
- logger: null
|
||||
- strategy: default
|
||||
- trainer: default
|
||||
- paths: default
|
||||
- extras: default
|
||||
- hydra: default
|
||||
- environment: default
|
||||
- _self_
|
||||
|
||||
task_name: "eval"
|
||||
|
||||
tags: ["eval", "combined", "flowdock_fm"]
|
||||
|
||||
# passing checkpoint path is necessary for evaluation
|
||||
ckpt_path: ???
|
||||
|
||||
# seed for random number generators in pytorch, numpy and python.random
|
||||
seed: null
|
||||
|
||||
# model arguments
|
||||
model:
|
||||
cfg:
|
||||
mol_encoder:
|
||||
from_pretrained: false
|
||||
protein_encoder:
|
||||
from_pretrained: false
|
||||
relational_reasoning:
|
||||
from_pretrained: false
|
||||
contact_predictor:
|
||||
from_pretrained: false
|
||||
score_head:
|
||||
from_pretrained: false
|
||||
confidence:
|
||||
from_pretrained: false
|
||||
affinity:
|
||||
from_pretrained: false
|
||||
task:
|
||||
freeze_mol_encoder: true
|
||||
freeze_protein_encoder: false
|
||||
freeze_relational_reasoning: false
|
||||
freeze_contact_predictor: false
|
||||
freeze_score_head: false
|
||||
freeze_confidence: true
|
||||
freeze_affinity: false
|
||||
35
configs/experiment/flowdock_fm.yaml
Normal file
35
configs/experiment/flowdock_fm.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
# @package _global_
|
||||
|
||||
# to execute this experiment run:
|
||||
# python train.py experiment=flowdock_fm
|
||||
|
||||
defaults:
|
||||
- override /data: combined
|
||||
- override /model: flowdock_fm
|
||||
- override /callbacks: default
|
||||
- override /trainer: default
|
||||
|
||||
# all parameters below will be merged with parameters from default configurations set above
|
||||
# this allows you to overwrite only specified parameters
|
||||
|
||||
tags: ["flowdock_fm", "combined_dataset"]
|
||||
|
||||
seed: 496
|
||||
|
||||
trainer:
|
||||
max_epochs: 300
|
||||
check_val_every_n_epoch: 5 # NOTE: we increase this since validation steps involve full model sampling and evaluation
|
||||
reload_dataloaders_every_n_epochs: 1
|
||||
|
||||
model:
|
||||
optimizer:
|
||||
lr: 2e-4
|
||||
compile: false
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
|
||||
logger:
|
||||
wandb:
|
||||
tags: ${tags}
|
||||
group: "FlowDock-FM"
|
||||
8
configs/extras/default.yaml
Normal file
8
configs/extras/default.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
# disable python warnings if they annoy you
|
||||
ignore_warnings: False
|
||||
|
||||
# ask user for tags if none are provided in the config
|
||||
enforce_tags: True
|
||||
|
||||
# pretty print config tree at the start of the run using Rich library
|
||||
print_config: True
|
||||
50
configs/hparams_search/combined_optuna.yaml
Normal file
50
configs/hparams_search/combined_optuna.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
# @package _global_
|
||||
|
||||
# example hyperparameter optimization of some experiment with Optuna:
|
||||
# python train.py -m hparams_search=mnist_optuna experiment=example
|
||||
|
||||
defaults:
|
||||
- override /hydra/sweeper: optuna
|
||||
|
||||
# choose metric which will be optimized by Optuna
|
||||
# make sure this is the correct name of some metric logged in lightning module!
|
||||
optimized_metric: "val/loss"
|
||||
|
||||
# here we define Optuna hyperparameter search
|
||||
# it optimizes for value returned from function with @hydra.main decorator
|
||||
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
||||
hydra:
|
||||
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
|
||||
|
||||
sweeper:
|
||||
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
||||
|
||||
# storage URL to persist optimization results
|
||||
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
||||
storage: null
|
||||
|
||||
# name of the study to persist optimization results
|
||||
study_name: null
|
||||
|
||||
# number of parallel workers
|
||||
n_jobs: 1
|
||||
|
||||
# 'minimize' or 'maximize' the objective
|
||||
direction: minimize
|
||||
|
||||
# total number of runs that will be executed
|
||||
n_trials: 20
|
||||
|
||||
# choose Optuna hyperparameter sampler
|
||||
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
|
||||
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
||||
sampler:
|
||||
_target_: optuna.samplers.TPESampler
|
||||
seed: 1234
|
||||
n_startup_trials: 10 # number of random sampling runs before optimization starts
|
||||
|
||||
# define hyperparameter search space
|
||||
params:
|
||||
model.optimizer.lr: interval(0.0001, 0.1)
|
||||
data.batch_size: choice(2, 4, 8, 16)
|
||||
model.net.hidden_dim: choice(64, 128, 256)
|
||||
19
configs/hydra/default.yaml
Normal file
19
configs/hydra/default.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
# https://hydra.cc/docs/configure_hydra/intro/
|
||||
|
||||
# enable color logging
|
||||
defaults:
|
||||
- override hydra_logging: colorlog
|
||||
- override job_logging: colorlog
|
||||
|
||||
# output directory, generated dynamically on each run
|
||||
run:
|
||||
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
||||
sweep:
|
||||
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
||||
subdir: ${hydra.job.num}
|
||||
|
||||
job_logging:
|
||||
handlers:
|
||||
file:
|
||||
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
||||
filename: ${hydra.runtime.output_dir}/${task_name}.log
|
||||
0
configs/local/.gitkeep
Normal file
0
configs/local/.gitkeep
Normal file
28
configs/logger/aim.yaml
Normal file
28
configs/logger/aim.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
# https://aimstack.io/
|
||||
|
||||
# example usage in lightning module:
|
||||
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
|
||||
|
||||
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
|
||||
# `aim up`
|
||||
|
||||
aim:
|
||||
_target_: aim.pytorch_lightning.AimLogger
|
||||
repo: ${paths.root_dir} # .aim folder will be created here
|
||||
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
|
||||
|
||||
# aim allows to group runs under experiment name
|
||||
experiment: null # any string, set to "default" if not specified
|
||||
|
||||
train_metric_prefix: "train/"
|
||||
val_metric_prefix: "val/"
|
||||
test_metric_prefix: "test/"
|
||||
|
||||
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
|
||||
system_tracking_interval: 10 # set to null to disable system metrics tracking
|
||||
|
||||
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
|
||||
log_system_params: true
|
||||
|
||||
# enable/disable tracking console logs (default value is true)
|
||||
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
|
||||
12
configs/logger/comet.yaml
Normal file
12
configs/logger/comet.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# https://www.comet.ml
|
||||
|
||||
comet:
|
||||
_target_: lightning.pytorch.loggers.comet.CometLogger
|
||||
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
||||
save_dir: "${paths.output_dir}"
|
||||
project_name: "FlowDock_FM"
|
||||
rest_api_key: null
|
||||
# experiment_name: ""
|
||||
experiment_key: null # set to resume experiment
|
||||
offline: False
|
||||
prefix: ""
|
||||
7
configs/logger/csv.yaml
Normal file
7
configs/logger/csv.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
# csv logger built in lightning
|
||||
|
||||
csv:
|
||||
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
||||
save_dir: "${paths.output_dir}"
|
||||
name: "csv/"
|
||||
prefix: ""
|
||||
9
configs/logger/many_loggers.yaml
Normal file
9
configs/logger/many_loggers.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
# train with many loggers at once
|
||||
|
||||
defaults:
|
||||
# - comet
|
||||
- csv
|
||||
# - mlflow
|
||||
# - neptune
|
||||
- tensorboard
|
||||
- wandb
|
||||
12
configs/logger/mlflow.yaml
Normal file
12
configs/logger/mlflow.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# https://mlflow.org
|
||||
|
||||
mlflow:
|
||||
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
|
||||
# experiment_name: ""
|
||||
# run_name: ""
|
||||
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
||||
tags: null
|
||||
# save_dir: "./mlruns"
|
||||
prefix: ""
|
||||
artifact_location: null
|
||||
# run_id: ""
|
||||
9
configs/logger/neptune.yaml
Normal file
9
configs/logger/neptune.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
# https://neptune.ai
|
||||
|
||||
neptune:
|
||||
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
|
||||
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
||||
project: username/FlowDock_FM
|
||||
# name: ""
|
||||
log_model_checkpoints: True
|
||||
prefix: ""
|
||||
10
configs/logger/tensorboard.yaml
Normal file
10
configs/logger/tensorboard.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# https://www.tensorflow.org/tensorboard/
|
||||
|
||||
tensorboard:
|
||||
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
||||
save_dir: "${paths.output_dir}/tensorboard/"
|
||||
name: null
|
||||
log_graph: False
|
||||
default_hp_metric: True
|
||||
prefix: ""
|
||||
# version: ""
|
||||
16
configs/logger/wandb.yaml
Normal file
16
configs/logger/wandb.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
# https://wandb.ai
|
||||
|
||||
wandb:
|
||||
_target_: lightning.pytorch.loggers.wandb.WandbLogger
|
||||
# name: "" # name of the run (normally generated by wandb)
|
||||
save_dir: "${paths.output_dir}"
|
||||
offline: False
|
||||
id: null # pass correct id to resume experiment!
|
||||
anonymous: null # enable anonymous logging
|
||||
project: "FlowDock_FM"
|
||||
log_model: False # upload lightning ckpts
|
||||
prefix: "" # a string to put at the beginning of metric keys
|
||||
entity: "bml-lab" # set to name of your wandb team
|
||||
group: ""
|
||||
tags: []
|
||||
job_type: ""
|
||||
148
configs/model/flowdock_fm.yaml
Normal file
148
configs/model/flowdock_fm.yaml
Normal file
@@ -0,0 +1,148 @@
|
||||
_target_: flowdock.models.flowdock_fm_module.FlowDockFMLitModule
|
||||
|
||||
net:
|
||||
_target_: flowdock.models.components.flowdock.FlowDock
|
||||
_partial_: true
|
||||
|
||||
optimizer:
|
||||
_target_: torch.optim.Adam
|
||||
_partial_: true
|
||||
lr: 2e-4
|
||||
weight_decay: 0.0
|
||||
|
||||
scheduler:
|
||||
_target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
|
||||
_partial_: true
|
||||
T_0: ${int_divide:${trainer.max_epochs},15}
|
||||
T_mult: 2
|
||||
eta_min: 1e-8
|
||||
verbose: true
|
||||
|
||||
# compile model for faster training with pytorch 2.0
|
||||
compile: false
|
||||
|
||||
# model arguments
|
||||
cfg:
|
||||
mol_encoder:
|
||||
node_channels: 512
|
||||
pair_channels: 64
|
||||
n_atom_encodings: 23
|
||||
n_bond_encodings: 4
|
||||
n_atom_pos_encodings: 6
|
||||
n_stereo_encodings: 14
|
||||
n_attention_heads: 8
|
||||
attention_head_dim: 8
|
||||
hidden_dim: 2048
|
||||
max_path_integral_length: 6
|
||||
n_transformer_stacks: 8
|
||||
n_heads: 8
|
||||
n_patches: ${data.n_lig_patches}
|
||||
checkpoint_file: ${oc.env:PROJECT_ROOT}/checkpoints/neuralplexermodels_downstream_datasets_predictions/models/complex_structure_prediction.ckpt
|
||||
megamolbart: null
|
||||
from_pretrained: true
|
||||
|
||||
protein_encoder:
|
||||
use_esm_embedding: true
|
||||
esm_version: esm2_t33_650M_UR50D
|
||||
esm_repr_layer: 33
|
||||
residue_dim: 512
|
||||
plm_embed_dim: 1280
|
||||
n_aa_types: 21
|
||||
atom_padding_dim: 37
|
||||
n_atom_types: 4 # [C, N, O, S]
|
||||
n_patches: ${data.n_protein_patches}
|
||||
n_attention_heads: 8
|
||||
scalar_dim: 16
|
||||
point_dim: 4
|
||||
pair_dim: 64
|
||||
n_heads: 8
|
||||
head_dim: 8
|
||||
max_residue_degree: 32
|
||||
n_encoder_stacks: 2
|
||||
from_pretrained: true
|
||||
|
||||
relational_reasoning:
|
||||
from_pretrained: true
|
||||
|
||||
contact_predictor:
|
||||
n_stacks: 4
|
||||
dropout: 0.01
|
||||
from_pretrained: true
|
||||
|
||||
score_head:
|
||||
fiber_dim: 64
|
||||
hidden_dim: 512
|
||||
n_stacks: 4
|
||||
max_atom_degree: 8
|
||||
from_pretrained: true
|
||||
|
||||
confidence:
|
||||
enabled: true # whether the confidence prediction head is to be used e.g., during inference
|
||||
fiber_dim: ${..score_head.fiber_dim}
|
||||
hidden_dim: ${..score_head.hidden_dim}
|
||||
n_stacks: ${..score_head.n_stacks}
|
||||
from_pretrained: true
|
||||
|
||||
affinity:
|
||||
enabled: true # whether the affinity prediction head is to be used e.g., during inference
|
||||
fiber_dim: ${..score_head.fiber_dim}
|
||||
hidden_dim: ${..score_head.hidden_dim}
|
||||
n_stacks: ${..score_head.n_stacks}
|
||||
ligand_pooling: sum # NOTE: must be a value in (`sum`, `mean`)
|
||||
dropout: 0.01
|
||||
from_pretrained: false
|
||||
|
||||
latent_model: default
|
||||
prior_type: esmfold # NOTE: must be a value in (`gaussian`, `harmonic`, `esmfold`)
|
||||
|
||||
task:
|
||||
pretrained: null
|
||||
ligands: true
|
||||
epoch_frac: ${data.epoch_frac}
|
||||
label_name: null
|
||||
sequence_crop_size: 1600
|
||||
edge_crop_size: ${data.edge_crop_size} # NOTE: for dynamic batching via `max_n_edges`
|
||||
max_masking_rate: 0.0
|
||||
n_modes: 8
|
||||
dropout: 0.01
|
||||
# pretraining: true
|
||||
freeze_mol_encoder: true
|
||||
freeze_protein_encoder: false
|
||||
freeze_relational_reasoning: false
|
||||
freeze_contact_predictor: true
|
||||
freeze_score_head: false
|
||||
freeze_confidence: true
|
||||
freeze_affinity: false
|
||||
use_template: true
|
||||
use_plddt: false
|
||||
block_contact_decoding_scheme: "beam"
|
||||
frozen_ligand_backbone: false
|
||||
frozen_protein_backbone: false
|
||||
single_protein_batch: true
|
||||
contact_loss_weight: 0.2
|
||||
global_score_loss_weight: 0.2
|
||||
ligand_score_loss_weight: 0.1
|
||||
clash_loss_weight: 10.0
|
||||
local_distgeom_loss_weight: 10.0
|
||||
drmsd_loss_weight: 2.0
|
||||
distogram_loss_weight: 0.05
|
||||
plddt_loss_weight: 1.0
|
||||
affinity_loss_weight: 0.1
|
||||
aux_batch_freq: 10 # NOTE: e.g., `10` means that auxiliary estimation losses will be calculated every 10th batch
|
||||
global_max_sigma: 5.0
|
||||
internal_max_sigma: 2.0
|
||||
detect_covalent: true
|
||||
# runtime configs
|
||||
float32_matmul_precision: highest
|
||||
# sampling configs
|
||||
constrained_inpainting: false
|
||||
visualize_generated_samples: true
|
||||
# testing configs
|
||||
loss_mode: auxiliary_estimation # NOTE: must be one of (`structure_prediction`, `auxiliary_estimation`, `auxiliary_estimation_without_structure_prediction`)
|
||||
num_steps: 20
|
||||
sampler: VDODE # NOTE: must be one of (`ODE`, `VDODE`)
|
||||
sampler_eta: 1.0 # NOTE: this corresponds to the variance diminishing factor for the `VDODE` sampler, which offers a trade-off between exploration (1.0) and exploitation (> 1.0)
|
||||
start_time: 1.0
|
||||
eval_structure_prediction: false # whether to evaluate structure prediction performance (`true`) or instead only binding affinity performance (`false`) when running a test epoch
|
||||
# overfitting configs
|
||||
overfitting_example_name: ${data.overfitting_example_name}
|
||||
21
configs/paths/default.yaml
Normal file
21
configs/paths/default.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
# path to root directory
|
||||
# this requires PROJECT_ROOT environment variable to exist
|
||||
# you can replace it with "." if you want the root to be the current working directory
|
||||
root_dir: ${oc.env:PROJECT_ROOT}
|
||||
|
||||
# path to data directory
|
||||
data_dir: ${paths.root_dir}/data/
|
||||
|
||||
# path to logging directory
|
||||
log_dir: ${paths.root_dir}/logs/
|
||||
|
||||
# path to output directory, created dynamically by hydra
|
||||
# path generation pattern is specified in `configs/hydra/default.yaml`
|
||||
# use it to store all files generated during the run, like ckpts and metrics
|
||||
output_dir: ${hydra:runtime.output_dir}
|
||||
|
||||
# path to working directory
|
||||
work_dir: ${hydra:runtime.cwd}
|
||||
|
||||
# path to the directory containing the model checkpoints
|
||||
ckpt_dir: ${paths.root_dir}/checkpoints/
|
||||
78
configs/sample.yaml
Normal file
78
configs/sample.yaml
Normal file
@@ -0,0 +1,78 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- data: combined # NOTE: this will not be referenced during sampling
|
||||
- model: flowdock_fm
|
||||
- logger: null
|
||||
- strategy: default
|
||||
- trainer: default
|
||||
- paths: default
|
||||
- extras: default
|
||||
- hydra: default
|
||||
- environment: default
|
||||
- _self_
|
||||
|
||||
task_name: "sample"
|
||||
|
||||
tags: ["sample", "combined", "flowdock_fm"]
|
||||
|
||||
# passing checkpoint path is necessary for sampling
|
||||
ckpt_path: ???
|
||||
|
||||
# seed for random number generators in pytorch, numpy and python.random
|
||||
seed: null
|
||||
|
||||
# sampling arguments
|
||||
sampling_task: batched_structure_sampling # NOTE: must be one of (`batched_structure_sampling`)
|
||||
sample_id: null # optional identifier for the sampling run
|
||||
input_receptor: null # NOTE: must be either a protein sequence string (with chains separated by `|`) or a path to a PDB file (from which protein chain sequences will be parsed)
|
||||
input_ligand: null # NOTE: must be either a ligand SMILES string (with chains/fragments separated by `|`) or a path to a ligand SDF file (from which ligand SMILES will be parsed)
|
||||
input_template: null # path to a protein PDB file to use as a starting protein template for sampling (with an ESMFold prior model)
|
||||
out_path: ??? # path to which to save the output PDB and SDF files
|
||||
n_samples: 5 # number of structures to sample
|
||||
chunk_size: 5 # number of structures to concurrently sample within each batch segment - NOTE: `n_samples` should be evenly divisible by `chunk_size` to produce the expected number of outputs
|
||||
num_steps: 40 # number of sampling steps to perform
|
||||
latent_model: null # if provided, the type of latent model to use
|
||||
sampler: VDODE # sampling algorithm to use - NOTE: must be one of (`ODE`, `VDODE`)
|
||||
sampler_eta: 1.0 # the variance diminishing factor for the `VDODE` sampler - NOTE: offers a trade-off between exploration (1.0) and exploitation (> 1.0)
|
||||
start_time: "1.0" # time at which to start sampling
|
||||
max_chain_encoding_k: -1 # maximum number of chains to encode in the chain encoding
|
||||
exact_prior: false # whether to use the "ground-truth" binding site for sampling, if available
|
||||
prior_type: esmfold # the type of prior to use for sampling - NOTE: must be one of (`gaussian`, `harmonic`, `esmfold`)
|
||||
discard_ligand: false # whether to discard a given input ligand during sampling
|
||||
discard_sdf_coords: true # whether to discard the input ligand's 3D structure during sampling, if available
|
||||
detect_covalent: false # whether to detect covalent bonds between the input receptor and ligand
|
||||
use_template: true # whether to use the input protein template for sampling if one is provided
|
||||
separate_pdb: true # whether to save separate PDB files for each sampled structure instead of simply a single PDB file
|
||||
rank_outputs_by_confidence: true # whether to rank the sampled structures by estimated confidence
|
||||
plddt_ranking_type: ligand # the type of plDDT ranking to apply to generated samples - NOTE: must be one of (`protein`, `ligand`, `protein_ligand`)
|
||||
visualize_sample_trajectories: false # whether to visualize the generated samples' trajectories
|
||||
auxiliary_estimation_only: false # whether to only estimate auxiliary outputs (e.g., confidence, affinity) for the input (generated) samples (potentially derived from external sources)
|
||||
csv_path: null # if provided, the CSV file (with columns `id`, `input_receptor`, `input_ligand`, and `input_template`) from which to parse input receptors and ligands for sampling, overriding the `input_receptor` and `input_ligand` arguments in the process and ignoring the `input_template` for now
|
||||
esmfold_chunk_size: null # chunks axial attention computation to reduce memory usage from O(L^2) to O(L); equivalent to running a for loop over chunks of of each dimension; lower values will result in lower memory usage at the cost of speed; recommended values: 128, 64, 32
|
||||
|
||||
# model arguments
|
||||
model:
|
||||
cfg:
|
||||
mol_encoder:
|
||||
from_pretrained: false
|
||||
protein_encoder:
|
||||
from_pretrained: false
|
||||
relational_reasoning:
|
||||
from_pretrained: false
|
||||
contact_predictor:
|
||||
from_pretrained: false
|
||||
score_head:
|
||||
from_pretrained: false
|
||||
confidence:
|
||||
from_pretrained: false
|
||||
affinity:
|
||||
from_pretrained: false
|
||||
task:
|
||||
freeze_mol_encoder: true
|
||||
freeze_protein_encoder: false
|
||||
freeze_relational_reasoning: false
|
||||
freeze_contact_predictor: false
|
||||
freeze_score_head: false
|
||||
freeze_confidence: true
|
||||
freeze_affinity: false
|
||||
4
configs/strategy/ddp.yaml
Normal file
4
configs/strategy/ddp.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
_target_: lightning.pytorch.strategies.DDPStrategy
|
||||
static_graph: false
|
||||
gradient_as_bucket_view: false
|
||||
find_unused_parameters: true
|
||||
5
configs/strategy/ddp_spawn.yaml
Normal file
5
configs/strategy/ddp_spawn.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
_target_: lightning.pytorch.strategies.DDPStrategy
|
||||
static_graph: false
|
||||
gradient_as_bucket_view: false
|
||||
find_unused_parameters: true
|
||||
start_method: spawn
|
||||
5
configs/strategy/deepspeed.yaml
Normal file
5
configs/strategy/deepspeed.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
_target_: lightning.pytorch.strategies.DeepSpeedStrategy
|
||||
stage: 2
|
||||
offload_optimizer: false
|
||||
allgather_bucket_size: 200_000_000
|
||||
reduce_bucket_size: 200_000_000
|
||||
2
configs/strategy/default.yaml
Normal file
2
configs/strategy/default.yaml
Normal file
@@ -0,0 +1,2 @@
|
||||
defaults:
|
||||
- _self_
|
||||
12
configs/strategy/fsdp.yaml
Normal file
12
configs/strategy/fsdp.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
_target_: lightning.pytorch.strategies.FSDPStrategy
|
||||
sharding_strategy: ${resolve_variable:torch.distributed.fsdp.ShardingStrategy.FULL_SHARD}
|
||||
cpu_offload: null
|
||||
activation_checkpointing: null
|
||||
mixed_precision:
|
||||
_target_: torch.distributed.fsdp.MixedPrecision
|
||||
param_dtype: null
|
||||
reduce_dtype: null
|
||||
buffer_dtype: null
|
||||
keep_low_precision_grads: false
|
||||
cast_forward_inputs: false
|
||||
cast_root_forward_inputs: true
|
||||
4
configs/strategy/optimized_ddp.yaml
Normal file
4
configs/strategy/optimized_ddp.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
_target_: lightning.pytorch.strategies.DDPStrategy
|
||||
static_graph: true
|
||||
gradient_as_bucket_view: true
|
||||
find_unused_parameters: false
|
||||
51
configs/train.yaml
Normal file
51
configs/train.yaml
Normal file
@@ -0,0 +1,51 @@
|
||||
# @package _global_
|
||||
|
||||
# specify here default configuration
|
||||
# order of defaults determines the order in which configs override each other
|
||||
defaults:
|
||||
- _self_
|
||||
- data: combined
|
||||
- model: flowdock_fm
|
||||
- callbacks: default
|
||||
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
||||
- strategy: default
|
||||
- trainer: default
|
||||
- paths: default
|
||||
- extras: default
|
||||
- hydra: default
|
||||
- environment: default
|
||||
|
||||
# experiment configs allow for version control of specific hyperparameters
|
||||
# e.g. best hyperparameters for given model and datamodule
|
||||
- experiment: null
|
||||
|
||||
# config for hyperparameter optimization
|
||||
- hparams_search: null
|
||||
|
||||
# optional local config for machine/user specific settings
|
||||
# it's optional since it doesn't need to exist and is excluded from version control
|
||||
- optional local: default
|
||||
|
||||
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
||||
- debug: null
|
||||
|
||||
# task name, determines output directory path
|
||||
task_name: "train"
|
||||
|
||||
# tags to help you identify your experiments
|
||||
# you can overwrite this in experiment configs
|
||||
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
||||
tags: ["train", "combined", "flowdock_fm"]
|
||||
|
||||
# set False to skip model training
|
||||
train: True
|
||||
|
||||
# evaluate on test set, using best model weights achieved during training
|
||||
# lightning chooses best weights based on the metric specified in checkpoint callback
|
||||
test: False
|
||||
|
||||
# simply provide checkpoint path to resume training
|
||||
ckpt_path: null
|
||||
|
||||
# seed for random number generators in pytorch, numpy and python.random
|
||||
seed: null
|
||||
5
configs/trainer/cpu.yaml
Normal file
5
configs/trainer/cpu.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
accelerator: cpu
|
||||
devices: 1
|
||||
9
configs/trainer/ddp.yaml
Normal file
9
configs/trainer/ddp.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
strategy: ddp
|
||||
|
||||
accelerator: gpu
|
||||
devices: 4
|
||||
num_nodes: 1
|
||||
sync_batchnorm: True
|
||||
7
configs/trainer/ddp_sim.yaml
Normal file
7
configs/trainer/ddp_sim.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
# simulate DDP on CPU, useful for debugging
|
||||
accelerator: cpu
|
||||
devices: 2
|
||||
strategy: ddp_spawn
|
||||
9
configs/trainer/ddp_spawn.yaml
Normal file
9
configs/trainer/ddp_spawn.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
strategy: ddp_spawn
|
||||
|
||||
accelerator: gpu
|
||||
devices: 4
|
||||
num_nodes: 1
|
||||
sync_batchnorm: True
|
||||
29
configs/trainer/default.yaml
Normal file
29
configs/trainer/default.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
_target_: lightning.pytorch.trainer.Trainer
|
||||
|
||||
default_root_dir: ${paths.output_dir}
|
||||
|
||||
min_epochs: 1 # prevents early stopping
|
||||
max_epochs: 10
|
||||
|
||||
accelerator: cpu
|
||||
devices: 1
|
||||
|
||||
# mixed precision for extra speed-up
|
||||
# precision: 16
|
||||
|
||||
# perform a validation loop every N training epochs
|
||||
check_val_every_n_epoch: 1
|
||||
|
||||
# set True to to ensure deterministic results
|
||||
# makes training slower but gives more reproducibility than just setting seeds
|
||||
deterministic: False
|
||||
|
||||
# determine the frequency of how often to reload the dataloaders
|
||||
reload_dataloaders_every_n_epochs: 1
|
||||
|
||||
# if `gradient_clip_val` is not `null`, gradients will be norm-clipped during training
|
||||
gradient_clip_algorithm: norm
|
||||
gradient_clip_val: 1.0
|
||||
|
||||
# if `num_sanity_val_steps` is > 0, then specifically that many validation steps will be run during the first call to `trainer.fit`
|
||||
num_sanity_val_steps: 0
|
||||
5
configs/trainer/gpu.yaml
Normal file
5
configs/trainer/gpu.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
accelerator: gpu
|
||||
devices: 1
|
||||
5
configs/trainer/mps.yaml
Normal file
5
configs/trainer/mps.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
accelerator: mps
|
||||
devices: 1
|
||||
540
environments/flowdock_environment.yaml
Normal file
540
environments/flowdock_environment.yaml
Normal file
@@ -0,0 +1,540 @@
|
||||
name: FlowDock
|
||||
channels:
|
||||
- pytorch
|
||||
- pyg
|
||||
- senyan.dev
|
||||
- nvidia
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=conda_forge
|
||||
- _openmp_mutex=4.5=3_kmp_llvm
|
||||
- aiohappyeyeballs=2.6.1=pyhd8ed1ab_0
|
||||
- aiohttp=3.11.13=py310h89163eb_0
|
||||
- aiosignal=1.3.2=pyhd8ed1ab_0
|
||||
- alsa-lib=1.2.13=hb9d3cd8_0
|
||||
- ambertools=24.8=cuda_None_nompi_py310h834fefc_101
|
||||
- annotated-types=0.7.0=pyhd8ed1ab_1
|
||||
- anyio=4.8.0=pyhd8ed1ab_0
|
||||
- aom=3.9.1=hac33072_0
|
||||
- argon2-cffi=23.1.0=pyhd8ed1ab_1
|
||||
- argon2-cffi-bindings=21.2.0=py310ha75aee5_5
|
||||
- arpack=3.9.1=nompi_hf03ea27_102
|
||||
- arrow=1.3.0=pyhd8ed1ab_1
|
||||
- asttokens=3.0.0=pyhd8ed1ab_1
|
||||
- async-lru=2.0.4=pyhd8ed1ab_1
|
||||
- async-timeout=5.0.1=pyhd8ed1ab_1
|
||||
- attr=2.5.1=h166bdaf_1
|
||||
- attrs=25.3.0=pyh71513ae_0
|
||||
- babel=2.17.0=pyhd8ed1ab_0
|
||||
- beautifulsoup4=4.13.3=pyha770c72_0
|
||||
- blas=2.116=mkl
|
||||
- blas-devel=3.9.0=16_linux64_mkl
|
||||
- bleach=6.2.0=pyh29332c3_4
|
||||
- bleach-with-css=6.2.0=h82add2a_4
|
||||
- blosc=1.21.6=he440d0b_1
|
||||
- brotli=1.1.0=hb9d3cd8_2
|
||||
- brotli-bin=1.1.0=hb9d3cd8_2
|
||||
- brotli-python=1.1.0=py310hf71b8c6_2
|
||||
- bson=0.5.10=pyhd8ed1ab_0
|
||||
- bzip2=1.0.8=h4bc722e_7
|
||||
- c-ares=1.34.4=hb9d3cd8_0
|
||||
- c-blosc2=2.15.2=h3122c55_1
|
||||
- ca-certificates=2025.1.31=hbcca054_0
|
||||
- cached-property=1.5.2=hd8ed1ab_1
|
||||
- cached_property=1.5.2=pyha770c72_1
|
||||
- cachetools=5.5.2=pyhd8ed1ab_0
|
||||
- cairo=1.18.4=h3394656_0
|
||||
- certifi=2025.1.31=pyhd8ed1ab_0
|
||||
- cffi=1.17.1=py310h8deb56e_0
|
||||
- chardet=5.2.0=pyhd8ed1ab_3
|
||||
- charset-normalizer=3.4.1=pyhd8ed1ab_0
|
||||
- colorama=0.4.6=pyhd8ed1ab_1
|
||||
- comm=0.2.2=pyhd8ed1ab_1
|
||||
- contourpy=1.3.1=py310h3788b33_0
|
||||
- cpython=3.10.16=py310hd8ed1ab_1
|
||||
- cuda-cudart=11.8.89=0
|
||||
- cuda-cupti=11.8.87=0
|
||||
- cuda-libraries=11.8.0=0
|
||||
- cuda-nvrtc=11.8.89=0
|
||||
- cuda-nvtx=11.8.86=0
|
||||
- cuda-runtime=11.8.0=0
|
||||
- cuda-version=11.8=h70ddcb2_3
|
||||
- cudatoolkit=11.8.0=h4ba93d1_13
|
||||
- cudatoolkit-dev=11.8.0=h1fa729e_6
|
||||
- cycler=0.12.1=pyhd8ed1ab_1
|
||||
- cyrus-sasl=2.1.27=h54b06d7_7
|
||||
- dav1d=1.2.1=hd590300_0
|
||||
- dbus=1.13.6=h5008d03_3
|
||||
- debugpy=1.8.13=py310hf71b8c6_0
|
||||
- decorator=5.2.1=pyhd8ed1ab_0
|
||||
- defusedxml=0.7.1=pyhd8ed1ab_0
|
||||
- deprecated=1.2.18=pyhd8ed1ab_0
|
||||
- exceptiongroup=1.2.2=pyhd8ed1ab_1
|
||||
- expat=2.6.4=h5888daf_0
|
||||
- ffmpeg=7.1.1=gpl_h24e5c1d_701
|
||||
- fftw=3.3.10=nompi_hf1063bd_110
|
||||
- filelock=3.18.0=pyhd8ed1ab_0
|
||||
- flexcache=0.3=pyhd8ed1ab_1
|
||||
- flexparser=0.4=pyhd8ed1ab_1
|
||||
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
||||
- font-ttf-inconsolata=3.000=h77eed37_0
|
||||
- font-ttf-source-code-pro=2.038=h77eed37_0
|
||||
- font-ttf-ubuntu=0.83=h77eed37_3
|
||||
- fontconfig=2.15.0=h7e30c49_1
|
||||
- fonts-conda-ecosystem=1=0
|
||||
- fonts-conda-forge=1=0
|
||||
- fonttools=4.56.0=py310h89163eb_0
|
||||
- fqdn=1.5.1=pyhd8ed1ab_1
|
||||
- freetype=2.13.3=h48d6fc4_0
|
||||
- freetype-py=2.3.0=pyhd8ed1ab_0
|
||||
- fribidi=1.0.10=h36c2ea0_0
|
||||
- frozenlist=1.5.0=py310h89163eb_1
|
||||
- fsspec=2025.3.0=pyhd8ed1ab_0
|
||||
- gdk-pixbuf=2.42.12=hb9ae30d_0
|
||||
- gettext=0.23.1=h5888daf_0
|
||||
- gettext-tools=0.23.1=h5888daf_0
|
||||
- gmp=6.3.0=hac33072_2
|
||||
- gmpy2=2.1.5=py310he8512ff_3
|
||||
- graphite2=1.3.13=h59595ed_1003
|
||||
- greenlet=3.1.1=py310hf71b8c6_1
|
||||
- h11=0.14.0=pyhd8ed1ab_1
|
||||
- h2=4.2.0=pyhd8ed1ab_0
|
||||
- harfbuzz=10.4.0=h76408a6_0
|
||||
- hdf4=4.2.15=h2a13503_7
|
||||
- hdf5=1.14.4=nompi_h2d575fe_105
|
||||
- hpack=4.1.0=pyhd8ed1ab_0
|
||||
- httpcore=1.0.7=pyh29332c3_1
|
||||
- httpx=0.28.1=pyhd8ed1ab_0
|
||||
- hyperframe=6.1.0=pyhd8ed1ab_0
|
||||
- icu=75.1=he02047a_0
|
||||
- idna=3.10=pyhd8ed1ab_1
|
||||
- importlib-metadata=8.6.1=pyha770c72_0
|
||||
- importlib_resources=6.5.2=pyhd8ed1ab_0
|
||||
- ipykernel=6.29.5=pyh3099207_0
|
||||
- ipython=8.34.0=pyh907856f_0
|
||||
- isoduration=20.11.0=pyhd8ed1ab_1
|
||||
- jack=1.9.22=h7c63dc7_2
|
||||
- jedi=0.19.2=pyhd8ed1ab_1
|
||||
- jinja2=3.1.6=pyhd8ed1ab_0
|
||||
- joblib=1.4.2=pyhd8ed1ab_1
|
||||
- json5=0.10.0=pyhd8ed1ab_1
|
||||
- jsonpointer=3.0.0=py310hff52083_1
|
||||
- jsonschema=4.23.0=pyhd8ed1ab_1
|
||||
- jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
|
||||
- jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
|
||||
- jupyter-lsp=2.2.5=pyhd8ed1ab_1
|
||||
- jupyter_client=8.6.3=pyhd8ed1ab_1
|
||||
- jupyter_core=5.7.2=pyh31011fe_1
|
||||
- jupyter_events=0.12.0=pyh29332c3_0
|
||||
- jupyter_server=2.15.0=pyhd8ed1ab_0
|
||||
- jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
|
||||
- jupyterlab=4.3.6=pyhd8ed1ab_0
|
||||
- jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
|
||||
- jupyterlab_server=2.27.3=pyhd8ed1ab_1
|
||||
- jupyterlab_widgets=3.0.13=pyhd8ed1ab_1
|
||||
- kernel-headers_linux-64=3.10.0=he073ed8_18
|
||||
- keyutils=1.6.1=h166bdaf_0
|
||||
- kiwisolver=1.4.7=py310h3788b33_0
|
||||
- krb5=1.21.3=h659f571_0
|
||||
- lame=3.100=h166bdaf_1003
|
||||
- lcms2=2.17=h717163a_0
|
||||
- ld_impl_linux-64=2.43=h712a8e2_4
|
||||
- lerc=4.0.0=h27087fc_0
|
||||
- level-zero=1.21.2=h84d6215_0
|
||||
- libabseil=20250127.0=cxx17_hbbce691_0
|
||||
- libaec=1.1.3=h59595ed_0
|
||||
- libasprintf=0.23.1=h8e693c7_0
|
||||
- libasprintf-devel=0.23.1=h8e693c7_0
|
||||
- libass=0.17.3=hba53ac1_1
|
||||
- libblas=3.9.0=16_linux64_mkl
|
||||
- libboost=1.86.0=h6c02f8c_3
|
||||
- libboost-python=1.86.0=py310ha2bacc8_3
|
||||
- libbrotlicommon=1.1.0=hb9d3cd8_2
|
||||
- libbrotlidec=1.1.0=hb9d3cd8_2
|
||||
- libbrotlienc=1.1.0=hb9d3cd8_2
|
||||
- libcap=2.75=h39aace5_0
|
||||
- libcblas=3.9.0=16_linux64_mkl
|
||||
- libcublas=11.11.3.6=0
|
||||
- libcufft=10.9.0.58=0
|
||||
- libcufile=1.9.1.3=0
|
||||
- libcurand=10.3.5.147=0
|
||||
- libcurl=8.12.1=h332b0f4_0
|
||||
- libcusolver=11.4.1.48=0
|
||||
- libcusparse=11.7.5.86=0
|
||||
- libdb=6.2.32=h9c3ff4c_0
|
||||
- libdeflate=1.23=h4ddbbb0_0
|
||||
- libdrm=2.4.124=hb9d3cd8_0
|
||||
- libedit=3.1.20250104=pl5321h7949ede_0
|
||||
- libegl=1.7.0=ha4b6fd6_2
|
||||
- libev=4.33=hd590300_2
|
||||
- libexpat=2.6.4=h5888daf_0
|
||||
- libffi=3.4.6=h2dba641_0
|
||||
- libflac=1.4.3=h59595ed_0
|
||||
- libgcc=14.2.0=h767d61c_2
|
||||
- libgcc-ng=14.2.0=h69a702a_2
|
||||
- libgcrypt-lib=1.11.0=hb9d3cd8_2
|
||||
- libgettextpo=0.23.1=h5888daf_0
|
||||
- libgettextpo-devel=0.23.1=h5888daf_0
|
||||
- libgfortran=14.2.0=h69a702a_2
|
||||
- libgfortran-ng=14.2.0=h69a702a_2
|
||||
- libgfortran5=14.2.0=hf1ad2bd_2
|
||||
- libgl=1.7.0=ha4b6fd6_2
|
||||
- libglib=2.82.2=h2ff4ddf_1
|
||||
- libglvnd=1.7.0=ha4b6fd6_2
|
||||
- libglx=1.7.0=ha4b6fd6_2
|
||||
- libgomp=14.2.0=h767d61c_2
|
||||
- libgpg-error=1.51=hbd13f7d_1
|
||||
- libhwloc=2.11.2=default_h0d58e46_1001
|
||||
- libiconv=1.18=h4ce23a2_1
|
||||
- libjpeg-turbo=3.0.0=hd590300_1
|
||||
- liblapack=3.9.0=16_linux64_mkl
|
||||
- liblapacke=3.9.0=16_linux64_mkl
|
||||
- liblzma=5.6.4=hb9d3cd8_0
|
||||
- libnetcdf=4.9.2=nompi_h5ddbaa4_116
|
||||
- libnghttp2=1.64.0=h161d5f1_0
|
||||
- libnpp=11.8.0.86=0
|
||||
- libnsl=2.0.1=hd590300_0
|
||||
- libntlm=1.8=hb9d3cd8_0
|
||||
- libnvjpeg=11.9.0.86=0
|
||||
- libogg=1.3.5=h4ab18f5_0
|
||||
- libopenvino=2025.0.0=hdc3f47d_3
|
||||
- libopenvino-auto-batch-plugin=2025.0.0=h4d9b6c2_3
|
||||
- libopenvino-auto-plugin=2025.0.0=h4d9b6c2_3
|
||||
- libopenvino-hetero-plugin=2025.0.0=h981d57b_3
|
||||
- libopenvino-intel-cpu-plugin=2025.0.0=hdc3f47d_3
|
||||
- libopenvino-intel-gpu-plugin=2025.0.0=hdc3f47d_3
|
||||
- libopenvino-intel-npu-plugin=2025.0.0=hdc3f47d_3
|
||||
- libopenvino-ir-frontend=2025.0.0=h981d57b_3
|
||||
- libopenvino-onnx-frontend=2025.0.0=h0e684df_3
|
||||
- libopenvino-paddle-frontend=2025.0.0=h0e684df_3
|
||||
- libopenvino-pytorch-frontend=2025.0.0=h5888daf_3
|
||||
- libopenvino-tensorflow-frontend=2025.0.0=h684f15b_3
|
||||
- libopenvino-tensorflow-lite-frontend=2025.0.0=h5888daf_3
|
||||
- libopus=1.3.1=h7f98852_1
|
||||
- libpciaccess=0.18=hd590300_0
|
||||
- libpng=1.6.47=h943b412_0
|
||||
- libpq=17.4=h27ae623_0
|
||||
- libprotobuf=5.29.3=h501fc15_0
|
||||
- librdkit=2024.09.6=h84b0b3c_0
|
||||
- librsvg=2.58.4=h49af25d_2
|
||||
- libsndfile=1.2.2=hc60ed4a_1
|
||||
- libsodium=1.0.20=h4ab18f5_0
|
||||
- libsqlite=3.49.1=hee588c1_1
|
||||
- libssh2=1.11.1=hf672d98_0
|
||||
- libstdcxx=14.2.0=h8f9b012_2
|
||||
- libstdcxx-ng=14.2.0=h4852527_2
|
||||
- libsystemd0=257.4=h4e0b6ca_1
|
||||
- libtiff=4.7.0=hd9ff511_3
|
||||
- libudev1=257.4=hbe16f8c_1
|
||||
- libunwind=1.6.2=h9c3ff4c_0
|
||||
- liburing=2.9=h84d6215_0
|
||||
- libusb=1.0.27=hb9d3cd8_101
|
||||
- libuuid=2.38.1=h0b41bf4_0
|
||||
- libva=2.22.0=h4f16b4b_2
|
||||
- libvorbis=1.3.7=h9c3ff4c_0
|
||||
- libvpx=1.14.1=hac33072_0
|
||||
- libwebp-base=1.5.0=h851e524_0
|
||||
- libxcb=1.17.0=h8a09558_0
|
||||
- libxcrypt=4.4.36=hd590300_1
|
||||
- libxkbcommon=1.8.1=hc4a0caf_0
|
||||
- libxml2=2.13.6=h8d12d68_0
|
||||
- libxslt=1.1.39=h76b75d6_0
|
||||
- libzip=1.11.2=h6991a6a_0
|
||||
- libzlib=1.3.1=hb9d3cd8_2
|
||||
- llvm-openmp=15.0.7=h0cdce71_0
|
||||
- lxml=5.3.1=py310h6ee67d5_0
|
||||
- lz4-c=1.10.0=h5888daf_1
|
||||
- markupsafe=3.0.2=py310h89163eb_1
|
||||
- matplotlib-base=3.10.1=py310h68603db_0
|
||||
- matplotlib-inline=0.1.7=pyhd8ed1ab_1
|
||||
- mda-xdrlib=0.2.0=pyhd8ed1ab_1
|
||||
- mdtraj=1.10.3=py310h4cdbd58_0
|
||||
- mendeleev=0.20.1=pymin39_ha308f57_3
|
||||
- mistune=3.1.2=pyhd8ed1ab_0
|
||||
- mkl=2022.1.0=h84fe81f_915
|
||||
- mkl-devel=2022.1.0=ha770c72_916
|
||||
- mkl-include=2022.1.0=h84fe81f_915
|
||||
- mpc=1.3.1=h24ddda3_1
|
||||
- mpfr=4.2.1=h90cbb55_3
|
||||
- mpg123=1.32.9=hc50e24c_0
|
||||
- mpmath=1.3.0=pyhd8ed1ab_1
|
||||
- multidict=6.1.0=py310h89163eb_2
|
||||
- munkres=1.1.4=pyh9f0ad1d_0
|
||||
- nbclient=0.10.2=pyhd8ed1ab_0
|
||||
- nbconvert-core=7.16.6=pyh29332c3_0
|
||||
- nbformat=5.10.4=pyhd8ed1ab_1
|
||||
- ncurses=6.5=h2d0b736_3
|
||||
- nest-asyncio=1.6.0=pyhd8ed1ab_1
|
||||
- netcdf-fortran=4.6.1=nompi_ha5d1325_108
|
||||
- networkx=3.4.2=pyh267e887_2
|
||||
- notebook=7.3.3=pyhd8ed1ab_0
|
||||
- notebook-shim=0.2.4=pyhd8ed1ab_1
|
||||
- numexpr=2.7.3=py310hb5077e9_1
|
||||
- ocl-icd=2.3.2=hb9d3cd8_2
|
||||
- ocl-icd-system=1.0.0=1
|
||||
- opencl-headers=2024.10.24=h5888daf_0
|
||||
- openff-amber-ff-ports=0.0.4=pyhca7485f_0
|
||||
- openff-forcefields=2024.09.0=pyhff2d567_0
|
||||
- openff-interchange=0.4.2=pyhd8ed1ab_2
|
||||
- openff-interchange-base=0.4.2=pyhd8ed1ab_2
|
||||
- openff-toolkit=0.16.8=pyhd8ed1ab_2
|
||||
- openff-toolkit-base=0.16.8=pyhd8ed1ab_2
|
||||
- openff-units=0.3.0=pyhd8ed1ab_1
|
||||
- openff-utilities=0.1.15=pyhd8ed1ab_0
|
||||
- openh264=2.6.0=hc22cd8d_0
|
||||
- openjpeg=2.5.3=h5fbd93e_0
|
||||
- openldap=2.6.9=he970967_0
|
||||
- openmm=8.2.0=py310h30bdd6a_2
|
||||
- openmmforcefields=0.14.2=pyhd8ed1ab_0
|
||||
- openssl=3.4.1=h7b32b05_0
|
||||
- overrides=7.7.0=pyhd8ed1ab_1
|
||||
- packaging=24.2=pyhd8ed1ab_2
|
||||
- panedr=0.8.0=pyhd8ed1ab_1
|
||||
- pango=1.56.2=h861ebed_0
|
||||
- parmed=4.3.0=py310h78e4988_1
|
||||
- parso=0.8.4=pyhd8ed1ab_1
|
||||
- pcre2=10.44=hba22ea6_2
|
||||
- pdbfixer=1.11=pyhd8ed1ab_0
|
||||
- perl=5.32.1=7_hd590300_perl5
|
||||
- pexpect=4.9.0=pyhd8ed1ab_1
|
||||
- pickleshare=0.7.5=pyhd8ed1ab_1004
|
||||
- pillow=11.1.0=py310h7e6dc6c_0
|
||||
- pint=0.24.4=pyhd8ed1ab_1
|
||||
- pip=25.0.1=pyh8b19718_0
|
||||
- pixman=0.44.2=h29eaf8c_0
|
||||
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
|
||||
- platformdirs=4.3.6=pyhd8ed1ab_1
|
||||
- prometheus_client=0.21.1=pyhd8ed1ab_0
|
||||
- prompt-toolkit=3.0.50=pyha770c72_0
|
||||
- psutil=7.0.0=py310ha75aee5_0
|
||||
- pthread-stubs=0.4=hb9d3cd8_1002
|
||||
- ptyprocess=0.7.0=pyhd8ed1ab_1
|
||||
- pugixml=1.15=h3f63f65_0
|
||||
- pulseaudio-client=17.0=hb77b528_0
|
||||
- pure_eval=0.2.3=pyhd8ed1ab_1
|
||||
- py-cpuinfo=9.0.0=pyhd8ed1ab_1
|
||||
- pycairo=1.27.0=py310h25ff670_0
|
||||
- pycparser=2.22=pyh29332c3_1
|
||||
- pydantic=2.10.6=pyh3cfb1c2_0
|
||||
- pydantic-core=2.27.2=py310h505e2c1_0
|
||||
- pyedr=0.8.0=pyhd8ed1ab_1
|
||||
- pyfiglet=0.8.post1=py_0
|
||||
- pyg=2.5.2=py310_torch_2.2.0_cu118
|
||||
- pygments=2.19.1=pyhd8ed1ab_0
|
||||
- pyparsing=3.2.1=pyhd8ed1ab_0
|
||||
- pysocks=1.7.1=pyha55dd90_7
|
||||
- pytables=3.10.1=py310h431dcdc_4
|
||||
- python=3.10.16=he725a3c_1_cpython
|
||||
- python-constraint=1.4.0=pyhff2d567_1
|
||||
- python-dateutil=2.9.0.post0=pyhff2d567_1
|
||||
- python-fastjsonschema=2.21.1=pyhd8ed1ab_0
|
||||
- python-tzdata=2025.1=pyhd8ed1ab_0
|
||||
- python_abi=3.10=5_cp310
|
||||
- pytorch=2.2.1=py3.10_cuda11.8_cudnn8.7.0_0
|
||||
- pytorch-cuda=11.8=h7e8668a_6
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pytorch-scatter=2.1.2=py310_torch_2.2.0_cu118
|
||||
- pytz=2025.1=pyhd8ed1ab_0
|
||||
- pyyaml=6.0.2=py310h89163eb_2
|
||||
- pyzmq=26.3.0=py310h71f11fc_0
|
||||
- qhull=2020.2=h434a139_5
|
||||
- rdkit=2024.09.6=py310hcd13295_0
|
||||
- readline=8.2=h8c095d6_2
|
||||
- referencing=0.36.2=pyh29332c3_0
|
||||
- reportlab=4.3.1=py310ha75aee5_0
|
||||
- requests=2.32.3=pyhd8ed1ab_1
|
||||
- rfc3339-validator=0.1.4=pyhd8ed1ab_1
|
||||
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
|
||||
- rlpycairo=0.2.0=pyhd8ed1ab_0
|
||||
- rpds-py=0.23.1=py310hc1293b2_0
|
||||
- scikit-learn=1.6.1=py310h27f47ee_0
|
||||
- scipy=1.15.2=py310h1d65ade_0
|
||||
- sdl2=2.32.50=h9b8e6db_1
|
||||
- sdl3=3.2.8=h3083f51_0
|
||||
- send2trash=1.8.3=pyh0d859eb_1
|
||||
- setuptools=75.8.2=pyhff2d567_0
|
||||
- six=1.17.0=pyhd8ed1ab_0
|
||||
- smirnoff99frosst=1.1.0=pyh44b312d_0
|
||||
- snappy=1.2.1=h8bd8927_1
|
||||
- sniffio=1.3.1=pyhd8ed1ab_1
|
||||
- sqlalchemy=2.0.39=py310ha75aee5_1
|
||||
- stack_data=0.6.3=pyhd8ed1ab_1
|
||||
- svt-av1=3.0.1=h5888daf_0
|
||||
- sympy=1.13.3=pyh2585a3b_105
|
||||
- sysroot_linux-64=2.17=h0157908_18
|
||||
- tbb=2021.13.0=hceb3a55_1
|
||||
- terminado=0.18.1=pyh0d859eb_0
|
||||
- threadpoolctl=3.6.0=pyhecae5ae_0
|
||||
- tinycss2=1.4.0=pyhd8ed1ab_0
|
||||
- tinydb=4.8.2=pyhd8ed1ab_1
|
||||
- tk=8.6.13=noxft_h4845f30_101
|
||||
- tomli=2.2.1=pyhd8ed1ab_1
|
||||
- torchaudio=2.2.1=py310_cu118
|
||||
- torchtriton=2.2.0=py310
|
||||
- torchvision=0.17.1=py310_cu118
|
||||
- tornado=6.4.2=py310ha75aee5_0
|
||||
- tqdm=4.67.1=pyhd8ed1ab_1
|
||||
- traitlets=5.14.3=pyhd8ed1ab_1
|
||||
- types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
|
||||
- typing-extensions=4.12.2=hd8ed1ab_1
|
||||
- typing_extensions=4.12.2=pyha770c72_1
|
||||
- typing_utils=0.1.0=pyhd8ed1ab_1
|
||||
- tzdata=2025a=h78e105d_0
|
||||
- unicodedata2=16.0.0=py310ha75aee5_0
|
||||
- uri-template=1.3.0=pyhd8ed1ab_1
|
||||
- urllib3=2.3.0=pyhd8ed1ab_0
|
||||
- validators=0.34.0=pyhd8ed1ab_1
|
||||
- wayland=1.23.1=h3e06ad9_0
|
||||
- wayland-protocols=1.41=hd8ed1ab_0
|
||||
- wcwidth=0.2.13=pyhd8ed1ab_1
|
||||
- webcolors=24.11.1=pyhd8ed1ab_0
|
||||
- webencodings=0.5.1=pyhd8ed1ab_3
|
||||
- websocket-client=1.8.0=pyhd8ed1ab_1
|
||||
- wheel=0.45.1=pyhd8ed1ab_1
|
||||
- wrapt=1.17.2=py310ha75aee5_0
|
||||
- x264=1!164.3095=h166bdaf_2
|
||||
- x265=3.5=h924138e_3
|
||||
- xkeyboard-config=2.43=hb9d3cd8_0
|
||||
- xmltodict=0.14.2=pyhd8ed1ab_1
|
||||
- xorg-libice=1.1.2=hb9d3cd8_0
|
||||
- xorg-libsm=1.2.6=he73a12e_0
|
||||
- xorg-libx11=1.8.12=h4f16b4b_0
|
||||
- xorg-libxau=1.0.12=hb9d3cd8_0
|
||||
- xorg-libxcursor=1.2.3=hb9d3cd8_0
|
||||
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
|
||||
- xorg-libxext=1.3.6=hb9d3cd8_0
|
||||
- xorg-libxfixes=6.0.1=hb9d3cd8_0
|
||||
- xorg-libxrender=0.9.12=hb9d3cd8_0
|
||||
- xorg-libxscrnsaver=1.2.4=hb9d3cd8_0
|
||||
- xorg-libxt=1.3.1=hb9d3cd8_0
|
||||
- yaml=0.2.5=h7f98852_2
|
||||
- yarl=1.18.3=py310h89163eb_1
|
||||
- zeromq=4.3.5=h3b0a872_7
|
||||
- zipp=3.21.0=pyhd8ed1ab_1
|
||||
- zlib=1.3.1=hb9d3cd8_2
|
||||
- zlib-ng=2.2.4=h7955e40_0
|
||||
- zstandard=0.23.0=py310ha75aee5_1
|
||||
- zstd=1.5.7=hb8e6e7a_1
|
||||
- pip:
|
||||
- absl-py==2.1.0
|
||||
- alembic==1.15.1
|
||||
- amberutils==21.0
|
||||
- antlr4-python3-runtime==4.9.3
|
||||
- autopage==0.5.2
|
||||
- beartype==0.20.0
|
||||
- biopandas==0.5.1
|
||||
- biopython==1.79
|
||||
- biotite==1.1.0
|
||||
- biotraj==1.2.2
|
||||
- cfgv==3.4.0
|
||||
- cftime==1.6.4.post1
|
||||
- click==8.1.8
|
||||
- cliff==4.9.1
|
||||
- cloudpathlib==0.21.0
|
||||
- cmaes==0.11.1
|
||||
- cmd2==2.5.11
|
||||
- colorlog==6.9.0
|
||||
- distlib==0.3.9
|
||||
- git+https://github.com/NVIDIA/dllogger.git@0540a43971f4a8a16693a9de9de73c1072020769
|
||||
- dm-tree==0.1.9
|
||||
- docker-pycreds==0.4.0
|
||||
- duckdb==1.2.1
|
||||
- edgembar==3.0
|
||||
- einops==0.8.1
|
||||
- eval-type-backport==0.2.2
|
||||
- executing==2.2.0
|
||||
- fair-esm==2.0.0
|
||||
- fairscale==0.4.13
|
||||
- fastcore==1.7.29
|
||||
- future==1.0.0
|
||||
- fvcore==0.1.5.post20221221
|
||||
- gcsfs==2025.3.0
|
||||
- gemmi==0.7.0
|
||||
- gitdb==4.0.12
|
||||
- gitpython==3.1.44
|
||||
- google-api-core==2.24.2
|
||||
- google-auth==2.38.0
|
||||
- google-auth-oauthlib==1.2.1
|
||||
- google-cloud-core==2.4.3
|
||||
- google-cloud-storage==3.1.0
|
||||
- google-crc32c==1.6.0
|
||||
- google-resumable-media==2.7.2
|
||||
- googleapis-common-protos==1.69.1
|
||||
- hydra-colorlog==1.2.0
|
||||
- hydra-core==1.3.2
|
||||
- hydra-optuna-sweeper==1.2.0
|
||||
- identify==2.6.9
|
||||
- iniconfig==2.0.0
|
||||
- iopath==0.1.10
|
||||
- ipython-genutils==0.2.0
|
||||
- ipywidgets==7.8.5
|
||||
- jupyterlab-widgets==1.1.11
|
||||
- lightning==2.5.0.post0
|
||||
- lightning-utilities==0.14.1
|
||||
- looseversion==1.1.2
|
||||
- lovely-numpy==0.2.13
|
||||
- lovely-tensors==0.1.18
|
||||
- mako==1.3.9
|
||||
- markdown-it-py==3.0.0
|
||||
- mdurl==0.1.2
|
||||
- ml-collections==1.0.0
|
||||
- mmcif==0.91.0
|
||||
- mmpbsa-py==16.0
|
||||
- mmtf-python==1.1.3
|
||||
- mols2grid==2.0.0
|
||||
- msgpack==1.1.0
|
||||
- msgpack-numpy==0.4.8
|
||||
- narwhals==1.30.0
|
||||
- netcdf4==1.7.2
|
||||
- nodeenv==1.9.1
|
||||
- numpy==1.26.4
|
||||
- oauthlib==3.2.2
|
||||
- omegaconf==2.3.0
|
||||
- git+https://github.com/amorehead/openfold.git@fe1275099639bf7e617e09ef24d6af778647dd64
|
||||
- optuna==2.10.1
|
||||
- packmol-memgen==2025.1.29
|
||||
- pandas==2.2.3
|
||||
- pandocfilters==1.5.1
|
||||
- pbr==6.1.1
|
||||
- pdb4amber==22.0
|
||||
- plinder==0.2.24
|
||||
- plotly==6.0.0
|
||||
- pluggy==1.5.0
|
||||
- portalocker==3.1.1
|
||||
- posebusters==0.2.13
|
||||
- git+https://git@github.com/zrqiao/power_spherical.git@290b1630c5f84e3bb0d61711046edcf6e47200d4
|
||||
- pre-commit==4.1.0
|
||||
- prettytable==3.15.1
|
||||
# - prody==2.4.1 # NOTE: we must `pip` install Prody to skip its NumPy dependency
|
||||
- propcache==0.3.0
|
||||
- proto-plus==1.26.1
|
||||
- protobuf==5.29.3
|
||||
- pyarrow==19.0.1
|
||||
- pyasn1==0.6.1
|
||||
- pyasn1-modules==0.4.1
|
||||
- pymsmt==22.0
|
||||
- pyperclip==1.9.0
|
||||
- pytest==8.3.5
|
||||
- python-dotenv==1.0.1
|
||||
- python-json-logger==3.3.0
|
||||
- pytorch-lightning==2.5.0.post0
|
||||
- git+https://github.com/facebookresearch/pytorch3d.git@3da7703c5ac10039645966deddffe8db52eab8c5
|
||||
- pytraj==2.0.6
|
||||
- requests-oauthlib==2.0.0
|
||||
- rich==13.9.4
|
||||
- rootutils==1.0.7
|
||||
- rsa==4.9
|
||||
- sander==22.0
|
||||
- seaborn==0.13.2
|
||||
- sentry-sdk==2.22.0
|
||||
- setproctitle==1.3.5
|
||||
- smmap==5.0.2
|
||||
- soupsieve==2.6
|
||||
- stevedore==5.4.1
|
||||
- tabulate==0.9.0
|
||||
- termcolor==2.5.0
|
||||
- torchmetrics==1.6.3
|
||||
- virtualenv==20.29.3
|
||||
- wandb==0.19.8
|
||||
- widgetsnbextension==3.6.10
|
||||
- yacs==0.1.8
|
||||
58
environments/flowdock_environment_docker.yaml
Normal file
58
environments/flowdock_environment_docker.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
name: flowdock
|
||||
channels:
|
||||
- pyg
|
||||
- pytorch
|
||||
- nvidia
|
||||
- defaults
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- mendeleev=0.20.1=pymin39_ha308f57_3
|
||||
- networkx=3.4.2=pyh267e887_2
|
||||
- python=3.10.16=he725a3c_1_cpython
|
||||
- pytorch=2.2.1=py3.10_cuda11.8_cudnn8.7.0_0
|
||||
- pytorch-cuda=11.8=h7e8668a_6
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pytorch-scatter=2.1.2=py310_torch_2.2.0_cu118
|
||||
- rdkit=2024.09.6=py310hcd13295_0
|
||||
- scikit-learn=1.6.1=py310h27f47ee_0
|
||||
- scipy=1.15.2=py310h1d65ade_0
|
||||
- torchaudio=2.2.1=py310_cu118
|
||||
- torchtriton=2.2.0=py310
|
||||
- torchvision=0.17.1=py310_cu118
|
||||
- tqdm=4.67.1=pyhd8ed1ab_1
|
||||
- pip:
|
||||
- beartype==0.20.0
|
||||
- biopandas==0.5.1
|
||||
- biopython==1.79
|
||||
- biotite==1.1.0
|
||||
- git+https://github.com/NVIDIA/dllogger.git@0540a43971f4a8a16693a9de9de73c1072020769
|
||||
- dm-tree==0.1.9
|
||||
- einops==0.8.1
|
||||
- fair-esm==2.0.0
|
||||
- fairscale==0.4.13
|
||||
- gemmi==0.7.0
|
||||
- hydra-colorlog==1.2.0
|
||||
- hydra-core==1.3.2
|
||||
- hydra-optuna-sweeper==1.2.0
|
||||
- lightning==2.5.0.post0
|
||||
- lightning-utilities==0.14.1
|
||||
- lovely-numpy==0.2.13
|
||||
- lovely-tensors==0.1.18
|
||||
- ml-collections==1.0.0
|
||||
- msgpack==1.1.0
|
||||
- msgpack-numpy==0.4.8
|
||||
- numpy==1.26.4
|
||||
- omegaconf==2.3.0
|
||||
- git+https://github.com/amorehead/openfold.git@fe1275099639bf7e617e09ef24d6af778647dd64
|
||||
- pandas==2.2.3
|
||||
- plinder==0.2.24
|
||||
- plotly==6.0.0
|
||||
- posebusters==0.2.13
|
||||
# - prody==2.4.1 # NOTE: we must `pip` install Prody to skip its NumPy dependency
|
||||
- pytorch-lightning==2.5.0.post0
|
||||
- git+https://github.com/facebookresearch/pytorch3d.git@3da7703c5ac10039645966deddffe8db52eab8c5
|
||||
- rich==13.9.4
|
||||
- rootutils==1.0.7
|
||||
- seaborn==0.13.2
|
||||
- torchmetrics==1.6.3
|
||||
- wandb==0.19.8
|
||||
120
flowdock/__init__.py
Normal file
120
flowdock/__init__.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from beartype.typing import Any
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
METHOD_TITLE_MAPPING = {
|
||||
"diffdock": "DiffDock",
|
||||
"flowdock": "FlowDock",
|
||||
"neuralplexer": "NeuralPLexer",
|
||||
}
|
||||
|
||||
STANDARDIZED_DIR_METHODS = ["diffdock"]
|
||||
|
||||
|
||||
def resolve_omegaconf_variable(variable_path: str) -> Any:
|
||||
"""Resolve an OmegaConf variable path to its value."""
|
||||
# split the string into parts using the dot separator
|
||||
parts = variable_path.rsplit(".", 1)
|
||||
|
||||
# get the module name from the first part of the path
|
||||
module_name = parts[0]
|
||||
|
||||
# dynamically import the module using the module name
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
# use the imported module to get the requested attribute value
|
||||
attribute = getattr(module, parts[1])
|
||||
except Exception:
|
||||
module = importlib.import_module(".".join(module_name.split(".")[:-1]))
|
||||
inner_module = ".".join(module_name.split(".")[-1:])
|
||||
# use the imported module to get the requested attribute value
|
||||
attribute = getattr(getattr(module, inner_module), parts[1])
|
||||
|
||||
return attribute
|
||||
|
||||
|
||||
def resolve_dataset_path_dirname(dataset: str) -> str:
|
||||
"""Resolve the dataset path directory name based on the dataset's name.
|
||||
|
||||
:param dataset: Name of the dataset.
|
||||
:return: Directory name for the dataset path.
|
||||
"""
|
||||
return "DockGen" if dataset == "dockgen" else dataset
|
||||
|
||||
|
||||
def resolve_method_input_csv_path(method: str, dataset: str) -> str:
|
||||
"""Resolve the input CSV path for a given method.
|
||||
|
||||
:param method: The method name.
|
||||
:param dataset: The dataset name.
|
||||
:return: The input CSV path for the given method.
|
||||
"""
|
||||
if method in STANDARDIZED_DIR_METHODS or method in ["flowdock", "neuralplexer"]:
|
||||
return os.path.join(
|
||||
"forks",
|
||||
METHOD_TITLE_MAPPING.get(method, method),
|
||||
"inference",
|
||||
f"{method}_{dataset}_inputs.csv",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
|
||||
def resolve_method_title(method: str) -> str:
|
||||
"""Resolve the method title for a given method.
|
||||
|
||||
:param method: The method name.
|
||||
:return: The method title for the given method.
|
||||
"""
|
||||
return METHOD_TITLE_MAPPING.get(method, method)
|
||||
|
||||
|
||||
def resolve_method_output_dir(
|
||||
method: str,
|
||||
dataset: str,
|
||||
repeat_index: int,
|
||||
) -> str:
|
||||
"""Resolve the output directory for a given method.
|
||||
|
||||
:param method: The method name.
|
||||
:param dataset: The dataset name.
|
||||
:param repeat_index: The repeat index for the method.
|
||||
:return: The output directory for the given method.
|
||||
"""
|
||||
if method in STANDARDIZED_DIR_METHODS or method in ["flowdock", "neuralplexer"]:
|
||||
return os.path.join(
|
||||
"forks",
|
||||
METHOD_TITLE_MAPPING.get(method, method),
|
||||
"inference",
|
||||
f"{method}_{dataset}_output{'s' if method in ['flowdock', 'neuralplexer'] else ''}_{repeat_index}",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
|
||||
def register_custom_omegaconf_resolvers():
|
||||
"""Register custom OmegaConf resolvers."""
|
||||
OmegaConf.register_new_resolver(
|
||||
"resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path)
|
||||
)
|
||||
OmegaConf.register_new_resolver(
|
||||
"resolve_dataset_path_dirname", lambda dataset: resolve_dataset_path_dirname(dataset)
|
||||
)
|
||||
OmegaConf.register_new_resolver(
|
||||
"resolve_method_input_csv_path",
|
||||
lambda method, dataset: resolve_method_input_csv_path(method, dataset),
|
||||
)
|
||||
OmegaConf.register_new_resolver(
|
||||
"resolve_method_title", lambda method: resolve_method_title(method)
|
||||
)
|
||||
OmegaConf.register_new_resolver(
|
||||
"resolve_method_output_dir",
|
||||
lambda method, dataset, repeat_index: resolve_method_output_dir(
|
||||
method, dataset, repeat_index
|
||||
),
|
||||
)
|
||||
OmegaConf.register_new_resolver(
|
||||
"int_divide", lambda dividend, divisor: int(dividend) // int(divisor)
|
||||
)
|
||||
165
flowdock/eval.py
Normal file
165
flowdock/eval.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
|
||||
import hydra
|
||||
import lightning as L
|
||||
import lovely_tensors as lt
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, List, Tuple
|
||||
from lightning import LightningDataModule, LightningModule, Trainer
|
||||
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from lightning.pytorch.strategies.strategy import Strategy
|
||||
from omegaconf import DictConfig, open_dict
|
||||
|
||||
lt.monkey_patch()
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# the setup_root above is equivalent to:
|
||||
# - adding project root dir to PYTHONPATH
|
||||
# (so you don't need to force user to install project as a package)
|
||||
# (necessary before importing any local modules e.g. `from flowdock import utils`)
|
||||
# - setting up PROJECT_ROOT environment variable
|
||||
# (which is used as a base for paths in "configs/paths/default.yaml")
|
||||
# (this way all filepaths are the same no matter where you run the code)
|
||||
# - loading environment variables from ".env" in root dir
|
||||
#
|
||||
# you can remove it if you:
|
||||
# 1. either install project as a package or move entry files to project root dir
|
||||
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
||||
#
|
||||
# more info: https://github.com/ashleve/rootutils
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
|
||||
from flowdock.utils import (
|
||||
RankedLogger,
|
||||
extras,
|
||||
instantiate_loggers,
|
||||
log_hyperparameters,
|
||||
task_wrapper,
|
||||
)
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@task_wrapper
|
||||
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Evaluates given checkpoint on a datamodule testset.
|
||||
|
||||
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
||||
failure. Useful for multiruns, saving info about the crash, etc.
|
||||
|
||||
:param cfg: DictConfig configuration composed by Hydra.
|
||||
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
|
||||
"""
|
||||
assert cfg.ckpt_path, "Please provide a checkpoint path to evaluate!"
|
||||
assert os.path.exists(cfg.ckpt_path), f"Checkpoint path {cfg.ckpt_path} does not exist!"
|
||||
|
||||
# set seed for random number generators in pytorch, numpy and python.random
|
||||
if cfg.get("seed"):
|
||||
L.seed_everything(cfg.seed, workers=True)
|
||||
|
||||
log.info(
|
||||
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
|
||||
)
|
||||
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
|
||||
|
||||
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
||||
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, stage="test")
|
||||
|
||||
# Establish model input arguments
|
||||
with open_dict(cfg):
|
||||
if cfg.model.cfg.task.start_time == "auto":
|
||||
cfg.model.cfg.task.start_time = 1.0
|
||||
else:
|
||||
cfg.model.cfg.task.start_time = float(cfg.model.cfg.task.start_time)
|
||||
|
||||
log.info(f"Instantiating model <{cfg.model._target_}>")
|
||||
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
||||
|
||||
log.info("Instantiating loggers...")
|
||||
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
||||
|
||||
plugins = None
|
||||
if "_target_" in cfg.environment:
|
||||
log.info(f"Instantiating environment <{cfg.environment._target_}>")
|
||||
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
|
||||
|
||||
strategy = getattr(cfg.trainer, "strategy", None)
|
||||
if "_target_" in cfg.strategy:
|
||||
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
|
||||
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
|
||||
if (
|
||||
"mixed_precision" in strategy.__dict__
|
||||
and getattr(strategy, "mixed_precision", None) is not None
|
||||
):
|
||||
strategy.mixed_precision.param_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
strategy.mixed_precision.reduce_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
strategy.mixed_precision.buffer_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
||||
trainer: Trainer = (
|
||||
hydra.utils.instantiate(
|
||||
cfg.trainer,
|
||||
logger=logger,
|
||||
plugins=plugins,
|
||||
strategy=strategy,
|
||||
)
|
||||
if strategy is not None
|
||||
else hydra.utils.instantiate(
|
||||
cfg.trainer,
|
||||
logger=logger,
|
||||
plugins=plugins,
|
||||
)
|
||||
)
|
||||
|
||||
object_dict = {
|
||||
"cfg": cfg,
|
||||
"datamodule": datamodule,
|
||||
"model": model,
|
||||
"logger": logger,
|
||||
"trainer": trainer,
|
||||
}
|
||||
|
||||
if logger:
|
||||
log.info("Logging hyperparameters!")
|
||||
log_hyperparameters(object_dict)
|
||||
|
||||
log.info("Starting testing!")
|
||||
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
|
||||
|
||||
metric_dict = trainer.callback_metrics
|
||||
|
||||
return metric_dict, object_dict
|
||||
|
||||
|
||||
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
"""Main entry point for evaluation.
|
||||
|
||||
:param cfg: DictConfig configuration composed by Hydra.
|
||||
"""
|
||||
# apply extra utilities
|
||||
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
||||
extras(cfg)
|
||||
|
||||
evaluate(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_custom_omegaconf_resolvers()
|
||||
main()
|
||||
0
flowdock/models/__init__.py
Normal file
0
flowdock/models/__init__.py
Normal file
0
flowdock/models/components/__init__.py
Normal file
0
flowdock/models/components/__init__.py
Normal file
452
flowdock/models/components/callbacks/ema.py
Normal file
452
flowdock/models/components/callbacks/ema.py
Normal file
@@ -0,0 +1,452 @@
|
||||
# Adapted from https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
import lightning.pytorch as pl
|
||||
import torch
|
||||
from lightning.pytorch import Callback
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
||||
|
||||
|
||||
class EMA(Callback):
|
||||
"""Implements Exponential Moving Averaging (EMA).
|
||||
|
||||
When training a model, this callback will maintain moving averages of the trained parameters.
|
||||
When evaluating, we use the moving averages copy of the trained parameters.
|
||||
When saving, we save an additional set of parameters with the prefix `ema`.
|
||||
|
||||
Args:
|
||||
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
|
||||
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
|
||||
every_n_steps: Apply EMA every N steps.
|
||||
cpu_offload: Offload weights to CPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decay: float,
|
||||
validate_original_weights: bool = False,
|
||||
every_n_steps: int = 1,
|
||||
cpu_offload: bool = False,
|
||||
):
|
||||
if not (0 <= decay <= 1):
|
||||
raise MisconfigurationException("EMA decay value must be between 0 and 1")
|
||||
self.decay = decay
|
||||
self.validate_original_weights = validate_original_weights
|
||||
self.every_n_steps = every_n_steps
|
||||
self.cpu_offload = cpu_offload
|
||||
|
||||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Add the EMA optimizer to the trainer."""
|
||||
device = pl_module.device if not self.cpu_offload else torch.device("cpu")
|
||||
trainer.optimizers = [
|
||||
EMAOptimizer(
|
||||
optim,
|
||||
device=device,
|
||||
decay=self.decay,
|
||||
every_n_steps=self.every_n_steps,
|
||||
current_step=trainer.global_step,
|
||||
)
|
||||
for optim in trainer.optimizers
|
||||
if not isinstance(optim, EMAOptimizer)
|
||||
]
|
||||
|
||||
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Swap the model weights with the EMA weights."""
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Swap the model weights back to the original weights."""
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Swap the model weights with the EMA weights."""
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Swap the model weights back to the original weights."""
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
|
||||
"""Check if the EMA weights should be validated."""
|
||||
return not self.validate_original_weights and self._ema_initialized(trainer)
|
||||
|
||||
def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
|
||||
"""Check if the EMA weights have been initialized."""
|
||||
return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers)
|
||||
|
||||
def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
|
||||
"""Swaps the model weights with the EMA weights."""
|
||||
for optimizer in trainer.optimizers:
|
||||
assert isinstance(optimizer, EMAOptimizer)
|
||||
optimizer.switch_main_parameter_weights(saving_ema_model)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_ema_model(self, trainer: "pl.Trainer"):
|
||||
"""Saves an EMA copy of the model + EMA optimizer states for resume."""
|
||||
self.swap_model_weights(trainer, saving_ema_model=True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.swap_model_weights(trainer, saving_ema_model=False)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_original_optimizer_state(self, trainer: "pl.Trainer"):
|
||||
"""Save the original optimizer state."""
|
||||
for optimizer in trainer.optimizers:
|
||||
assert isinstance(optimizer, EMAOptimizer)
|
||||
optimizer.save_original_optimizer_state = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for optimizer in trainer.optimizers:
|
||||
optimizer.save_original_optimizer_state = False
|
||||
|
||||
def on_load_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Load the EMA state from the checkpoint if it exists."""
|
||||
checkpoint_callback = trainer.checkpoint_callback
|
||||
|
||||
# use the connector as NeMo calls the connector directly in the exp_manager when restoring.
|
||||
connector = trainer._checkpoint_connector
|
||||
# Replace connector._ckpt_path with below to avoid calling into lightning's protected API
|
||||
ckpt_path = trainer.ckpt_path
|
||||
|
||||
if (
|
||||
ckpt_path
|
||||
and checkpoint_callback is not None
|
||||
and "EMA" in type(checkpoint_callback).__name__
|
||||
):
|
||||
ext = checkpoint_callback.FILE_EXTENSION
|
||||
if ckpt_path.endswith(f"-EMA{ext}"):
|
||||
rank_zero_info(
|
||||
"loading EMA based weights. "
|
||||
"The callback will treat the loaded EMA weights as the main weights"
|
||||
" and create a new EMA copy when training."
|
||||
)
|
||||
return
|
||||
ema_path = ckpt_path.replace(ext, f"-EMA{ext}")
|
||||
if os.path.exists(ema_path):
|
||||
ema_state_dict = torch.load(ema_path, map_location=torch.device("cpu"))
|
||||
|
||||
checkpoint["optimizer_states"] = ema_state_dict["optimizer_states"]
|
||||
del ema_state_dict
|
||||
rank_zero_info("EMA state has been restored.")
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
"Unable to find the associated EMA weights when re-loading, "
|
||||
f"training will start with new EMA weights. Expected them to be at: {ema_path}",
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_model_tuple, current_model_tuple, decay):
|
||||
"""Update the EMA model with the current model."""
|
||||
torch._foreach_mul_(ema_model_tuple, decay)
|
||||
torch._foreach_add_(
|
||||
ema_model_tuple,
|
||||
current_model_tuple,
|
||||
alpha=(1.0 - decay),
|
||||
)
|
||||
|
||||
|
||||
def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None):
|
||||
"""Run EMA update on CPU."""
|
||||
if pre_sync_stream is not None:
|
||||
pre_sync_stream.synchronize()
|
||||
|
||||
ema_update(ema_model_tuple, current_model_tuple, decay)
|
||||
|
||||
|
||||
class EMAOptimizer(torch.optim.Optimizer):
|
||||
r"""EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average
|
||||
of parameters registered in the optimizer.
|
||||
|
||||
EMA parameters are automatically updated after every step of the optimizer
|
||||
with the following formula:
|
||||
|
||||
ema_weight = decay * ema_weight + (1 - decay) * training_weight
|
||||
|
||||
To access EMA parameters, use ``swap_ema_weights()`` context manager to
|
||||
perform a temporary in-place swap of regular parameters with EMA
|
||||
parameters.
|
||||
|
||||
Notes:
|
||||
- EMAOptimizer is not compatible with APEX AMP O2.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): optimizer to wrap
|
||||
device (torch.device): device for EMA parameters
|
||||
decay (float): decay factor
|
||||
|
||||
Returns:
|
||||
returns an instance of torch.optim.Optimizer that computes EMA of
|
||||
parameters
|
||||
|
||||
Example:
|
||||
model = Model().to(device)
|
||||
opt = torch.optim.Adam(model.parameters())
|
||||
|
||||
opt = EMAOptimizer(opt, device, 0.9999)
|
||||
|
||||
for epoch in range(epochs):
|
||||
training_loop(model, opt)
|
||||
|
||||
regular_eval_accuracy = evaluate(model)
|
||||
|
||||
with opt.swap_ema_weights():
|
||||
ema_eval_accuracy = evaluate(model)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
decay: float = 0.9999,
|
||||
every_n_steps: int = 1,
|
||||
current_step: int = 0,
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
self.decay = decay
|
||||
self.device = device
|
||||
self.current_step = current_step
|
||||
self.every_n_steps = every_n_steps
|
||||
self.save_original_optimizer_state = False
|
||||
|
||||
self.first_iteration = True
|
||||
self.rebuild_ema_params = True
|
||||
self.stream = None
|
||||
self.thread = None
|
||||
|
||||
self.ema_params = ()
|
||||
self.in_saving_ema_model_context = False
|
||||
|
||||
def all_parameters(self) -> Iterable[torch.Tensor]:
|
||||
"""Return an iterator over all parameters in the optimizer."""
|
||||
return (param for group in self.param_groups for param in group["params"])
|
||||
|
||||
def step(self, closure=None, grad_scaler=None, **kwargs):
|
||||
"""Perform a single optimization step."""
|
||||
self.join()
|
||||
|
||||
if self.first_iteration:
|
||||
if any(p.is_cuda for p in self.all_parameters()):
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
self.first_iteration = False
|
||||
|
||||
if self.rebuild_ema_params:
|
||||
opt_params = list(self.all_parameters())
|
||||
|
||||
self.ema_params += tuple(
|
||||
copy.deepcopy(param.data.detach()).to(self.device)
|
||||
for param in opt_params[len(self.ema_params) :]
|
||||
)
|
||||
self.rebuild_ema_params = False
|
||||
|
||||
if (
|
||||
getattr(self.optimizer, "_step_supports_amp_scaling", False)
|
||||
and grad_scaler is not None
|
||||
):
|
||||
loss = self.optimizer.step(closure=closure, grad_scaler=grad_scaler)
|
||||
else:
|
||||
loss = self.optimizer.step(closure)
|
||||
|
||||
if self._should_update_at_step():
|
||||
self.update()
|
||||
self.current_step += 1
|
||||
return loss
|
||||
|
||||
def _should_update_at_step(self) -> bool:
|
||||
"""Check if the EMA parameters should be updated at the current step."""
|
||||
return self.current_step % self.every_n_steps == 0
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
"""Update the EMA parameters."""
|
||||
if self.stream is not None:
|
||||
self.stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
current_model_state = tuple(
|
||||
param.data.to(self.device, non_blocking=True) for param in self.all_parameters()
|
||||
)
|
||||
|
||||
if self.device.type == "cuda":
|
||||
ema_update(self.ema_params, current_model_state, self.decay)
|
||||
|
||||
if self.device.type == "cpu":
|
||||
self.thread = threading.Thread(
|
||||
target=run_ema_update_cpu,
|
||||
args=(
|
||||
self.ema_params,
|
||||
current_model_state,
|
||||
self.decay,
|
||||
self.stream,
|
||||
),
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
def swap_tensors(self, tensor1, tensor2):
|
||||
"""Swaps the tensors in-place."""
|
||||
tmp = torch.empty_like(tensor1)
|
||||
tmp.copy_(tensor1)
|
||||
tensor1.copy_(tensor2)
|
||||
tensor2.copy_(tmp)
|
||||
|
||||
def switch_main_parameter_weights(self, saving_ema_model: bool = False):
|
||||
"""Switches the main parameter weights with the EMA weights."""
|
||||
self.join()
|
||||
self.in_saving_ema_model_context = saving_ema_model
|
||||
for param, ema_param in zip(self.all_parameters(), self.ema_params):
|
||||
self.swap_tensors(param.data, ema_param)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_ema_weights(self, enabled: bool = True):
|
||||
r"""A context manager to in-place swap regular parameters with EMA parameters. It swaps back
|
||||
to the original regular parameters on context manager exit.
|
||||
|
||||
Args:
|
||||
enabled (bool): whether the swap should be performed
|
||||
"""
|
||||
|
||||
if enabled:
|
||||
self.switch_main_parameter_weights()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if enabled:
|
||||
self.switch_main_parameter_weights()
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward all other attribute calls to the optimizer."""
|
||||
return getattr(self.optimizer, name)
|
||||
|
||||
def join(self):
|
||||
"""Wait for the update to complete."""
|
||||
if self.stream is not None:
|
||||
self.stream.synchronize()
|
||||
|
||||
if self.thread is not None:
|
||||
self.thread.join()
|
||||
|
||||
def state_dict(self):
|
||||
"""Return the state dict for the optimizer."""
|
||||
self.join()
|
||||
|
||||
if self.save_original_optimizer_state:
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
# if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
|
||||
ema_params = (
|
||||
self.ema_params
|
||||
if not self.in_saving_ema_model_context
|
||||
else list(self.all_parameters())
|
||||
)
|
||||
state_dict = {
|
||||
"opt": self.optimizer.state_dict(),
|
||||
"ema": ema_params,
|
||||
"current_step": self.current_step,
|
||||
"decay": self.decay,
|
||||
"every_n_steps": self.every_n_steps,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load the state dict for the optimizer."""
|
||||
self.join()
|
||||
|
||||
self.optimizer.load_state_dict(state_dict["opt"])
|
||||
self.ema_params = tuple(
|
||||
param.to(self.device) for param in copy.deepcopy(state_dict["ema"])
|
||||
)
|
||||
self.current_step = state_dict["current_step"]
|
||||
self.decay = state_dict["decay"]
|
||||
self.every_n_steps = state_dict["every_n_steps"]
|
||||
self.rebuild_ema_params = False
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
"""Add a param group to the optimizer."""
|
||||
self.optimizer.add_param_group(param_group)
|
||||
self.rebuild_ema_params = True
|
||||
|
||||
|
||||
class EMAModelCheckpoint(ModelCheckpoint):
|
||||
"""EMA version of ModelCheckpoint that saves EMA checkpoints as well."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _ema_callback(self, trainer: "pl.Trainer") -> Optional[EMA]:
|
||||
"""Returns the EMA callback if it exists."""
|
||||
ema_callback = None
|
||||
for callback in trainer.callbacks:
|
||||
if isinstance(callback, EMA):
|
||||
ema_callback = callback
|
||||
return ema_callback
|
||||
|
||||
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||
"""Saves the checkpoint file and the EMA checkpoint file if it exists."""
|
||||
ema_callback = self._ema_callback(trainer)
|
||||
if ema_callback is not None:
|
||||
with ema_callback.save_original_optimizer_state(trainer):
|
||||
super()._save_checkpoint(trainer, filepath)
|
||||
|
||||
# save EMA copy of the model as well.
|
||||
with ema_callback.save_ema_model(trainer):
|
||||
filepath = self._ema_format_filepath(filepath)
|
||||
if self.verbose:
|
||||
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
|
||||
super()._save_checkpoint(trainer, filepath)
|
||||
else:
|
||||
super()._save_checkpoint(trainer, filepath)
|
||||
|
||||
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||
"""Removes the checkpoint file and the EMA checkpoint file if it exists."""
|
||||
super()._remove_checkpoint(trainer, filepath)
|
||||
ema_callback = self._ema_callback(trainer)
|
||||
if ema_callback is not None:
|
||||
# remove EMA copy of the state dict as well.
|
||||
filepath = self._ema_format_filepath(filepath)
|
||||
super()._remove_checkpoint(trainer, filepath)
|
||||
|
||||
def _ema_format_filepath(self, filepath: str) -> str:
|
||||
"""Appends '-EMA' to the filepath."""
|
||||
return filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}")
|
||||
|
||||
def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool:
|
||||
"""Checks if any of the checkpoints are EMA checkpoints."""
|
||||
return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints)
|
||||
|
||||
def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool:
|
||||
"""Checks if the filepath is an EMA checkpoint."""
|
||||
return str(filepath).endswith(f"-EMA{self.FILE_EXTENSION}")
|
||||
|
||||
@property
|
||||
def _saved_checkpoint_paths(self) -> Iterable[Path]:
|
||||
"""Returns all the saved checkpoint paths in the directory."""
|
||||
return Path(self.dirpath).rglob("*.ckpt")
|
||||
857
flowdock/models/components/cpm.py
Normal file
857
flowdock/models/components/cpm.py
Normal file
@@ -0,0 +1,857 @@
|
||||
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||
|
||||
import random
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, Optional, Tuple
|
||||
from omegaconf import DictConfig
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.models.components.embedding import (
|
||||
GaussianFourierEncoding1D,
|
||||
RelativeGeometryEncoding,
|
||||
)
|
||||
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
|
||||
from flowdock.models.components.modules import (
|
||||
BiDirectionalTriangleAttention,
|
||||
TransformerLayer,
|
||||
)
|
||||
from flowdock.utils import RankedLogger
|
||||
from flowdock.utils.frame_utils import cartesian_to_internal, get_frame_matrix
|
||||
from flowdock.utils.model_utils import GELUMLP
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
STATE_DICT = Dict[str, Any]
|
||||
|
||||
|
||||
class ProtFormer(torch.nn.Module):
|
||||
"""Protein relational reasoning with downsampled edges."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
pair_dim: int,
|
||||
n_blocks: int = 4,
|
||||
n_heads: int = 8,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Initialize the ProtFormer model."""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.pair_dim = pair_dim
|
||||
self.n_heads = n_heads
|
||||
self.n_blocks = n_blocks
|
||||
self.time_encoding = GaussianFourierEncoding1D(16)
|
||||
self.res_in_mlp = GELUMLP(dim + 32, dim, dropout=dropout)
|
||||
self.chain_pos_encoding = GaussianFourierEncoding1D(self.pair_dim // 4)
|
||||
|
||||
self.rel_geom_enc = RelativeGeometryEncoding(15, self.pair_dim)
|
||||
self.template_nenc = GELUMLP(64 + 37 * 3, self.dim, n_hidden_feats=128)
|
||||
self.template_eenc = RelativeGeometryEncoding(15, self.pair_dim)
|
||||
self.template_binding_site_enc = torch.nn.Linear(1, 64, bias=False)
|
||||
self.pp_edge_embed = GELUMLP(
|
||||
pair_dim + self.pair_dim // 4 * 2 + dim * 2,
|
||||
self.pair_dim,
|
||||
n_hidden_feats=dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.graph_stacks = torch.nn.ModuleList(
|
||||
[
|
||||
TransformerLayer(
|
||||
dim,
|
||||
n_heads,
|
||||
head_dim=pair_dim // n_heads,
|
||||
edge_channels=pair_dim,
|
||||
edge_update=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(self.n_blocks)
|
||||
]
|
||||
)
|
||||
self.ABab_mha = TransformerLayer(
|
||||
pair_dim,
|
||||
n_heads,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
self.triangle_stacks = torch.nn.ModuleList(
|
||||
[
|
||||
BiDirectionalTriangleAttention(pair_dim, pair_dim // n_heads, n_heads)
|
||||
for _ in range(self.n_blocks)
|
||||
]
|
||||
)
|
||||
self.graph_relations = [
|
||||
(
|
||||
"residue_to_residue",
|
||||
"gather_idx_ab_a",
|
||||
"gather_idx_ab_b",
|
||||
"prot_res",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"sampled_residue_to_sampled_residue",
|
||||
"gather_idx_AB_a",
|
||||
"gather_idx_AB_b",
|
||||
"prot_res",
|
||||
"prot_res",
|
||||
),
|
||||
]
|
||||
|
||||
def compute_chain_pe(
|
||||
self,
|
||||
residue_index,
|
||||
res_chain_index,
|
||||
src_rid,
|
||||
dst_rid,
|
||||
):
|
||||
"""Compute chain positional encoding for a pair of residues."""
|
||||
chain_disp = residue_index[src_rid] - residue_index[dst_rid]
|
||||
chain_rope = self.chain_pos_encoding(chain_disp.div(8).unsqueeze(-1)).div(
|
||||
chain_disp.div(8).abs().add(1).unsqueeze(-1)
|
||||
)
|
||||
# Mask cross-chain entries
|
||||
chain_mask = res_chain_index[src_rid] == res_chain_index[dst_rid]
|
||||
chain_rope = chain_rope * chain_mask[..., None]
|
||||
return chain_rope
|
||||
|
||||
def compute_chain_pair_pe(
|
||||
self,
|
||||
residue_index,
|
||||
res_chain_index,
|
||||
AB_broadcasted_rid,
|
||||
ab_rid,
|
||||
AB_broadcasted_cid,
|
||||
ab_cid,
|
||||
):
|
||||
"""Compute chain positional encoding for a pair of residues."""
|
||||
chain_disp_row = residue_index[AB_broadcasted_rid] - residue_index[ab_rid]
|
||||
chain_disp_col = residue_index[AB_broadcasted_cid] - residue_index[ab_cid]
|
||||
chain_disp = torch.stack([chain_disp_row, chain_disp_col], dim=-1)
|
||||
chain_rope = self.chain_pos_encoding(chain_disp.div(8).unsqueeze(-1)).div(
|
||||
chain_disp.div(8).abs().add(1).unsqueeze(-1)
|
||||
)
|
||||
# Mask cross-chain entries
|
||||
chain_mask_row = res_chain_index[AB_broadcasted_rid] == res_chain_index[ab_rid]
|
||||
chain_mask_col = res_chain_index[AB_broadcasted_cid] == res_chain_index[ab_cid]
|
||||
chain_mask = torch.stack([chain_mask_row, chain_mask_col], dim=-1)
|
||||
chain_rope = (chain_rope * chain_mask[..., None]).flatten(-2, -1)
|
||||
return chain_rope
|
||||
|
||||
def eval_protein_template_encodings(self, batch, edge_idx, use_plddt=False):
|
||||
"""Evaluate template encodings for protein residues."""
|
||||
with torch.no_grad():
|
||||
template_bb_coords = batch["features"]["apo_res_atom_positions"][:, :3]
|
||||
template_bb_frames = get_frame_matrix(
|
||||
template_bb_coords[:, 0, :],
|
||||
template_bb_coords[:, 1, :],
|
||||
template_bb_coords[:, 2, :],
|
||||
)
|
||||
# Add template local representations & lddt
|
||||
template_local_coords = cartesian_to_internal(
|
||||
batch["features"]["apo_res_atom_positions"],
|
||||
template_bb_frames.unsqueeze(1),
|
||||
)
|
||||
template_local_coords[~batch["features"]["apo_res_atom_mask"].bool()] = 0
|
||||
# if use_plddt:
|
||||
# template_plddt_enc = F.one_hot(
|
||||
# torch.bucketize(
|
||||
# batch["features"]["apo_pLDDT"],
|
||||
# torch.linspace(0, 1, 65, device=template_bb_coords.device)[:-1],
|
||||
# right=True,
|
||||
# )
|
||||
# - 1,
|
||||
# num_classes=64,
|
||||
# )
|
||||
# else:
|
||||
# template_plddt_enc = torch.zeros(
|
||||
# template_local_coords.shape[0], 64, device=template_bb_coords.device
|
||||
# )
|
||||
if self.training:
|
||||
use_sidechain_coords = random.randint(0, 1) # nosec
|
||||
template_local_coords = template_local_coords * use_sidechain_coords
|
||||
# use_plddt_input = random.randint(0, 1)
|
||||
# template_plddt_enc = template_plddt_enc * use_plddt_input
|
||||
if "binding_site_mask" in batch["features"].keys():
|
||||
# Externally-specified binding residue list
|
||||
binding_site_enc = self.template_binding_site_enc(
|
||||
batch["features"]["binding_site_mask"][:, None].float()
|
||||
)
|
||||
else:
|
||||
binding_site_enc = torch.zeros(
|
||||
template_local_coords.shape[0], 64, device=template_bb_coords.device
|
||||
)
|
||||
template_nfeat = self.template_nenc(
|
||||
torch.cat([template_local_coords.flatten(-2, -1), binding_site_enc], dim=-1)
|
||||
)
|
||||
template_efeat = self.template_eenc(template_bb_frames, edge_idx)
|
||||
template_alignment_mask = batch["features"]["apo_res_alignment_mask"].float()
|
||||
if self.training:
|
||||
# template_alignment_mask = template_alignment_mask * use_template
|
||||
nomasking_rate = random.randint(9, 10) / 10 # nosec
|
||||
template_alignment_mask = template_alignment_mask * (
|
||||
torch.rand_like(template_alignment_mask) < nomasking_rate
|
||||
)
|
||||
template_nfeat = template_nfeat * template_alignment_mask.unsqueeze(-1)
|
||||
template_efeat = (
|
||||
template_efeat
|
||||
* template_alignment_mask[edge_idx[0]].unsqueeze(-1)
|
||||
* template_alignment_mask[edge_idx[1]].unsqueeze(-1)
|
||||
)
|
||||
return template_nfeat, template_efeat
|
||||
|
||||
def forward(self, batch, **kwargs):
|
||||
"""Forward pass of the ProtFormer model."""
|
||||
return self.forward_prot_sample(batch, **kwargs)
|
||||
|
||||
def forward_prot_sample(
|
||||
self,
|
||||
batch,
|
||||
embed_coords=True,
|
||||
in_attr_suffix="",
|
||||
out_attr_suffix="",
|
||||
use_template=False,
|
||||
use_plddt=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Forward pass of the ProtFormer model for a single protein sample."""
|
||||
features = batch["features"]
|
||||
indexer = batch["indexer"]
|
||||
metadata = batch["metadata"]
|
||||
device = features["res_type"].device
|
||||
|
||||
time_encoding = self.time_encoding(features["timestep_encoding_prot"])
|
||||
if not embed_coords:
|
||||
time_encoding = torch.zeros_like(time_encoding)
|
||||
|
||||
residue_rep = (
|
||||
self.res_in_mlp(
|
||||
torch.cat(
|
||||
[
|
||||
features["res_embedding_in"],
|
||||
time_encoding,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
+ features["res_embedding_in"]
|
||||
)
|
||||
batch_size = metadata["num_structid"]
|
||||
|
||||
# Prepare indexers
|
||||
# Use max to ensure segmentation faults are 100% invoked
|
||||
# in case there are any bad indices
|
||||
max(metadata["num_a_per_sample"])
|
||||
n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"]
|
||||
|
||||
indexer["gather_idx_pid_b"] = indexer["gather_idx_pid_a"]
|
||||
# Evaluate gather_idx_AB_a and gather_idx_AB_b
|
||||
# Assign a to rows and b to columns
|
||||
# Simple broadcasting for single-structure batches
|
||||
indexer["gather_idx_AB_a"] = (
|
||||
indexer["gather_idx_pid_a"]
|
||||
.view(batch_size, n_protein_patches)[:, :, None]
|
||||
.expand(-1, -1, n_protein_patches)
|
||||
.contiguous()
|
||||
.flatten()
|
||||
)
|
||||
indexer["gather_idx_AB_b"] = (
|
||||
indexer["gather_idx_pid_b"]
|
||||
.view(batch_size, n_protein_patches)[:, None, :]
|
||||
.expand(-1, n_protein_patches, -1)
|
||||
.contiguous()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
# Handle all batch offsets here
|
||||
graph_batcher = make_multi_relation_graph_batcher(self.graph_relations, indexer, metadata)
|
||||
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||
|
||||
input_protein_coords_padded = features["input_protein_coords"]
|
||||
backbone_frames = get_frame_matrix(
|
||||
input_protein_coords_padded[:, 0, :],
|
||||
input_protein_coords_padded[:, 1, :],
|
||||
input_protein_coords_padded[:, 2, :],
|
||||
)
|
||||
batch["features"]["backbone_frames"] = backbone_frames
|
||||
# Adding geometrical info to pair representations
|
||||
|
||||
chain_pe = self.compute_chain_pe(
|
||||
features["residue_index"],
|
||||
features["res_chain_id"],
|
||||
merged_edge_idx[0],
|
||||
merged_edge_idx[1],
|
||||
)
|
||||
geometry_pe = self.rel_geom_enc(backbone_frames, merged_edge_idx)
|
||||
if not embed_coords:
|
||||
geometry_pe = torch.zeros_like(geometry_pe)
|
||||
merged_edge_reps = self.pp_edge_embed(
|
||||
torch.cat(
|
||||
[
|
||||
geometry_pe,
|
||||
chain_pe,
|
||||
residue_rep[merged_edge_idx[0]],
|
||||
residue_rep[merged_edge_idx[1]],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
if use_template and "apo_res_atom_positions" in features.keys():
|
||||
(
|
||||
template_res_encodings,
|
||||
template_geom_encodings,
|
||||
) = self.eval_protein_template_encodings(batch, merged_edge_idx, use_plddt=use_plddt)
|
||||
residue_rep = residue_rep + template_res_encodings
|
||||
merged_edge_reps = merged_edge_reps + template_geom_encodings
|
||||
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
|
||||
|
||||
node_reps = {"prot_res": residue_rep}
|
||||
|
||||
gather_idx_res_protpatch = indexer["gather_idx_a_pid"]
|
||||
# Pointer: AB->AB, ab->AB
|
||||
gather_idx_ab_AB = (
|
||||
indexer["gather_idx_ab_structid"] * n_protein_patches**2
|
||||
+ (gather_idx_res_protpatch % n_protein_patches)[indexer["gather_idx_ab_a"]]
|
||||
* n_protein_patches
|
||||
+ (gather_idx_res_protpatch % n_protein_patches)[indexer["gather_idx_ab_b"]]
|
||||
)
|
||||
|
||||
# Intertwine graph iterations and triangle iterations
|
||||
for block_id in range(self.n_blocks):
|
||||
# Communicate between atomistic and patch resolutions
|
||||
# Up-sampling for interface edge embeddings
|
||||
rec_pair_rep = edge_reps["residue_to_residue"]
|
||||
AB_grid_attr_flat = edge_reps["sampled_residue_to_sampled_residue"]
|
||||
# Upper-left block: intra-window visual-attention
|
||||
# Cross-attention between random and grid edges
|
||||
rec_pair_rep, AB_grid_attr_flat = self.ABab_mha(
|
||||
rec_pair_rep,
|
||||
AB_grid_attr_flat,
|
||||
(
|
||||
torch.arange(metadata["num_ab"], device=device),
|
||||
gather_idx_ab_AB,
|
||||
),
|
||||
)
|
||||
AB_grid_attr = AB_grid_attr_flat.view(
|
||||
batch_size,
|
||||
n_protein_patches,
|
||||
n_protein_patches,
|
||||
self.pair_dim,
|
||||
)
|
||||
|
||||
# Inter-patch triangle attentions, refining intermolecular edges
|
||||
_, AB_grid_attr = self.triangle_stacks[block_id](
|
||||
AB_grid_attr,
|
||||
AB_grid_attr,
|
||||
AB_grid_attr.unsqueeze(-4),
|
||||
)
|
||||
|
||||
# Transfer grid-formatted representations to edges
|
||||
edge_reps["residue_to_residue"] = rec_pair_rep
|
||||
edge_reps["sampled_residue_to_sampled_residue"] = AB_grid_attr.flatten(0, 2)
|
||||
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
|
||||
merged_edge_reps = graph_batcher.collate_edge_attr(edge_reps)
|
||||
|
||||
# Graph transformer iteration
|
||||
_, merged_node_reps, merged_edge_reps = self.graph_stacks[block_id](
|
||||
merged_node_reps,
|
||||
merged_node_reps,
|
||||
merged_edge_idx,
|
||||
merged_edge_reps,
|
||||
)
|
||||
node_reps = graph_batcher.offload_node_attr(merged_node_reps)
|
||||
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
|
||||
|
||||
batch["features"][f"rec_res_attr{out_attr_suffix}"] = node_reps["prot_res"]
|
||||
batch["features"][f"res_res_pair_attr{out_attr_suffix}"] = edge_reps["residue_to_residue"]
|
||||
batch["features"][f"res_res_grid_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||
"sampled_residue_to_sampled_residue"
|
||||
]
|
||||
batch["indexer"]["gather_idx_AB_a"] = indexer["gather_idx_AB_a"]
|
||||
batch["indexer"]["gather_idx_AB_b"] = indexer["gather_idx_AB_b"]
|
||||
batch["indexer"]["gather_idx_ab_AB"] = gather_idx_ab_AB
|
||||
return batch
|
||||
|
||||
|
||||
class BindingFormer(ProtFormer):
|
||||
"""Edge inference on protein-ligand graphs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
pair_dim: int,
|
||||
n_blocks: int = 4,
|
||||
n_heads: int = 8,
|
||||
n_ligand_patches: int = 16,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Initialize the BindingFormer model."""
|
||||
super().__init__(
|
||||
dim,
|
||||
pair_dim,
|
||||
n_blocks,
|
||||
n_heads,
|
||||
dropout,
|
||||
)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.n_blocks = n_blocks
|
||||
self.n_ligand_patches = n_ligand_patches
|
||||
self.pl_edge_embed = GELUMLP(dim * 2, self.pair_dim, n_hidden_feats=dim, dropout=dropout)
|
||||
self.AaJ_mha = TransformerLayer(pair_dim, n_heads, bidirectional=True)
|
||||
|
||||
self.graph_relations = [
|
||||
(
|
||||
"residue_to_residue",
|
||||
"gather_idx_ab_a",
|
||||
"gather_idx_ab_b",
|
||||
"prot_res",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"sampled_residue_to_sampled_residue",
|
||||
"gather_idx_AB_a",
|
||||
"gather_idx_AB_b",
|
||||
"prot_res",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"sampled_residue_to_sampled_lig_triplet",
|
||||
"gather_idx_AJ_a",
|
||||
"gather_idx_AJ_J",
|
||||
"prot_res",
|
||||
"lig_trp",
|
||||
),
|
||||
(
|
||||
"sampled_lig_triplet_to_sampled_residue",
|
||||
"gather_idx_AJ_J",
|
||||
"gather_idx_AJ_a",
|
||||
"lig_trp",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"residue_to_sampled_lig_triplet",
|
||||
"gather_idx_aJ_a",
|
||||
"gather_idx_aJ_J",
|
||||
"prot_res",
|
||||
"lig_trp",
|
||||
),
|
||||
(
|
||||
"sampled_lig_triplet_to_residue",
|
||||
"gather_idx_aJ_J",
|
||||
"gather_idx_aJ_a",
|
||||
"lig_trp",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"sampled_lig_triplet_to_sampled_lig_triplet",
|
||||
"gather_idx_IJ_I",
|
||||
"gather_idx_IJ_J",
|
||||
"lig_trp",
|
||||
"lig_trp",
|
||||
),
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch,
|
||||
observed_block_contacts=None,
|
||||
in_attr_suffix="",
|
||||
out_attr_suffix="",
|
||||
):
|
||||
"""Forward pass of the BindingFormer model."""
|
||||
features = batch["features"]
|
||||
indexer = batch["indexer"]
|
||||
metadata = batch["metadata"]
|
||||
device = features["res_type"].device
|
||||
# Synchronize with a language model
|
||||
residue_rep = features[f"rec_res_attr{in_attr_suffix}"]
|
||||
rec_pair_rep = features[f"res_res_pair_attr{in_attr_suffix}"]
|
||||
# Inherit the last-layer pair representations from protein encoder
|
||||
AB_grid_attr_flat = features[f"res_res_grid_attr_flat{in_attr_suffix}"]
|
||||
|
||||
# Prepare indexers
|
||||
batch_size = metadata["num_structid"]
|
||||
n_a_per_sample = max(metadata["num_a_per_sample"])
|
||||
n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"]
|
||||
|
||||
if not batch["misc"]["protein_only"]:
|
||||
n_ligand_patches = max(metadata["num_I_per_sample"])
|
||||
max(metadata["num_molid_per_sample"])
|
||||
lig_frame_rep = features[f"lig_trp_attr{in_attr_suffix}"]
|
||||
UI_grid_attr = features["lig_af_grid_attr_projected"]
|
||||
IJ_grid_attr = (UI_grid_attr + UI_grid_attr.transpose(1, 2)) / 2
|
||||
|
||||
aJ_grid_attr = self.pl_edge_embed(
|
||||
torch.cat(
|
||||
[
|
||||
residue_rep.view(batch_size, n_a_per_sample, self.dim)[:, :, None].expand(
|
||||
-1, -1, n_ligand_patches, -1
|
||||
),
|
||||
lig_frame_rep.view(batch_size, n_ligand_patches, self.dim)[
|
||||
:, None, :
|
||||
].expand(-1, n_a_per_sample, -1, -1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
AJ_grid_attr = IJ_grid_attr.new_zeros(
|
||||
batch_size, n_protein_patches, n_ligand_patches, self.pair_dim
|
||||
)
|
||||
gather_idx_I_I = torch.arange(
|
||||
batch_size * n_ligand_patches, device=AJ_grid_attr.device
|
||||
)
|
||||
gather_idx_a_a = torch.arange(batch_size * n_a_per_sample, device=AJ_grid_attr.device)
|
||||
# Note: off-diagonal (AJ) blocks are zero-initialized in the prior stack
|
||||
indexer["gather_idx_IJ_I"] = (
|
||||
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, :, None]
|
||||
.expand(-1, -1, n_ligand_patches)
|
||||
.contiguous()
|
||||
.flatten()
|
||||
)
|
||||
indexer["gather_idx_IJ_J"] = (
|
||||
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
|
||||
.expand(-1, n_ligand_patches, -1)
|
||||
.contiguous()
|
||||
.flatten()
|
||||
)
|
||||
indexer["gather_idx_AJ_a"] = (
|
||||
indexer["gather_idx_pid_a"]
|
||||
.view(batch_size, n_protein_patches)[:, :, None]
|
||||
.expand(-1, -1, n_ligand_patches)
|
||||
.contiguous()
|
||||
.flatten()
|
||||
)
|
||||
indexer["gather_idx_AJ_J"] = (
|
||||
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
|
||||
.expand(-1, n_protein_patches, -1)
|
||||
.contiguous()
|
||||
.flatten()
|
||||
)
|
||||
indexer["gather_idx_aJ_a"] = (
|
||||
gather_idx_a_a.view(batch_size, n_a_per_sample)[:, :, None]
|
||||
.expand(-1, -1, n_ligand_patches)
|
||||
.contiguous()
|
||||
.flatten()
|
||||
)
|
||||
indexer["gather_idx_aJ_J"] = (
|
||||
gather_idx_I_I.view(batch_size, n_ligand_patches)[:, None, :]
|
||||
.expand(-1, n_a_per_sample, -1)
|
||||
.contiguous()
|
||||
.flatten()
|
||||
)
|
||||
batch["indexer"] = indexer
|
||||
|
||||
if observed_block_contacts is not None:
|
||||
# Generative feedback from block one-hot sampling
|
||||
# AJ_grid_attr = (
|
||||
# AJ_grid_attr
|
||||
# + observed_block_contacts.transpose(1, 2)
|
||||
# .contiguous()
|
||||
# .flatten(0, 1)[indexer["gather_idx_I_molid"]]
|
||||
# .view(batch_size, n_ligand_patches, n_protein_patches, -1)
|
||||
# .transpose(1, 2)
|
||||
# .contiguous()
|
||||
# )
|
||||
AJ_grid_attr = AJ_grid_attr + observed_block_contacts
|
||||
|
||||
graph_batcher = make_multi_relation_graph_batcher(
|
||||
self.graph_relations, indexer, metadata
|
||||
)
|
||||
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||
node_reps = {
|
||||
"prot_res": residue_rep,
|
||||
"lig_trp": lig_frame_rep,
|
||||
}
|
||||
edge_reps = {
|
||||
"residue_to_residue": rec_pair_rep,
|
||||
"sampled_residue_to_sampled_residue": AB_grid_attr_flat,
|
||||
"sampled_lig_triplet_to_sampled_residue": AJ_grid_attr.flatten(0, 2),
|
||||
"sampled_residue_to_sampled_lig_triplet": AJ_grid_attr.flatten(0, 2),
|
||||
"sampled_lig_triplet_to_residue": aJ_grid_attr.flatten(0, 2),
|
||||
"residue_to_sampled_lig_triplet": aJ_grid_attr.flatten(0, 2),
|
||||
"sampled_lig_triplet_to_sampled_lig_triplet": IJ_grid_attr.flatten(0, 2),
|
||||
}
|
||||
edge_reps = graph_batcher.zero_pad_edge_attr(edge_reps, self.dim, device)
|
||||
else:
|
||||
graph_batcher = make_multi_relation_graph_batcher(
|
||||
self.graph_relations[:2], indexer, metadata
|
||||
)
|
||||
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||
|
||||
node_reps = {
|
||||
"prot_res": residue_rep,
|
||||
}
|
||||
edge_reps = {
|
||||
"residue_to_residue": rec_pair_rep,
|
||||
"sampled_residue_to_sampled_residue": AB_grid_attr_flat,
|
||||
}
|
||||
edge_reps = graph_batcher.zero_pad_edge_attr(edge_reps, self.dim, device)
|
||||
|
||||
# Intertwine graph iterations and triangle iterations
|
||||
gather_idx_res_protpatch = indexer["gather_idx_a_pid"]
|
||||
for block_id in range(self.n_blocks):
|
||||
# Communicate between atomistic and patch resolutions
|
||||
# Up-sampling for interface edge embeddings
|
||||
rec_pair_rep = edge_reps["residue_to_residue"]
|
||||
AB_grid_attr_flat = edge_reps["sampled_residue_to_sampled_residue"]
|
||||
AB_grid_attr = AB_grid_attr_flat.view(
|
||||
batch_size,
|
||||
n_protein_patches,
|
||||
n_protein_patches,
|
||||
self.pair_dim,
|
||||
)
|
||||
|
||||
if not batch["misc"]["protein_only"]:
|
||||
# Symmetrize off-diagonal blocks
|
||||
AJ_grid_attr_flat_ = (
|
||||
edge_reps["sampled_residue_to_sampled_lig_triplet"]
|
||||
+ edge_reps["sampled_lig_triplet_to_sampled_residue"]
|
||||
) / 2
|
||||
AJ_grid_attr = AJ_grid_attr_flat_.contiguous().view(
|
||||
batch_size, n_protein_patches, n_ligand_patches, -1
|
||||
)
|
||||
aJ_grid_attr_flat_ = (
|
||||
edge_reps["residue_to_sampled_lig_triplet"]
|
||||
+ edge_reps["sampled_lig_triplet_to_residue"]
|
||||
) / 2
|
||||
aJ_grid_attr = aJ_grid_attr_flat_.contiguous().view(
|
||||
batch_size, n_a_per_sample, n_ligand_patches, -1
|
||||
)
|
||||
IJ_grid_attr = (
|
||||
edge_reps["sampled_lig_triplet_to_sampled_lig_triplet"]
|
||||
.contiguous()
|
||||
.view(batch_size, n_ligand_patches, n_ligand_patches, -1)
|
||||
)
|
||||
AJ_grid_attr_temp_, aJ_grid_attr_temp_ = self.AaJ_mha(
|
||||
AJ_grid_attr.flatten(0, 1),
|
||||
aJ_grid_attr.flatten(0, 1),
|
||||
(
|
||||
gather_idx_res_protpatch,
|
||||
torch.arange(gather_idx_res_protpatch.shape[0], device=device),
|
||||
),
|
||||
)
|
||||
AJ_grid_attr = AJ_grid_attr_temp_.contiguous().view(
|
||||
batch_size, n_protein_patches, n_ligand_patches, -1
|
||||
)
|
||||
aJ_grid_attr = aJ_grid_attr_temp_.contiguous().view(
|
||||
batch_size, n_a_per_sample, n_ligand_patches, -1
|
||||
)
|
||||
merged_grid_rep = torch.cat(
|
||||
[
|
||||
torch.cat([AB_grid_attr, AJ_grid_attr], dim=2),
|
||||
torch.cat([AJ_grid_attr.transpose(1, 2), IJ_grid_attr], dim=2),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
merged_grid_rep = AB_grid_attr
|
||||
|
||||
# Inter-patch triangle attentions
|
||||
_, merged_grid_rep = self.triangle_stacks[block_id](
|
||||
merged_grid_rep,
|
||||
merged_grid_rep,
|
||||
merged_grid_rep.unsqueeze(-4),
|
||||
)
|
||||
|
||||
# Dis-assemble the grid representation
|
||||
AB_grid_attr = merged_grid_rep[:, :n_protein_patches, :n_protein_patches]
|
||||
# Transfer grid-formatted representations to edges
|
||||
edge_reps["residue_to_residue"] = rec_pair_rep
|
||||
edge_reps["sampled_residue_to_sampled_residue"] = AB_grid_attr.flatten(0, 2)
|
||||
|
||||
if not batch["misc"]["protein_only"]:
|
||||
AJ_grid_attr = merged_grid_rep[
|
||||
:, :n_protein_patches, n_protein_patches:
|
||||
].contiguous()
|
||||
IJ_grid_attr = merged_grid_rep[
|
||||
:, n_protein_patches:, n_protein_patches:
|
||||
].contiguous()
|
||||
|
||||
edge_reps["sampled_residue_to_sampled_lig_triplet"] = AJ_grid_attr.flatten(0, 2)
|
||||
edge_reps["sampled_lig_triplet_to_sampled_residue"] = AJ_grid_attr.flatten(0, 2)
|
||||
edge_reps["residue_to_sampled_lig_triplet"] = aJ_grid_attr.flatten(0, 2)
|
||||
edge_reps["sampled_lig_triplet_to_residue"] = aJ_grid_attr.flatten(0, 2)
|
||||
edge_reps["sampled_lig_triplet_to_sampled_lig_triplet"] = IJ_grid_attr.flatten(
|
||||
0, 2
|
||||
)
|
||||
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
|
||||
merged_edge_reps = graph_batcher.collate_edge_attr(edge_reps)
|
||||
|
||||
# Graph transformer iteration
|
||||
_, merged_node_reps, merged_edge_reps = self.graph_stacks[block_id](
|
||||
merged_node_reps,
|
||||
merged_node_reps,
|
||||
merged_edge_idx,
|
||||
merged_edge_reps,
|
||||
)
|
||||
node_reps = graph_batcher.offload_node_attr(merged_node_reps)
|
||||
edge_reps = graph_batcher.offload_edge_attr(merged_edge_reps)
|
||||
|
||||
batch["features"][f"rec_res_attr{out_attr_suffix}"] = node_reps["prot_res"]
|
||||
batch["features"][f"res_res_pair_attr{out_attr_suffix}"] = edge_reps["residue_to_residue"]
|
||||
batch["features"][f"res_res_grid_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||
"sampled_residue_to_sampled_residue"
|
||||
]
|
||||
if not batch["misc"]["protein_only"]:
|
||||
batch["features"][f"lig_trp_attr{out_attr_suffix}"] = node_reps["lig_trp"]
|
||||
batch["features"][f"res_trp_grid_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||
"sampled_residue_to_sampled_lig_triplet"
|
||||
]
|
||||
batch["features"][f"res_trp_pair_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||
"residue_to_sampled_lig_triplet"
|
||||
]
|
||||
batch["features"][f"trp_trp_grid_attr_flat{out_attr_suffix}"] = edge_reps[
|
||||
"sampled_lig_triplet_to_sampled_lig_triplet"
|
||||
]
|
||||
batch["metadata"]["n_lig_patches_per_sample"] = n_ligand_patches
|
||||
return batch
|
||||
|
||||
|
||||
def resolve_protein_encoder(
|
||||
protein_model_cfg: DictConfig,
|
||||
task_cfg: DictConfig,
|
||||
state_dict: Optional[STATE_DICT] = None,
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module]:
|
||||
"""Instantiates a ProtFormer model for protein encoding.
|
||||
|
||||
:param protein_model_cfg: Protein model configuration.
|
||||
:param task_cfg: Task configuration.
|
||||
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||
:return: Protein encoder model and residue input projector.
|
||||
"""
|
||||
node_dim = protein_model_cfg.residue_dim
|
||||
model = ProtFormer(
|
||||
node_dim,
|
||||
protein_model_cfg.pair_dim,
|
||||
n_heads=protein_model_cfg.n_heads,
|
||||
n_blocks=protein_model_cfg.n_encoder_stacks,
|
||||
dropout=task_cfg.dropout,
|
||||
)
|
||||
if protein_model_cfg.use_esm_embedding:
|
||||
# protein sequence language model
|
||||
res_in_projector = torch.nn.Linear(protein_model_cfg.plm_embed_dim, node_dim, bias=False)
|
||||
else:
|
||||
# one-hot amino acid types
|
||||
res_in_projector = torch.nn.Linear(
|
||||
protein_model_cfg.n_aa_types,
|
||||
node_dim,
|
||||
bias=False,
|
||||
)
|
||||
if protein_model_cfg.from_pretrained and state_dict is not None:
|
||||
try:
|
||||
# NOTE: we must avoid enforcing strict key matching
|
||||
# due to the (new) weights `template_binding_site_enc.weight`
|
||||
model.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("protein_encoder")
|
||||
},
|
||||
strict=False,
|
||||
)
|
||||
log.info("Successfully loaded pretrained protein encoder weights.")
|
||||
except Exception as e:
|
||||
log.warning(f"Skipping loading of pretrained protein encoder weights due to: {e}.")
|
||||
|
||||
try:
|
||||
res_in_projector.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith(
|
||||
"plm_adapter"
|
||||
if protein_model_cfg.use_esm_embedding
|
||||
else "res_in_projector"
|
||||
)
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained protein input projector weights.")
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping loading of pretrained protein input projector weights due to: {e}."
|
||||
)
|
||||
return model, res_in_projector
|
||||
|
||||
|
||||
def resolve_pl_contact_stack(
|
||||
protein_model_cfg: DictConfig,
|
||||
ligand_model_cfg: DictConfig,
|
||||
contact_cfg: DictConfig,
|
||||
task_cfg: DictConfig,
|
||||
state_dict: Optional[STATE_DICT] = None,
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.nn.Module]:
|
||||
"""Instantiates a BindingFormer model for protein-ligand contact prediction.
|
||||
|
||||
:param protein_model_cfg: Protein model configuration.
|
||||
:param ligand_model_cfg: Ligand model configuration.
|
||||
:param contact_cfg: Contact prediction configuration.
|
||||
:param task_cfg: Task configuration.
|
||||
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||
:return: Protein-ligand contact prediction model, contact code embedding, distance bins, and
|
||||
distogram head.
|
||||
"""
|
||||
pl_contact_stack = BindingFormer(
|
||||
protein_model_cfg.residue_dim,
|
||||
protein_model_cfg.pair_dim,
|
||||
n_heads=protein_model_cfg.n_heads,
|
||||
n_blocks=contact_cfg.n_stacks,
|
||||
n_ligand_patches=ligand_model_cfg.n_patches,
|
||||
dropout=contact_cfg.dropout if contact_cfg.get("dropout") else task_cfg.dropout,
|
||||
)
|
||||
contact_code_embed = torch.nn.Embedding(2, protein_model_cfg.pair_dim)
|
||||
# Distogram heads
|
||||
dist_bins = torch.nn.Parameter(torch.linspace(2, 22, 32), requires_grad=False)
|
||||
dgram_head = GELUMLP(
|
||||
protein_model_cfg.pair_dim,
|
||||
32,
|
||||
n_hidden_feats=protein_model_cfg.pair_dim,
|
||||
zero_init=True,
|
||||
)
|
||||
if contact_cfg.from_pretrained and state_dict is not None:
|
||||
try:
|
||||
# NOTE: we must avoid enforcing strict key matching
|
||||
# due to the (new) weights `template_binding_site_enc.weight`
|
||||
pl_contact_stack.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("pl_contact_stack")
|
||||
},
|
||||
strict=False,
|
||||
)
|
||||
log.info("Successfully loaded pretrained protein-ligand contact prediction weights.")
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping loading of pretrained protein-ligand contact prediction weights due to: {e}."
|
||||
)
|
||||
|
||||
try:
|
||||
contact_code_embed.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("contact_code_embed")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained contact code embedding weights.")
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping loading of pretrained contact code embedding weights due to: {e}."
|
||||
)
|
||||
|
||||
try:
|
||||
dgram_head.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("dgram_head")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained distogram head weights.")
|
||||
except Exception as e:
|
||||
log.warning(f"Skipping loading of pretrained distogram head weights due to: {e}.")
|
||||
return pl_contact_stack, contact_code_embed, dist_bins, dgram_head
|
||||
105
flowdock/models/components/embedding.py
Normal file
105
flowdock/models/components/embedding.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||
|
||||
import math
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Tuple
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.utils.frame_utils import RigidTransform
|
||||
|
||||
|
||||
class GaussianFourierEncoding1D(torch.nn.Module):
|
||||
"""Gaussian Fourier Encoding for 1D data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_basis: int,
|
||||
eps: float = 1e-2,
|
||||
):
|
||||
"""Initialize Gaussian Fourier Encoding."""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.fourier_freqs = torch.nn.Parameter(torch.randn(n_basis) * math.pi)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
):
|
||||
"""Forward pass of Gaussian Fourier Encoding."""
|
||||
encodings = torch.cat(
|
||||
[
|
||||
torch.sin(self.fourier_freqs.mul(x)),
|
||||
torch.cos(self.fourier_freqs.mul(x)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return encodings
|
||||
|
||||
|
||||
class GaussianRBFEncoding1D(torch.nn.Module):
|
||||
"""Gaussian RBF Encoding for 1D data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_basis: int,
|
||||
x_max: float,
|
||||
sigma: float = 1.0,
|
||||
):
|
||||
"""Initialize Gaussian RBF Encoding."""
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.rbf_centers = torch.nn.Parameter(
|
||||
torch.linspace(0, x_max, n_basis), requires_grad=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
):
|
||||
"""Forward pass of Gaussian RBF Encoding."""
|
||||
encodings = torch.exp(-((x.unsqueeze(-1) - self.rbf_centers).div(self.sigma).square()))
|
||||
return encodings
|
||||
|
||||
|
||||
class RelativeGeometryEncoding(torch.nn.Module):
|
||||
"Compute radial basis functions and iterresidue/pseudoresidue orientations."
|
||||
|
||||
def __init__(self, n_basis: int, out_dim: int, d_max: float = 20.0):
|
||||
"""Initialize RelativeGeometryEncoding."""
|
||||
super().__init__()
|
||||
self.rbf_encoding = GaussianRBFEncoding1D(n_basis, d_max)
|
||||
self.rel_geom_projector = torch.nn.Linear(n_basis + 15, out_dim, bias=False)
|
||||
|
||||
def forward(self, frames: RigidTransform, merged_edge_idx: Tuple[torch.Tensor, torch.Tensor]):
|
||||
"""Forward pass of RelativeGeometryEncoding."""
|
||||
frame_t, frame_R = frames.t, frames.R
|
||||
pair_dists = torch.norm(
|
||||
frame_t[merged_edge_idx[0]] - frame_t[merged_edge_idx[1]],
|
||||
dim=-1,
|
||||
)
|
||||
pair_directions_l = torch.matmul(
|
||||
(frame_t[merged_edge_idx[1]] - frame_t[merged_edge_idx[0]]).unsqueeze(-2),
|
||||
frame_R[merged_edge_idx[0]],
|
||||
).squeeze(-2) / pair_dists.square().add(1).sqrt().unsqueeze(-1)
|
||||
pair_directions_r = torch.matmul(
|
||||
(frame_t[merged_edge_idx[0]] - frame_t[merged_edge_idx[1]]).unsqueeze(-2),
|
||||
frame_R[merged_edge_idx[1]],
|
||||
).squeeze(-2) / pair_dists.square().add(1).sqrt().unsqueeze(-1)
|
||||
pair_orientations = torch.matmul(
|
||||
frame_R.transpose(-2, -1).contiguous()[merged_edge_idx[0]],
|
||||
frame_R[merged_edge_idx[1]],
|
||||
)
|
||||
return self.rel_geom_projector(
|
||||
torch.cat(
|
||||
[
|
||||
self.rbf_encoding(pair_dists),
|
||||
pair_directions_l,
|
||||
pair_directions_r,
|
||||
pair_orientations.flatten(-2, -1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
884
flowdock/models/components/esdm.py
Normal file
884
flowdock/models/components/esdm.py
Normal file
@@ -0,0 +1,884 @@
|
||||
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, Optional, Tuple
|
||||
from omegaconf import DictConfig
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.models.components.embedding import (
|
||||
GaussianFourierEncoding1D,
|
||||
RelativeGeometryEncoding,
|
||||
)
|
||||
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
|
||||
from flowdock.models.components.modules import PointSetAttention
|
||||
from flowdock.utils import RankedLogger
|
||||
from flowdock.utils.frame_utils import RigidTransform, get_frame_matrix
|
||||
from flowdock.utils.model_utils import GELUMLP, AveragePooling, SumPooling, segment_mean
|
||||
|
||||
STATE_DICT = Dict[str, Any]
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class LocalUpdateUsingReferenceRotations(torch.nn.Module):
|
||||
"""Update local geometric representations using reference rotations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fiber_dim: int,
|
||||
extra_feat_dim: int = 0,
|
||||
eps: float = 1e-4,
|
||||
dropout: float = 0.0,
|
||||
hidden_dim: Optional[int] = None,
|
||||
zero_init: bool = False,
|
||||
):
|
||||
"""Initialize the LocalUpdateUsingReferenceRotations module."""
|
||||
super().__init__()
|
||||
self.dim = fiber_dim * 5 + extra_feat_dim
|
||||
self.fiber_dim = fiber_dim
|
||||
self.mlp = GELUMLP(
|
||||
self.dim,
|
||||
fiber_dim * 4,
|
||||
dropout=dropout,
|
||||
zero_init=zero_init,
|
||||
n_hidden_feats=hidden_dim,
|
||||
)
|
||||
self.eps = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
rotation_mats: torch.Tensor,
|
||||
extra_feats=None,
|
||||
):
|
||||
"""Forward pass of the LocalUpdateUsingReferenceRotations module."""
|
||||
# Vector norms are evaluated without applying rigid transform
|
||||
vecx_local = torch.matmul(
|
||||
x[:, 1:].transpose(-2, -1),
|
||||
rotation_mats,
|
||||
)
|
||||
x1_local = torch.cat(
|
||||
[
|
||||
x[:, 0],
|
||||
vecx_local.flatten(-2, -1),
|
||||
x[:, 1:].square().sum(dim=-2).add(self.eps).sqrt(),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if extra_feats is not None:
|
||||
x1_local = torch.cat([x1_local, extra_feats], dim=-1)
|
||||
x1_local = self.mlp(x1_local).view(-1, 4, self.fiber_dim)
|
||||
vecx1_out = torch.matmul(
|
||||
rotation_mats,
|
||||
x1_local[:, 1:],
|
||||
)
|
||||
x1_out = torch.cat([x1_local[:, :1], vecx1_out], dim=-2)
|
||||
return x1_out
|
||||
|
||||
|
||||
class LocalUpdateUsingChannelWiseGating(torch.nn.Module):
|
||||
"""Update local geometric representations using channel-wise gating."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fiber_dim: int,
|
||||
eps: float = 1e-4,
|
||||
dropout: float = 0.0,
|
||||
hidden_dim: Optional[int] = None,
|
||||
zero_init: bool = False,
|
||||
):
|
||||
"""Initialize the LocalUpdateUsingChannelWiseGating module."""
|
||||
super().__init__()
|
||||
self.dim = fiber_dim * 2
|
||||
self.fiber_dim = fiber_dim
|
||||
self.mlp = GELUMLP(
|
||||
self.dim,
|
||||
self.dim,
|
||||
dropout=dropout,
|
||||
n_hidden_feats=hidden_dim,
|
||||
zero_init=zero_init,
|
||||
)
|
||||
self.gate = torch.nn.Sigmoid()
|
||||
self.lin_out = torch.nn.Linear(fiber_dim, fiber_dim, bias=False)
|
||||
if zero_init:
|
||||
self.lin_out.weight.data.fill_(0.0)
|
||||
self.eps = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
):
|
||||
"""Forward pass of the LocalUpdateUsingChannelWiseGating module."""
|
||||
x1 = torch.cat(
|
||||
[
|
||||
x[:, 0],
|
||||
x[:, 1:].square().sum(dim=-2).add(self.eps).sqrt(),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
x1 = self.mlp(x1)
|
||||
# Gated nonlinear operation on l=1 representations
|
||||
x1_scalar, x1_gatein = torch.split(x1, self.fiber_dim, dim=-1)
|
||||
x1_gate = self.gate(x1_gatein).unsqueeze(-2)
|
||||
vecx1_out = self.lin_out(x[:, 1:]).mul(x1_gate)
|
||||
x1_out = torch.cat([x1_scalar.unsqueeze(-2), vecx1_out], dim=-2)
|
||||
return x1_out
|
||||
|
||||
|
||||
class EquivariantTransformerBlock(torch.nn.Module):
|
||||
"""Equivariant Transformer Block module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fiber_dim: int,
|
||||
heads: int = 8,
|
||||
point_dim: int = 4,
|
||||
eps: float = 1e-4,
|
||||
edge_dim: Optional[int] = None,
|
||||
target_frames: bool = False,
|
||||
edge_update: bool = False,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Initialize the EquivariantTransformerBlock module."""
|
||||
super().__init__()
|
||||
self.attn_conv = PointSetAttention(
|
||||
fiber_dim,
|
||||
heads=heads,
|
||||
point_dim=point_dim,
|
||||
edge_dim=edge_dim,
|
||||
edge_update=edge_update,
|
||||
)
|
||||
self.fiber_dim = fiber_dim
|
||||
self.target_frames = target_frames
|
||||
self.eps = eps
|
||||
self.edge_update = edge_update
|
||||
if target_frames:
|
||||
self.local_update = LocalUpdateUsingReferenceRotations(
|
||||
fiber_dim, eps=eps, dropout=dropout, zero_init=True
|
||||
)
|
||||
else:
|
||||
self.local_update = LocalUpdateUsingChannelWiseGating(
|
||||
fiber_dim, eps=eps, dropout=dropout, zero_init=True
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
edge_index: torch.LongTensor,
|
||||
t: torch.Tensor,
|
||||
R: torch.Tensor = None,
|
||||
x_edge: torch.Tensor = None,
|
||||
):
|
||||
"""Forward pass of the EquivariantTransformerBlock module."""
|
||||
if self.edge_update:
|
||||
xout, edge_out = self.attn_conv(x, x, edge_index, t, t, x_edge=x_edge)
|
||||
x_edge = x_edge + edge_out
|
||||
else:
|
||||
xout = self.attn_conv(x, x, edge_index, t, t, x_edge=x_edge)
|
||||
x = x + xout
|
||||
if self.target_frames:
|
||||
x = self.local_update(x, R) + x
|
||||
else:
|
||||
x = self.local_update(x) + x
|
||||
return x, x_edge
|
||||
|
||||
|
||||
class EquivariantStructureDenoisingModule(torch.nn.Module):
|
||||
"""Equivariant Structure Denoising Module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fiber_dim: int,
|
||||
input_dim: int,
|
||||
input_pair_dim: int,
|
||||
hidden_dim: int = 1024,
|
||||
n_stacks: int = 4,
|
||||
n_heads: int = 8,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Initialize the EquivariantStructureDenoisingModule module."""
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.input_pair_dim = input_pair_dim
|
||||
self.fiber_dim = fiber_dim
|
||||
self.protatm_padding_dim = 37
|
||||
self.n_blocks = n_stacks
|
||||
self.input_node_projector = torch.nn.Linear(input_dim, fiber_dim, bias=False)
|
||||
self.input_node_vec_projector = torch.nn.Linear(input_dim, fiber_dim * 3, bias=False)
|
||||
self.input_pair_projector = torch.nn.Linear(input_pair_dim, fiber_dim, bias=False)
|
||||
# Inherit the residue representations
|
||||
self.atm_embed = GELUMLP(input_dim + 32, fiber_dim)
|
||||
self.ipa_modules = torch.nn.ModuleList(
|
||||
[
|
||||
EquivariantTransformerBlock(
|
||||
fiber_dim,
|
||||
heads=n_heads,
|
||||
point_dim=fiber_dim // (n_heads * 2),
|
||||
edge_dim=fiber_dim,
|
||||
target_frames=True,
|
||||
edge_update=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(n_stacks)
|
||||
]
|
||||
)
|
||||
self.res_adapters = torch.nn.ModuleList(
|
||||
[
|
||||
LocalUpdateUsingReferenceRotations(
|
||||
fiber_dim,
|
||||
extra_feat_dim=input_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
dropout=dropout,
|
||||
zero_init=True,
|
||||
)
|
||||
for _ in range(n_stacks)
|
||||
]
|
||||
)
|
||||
self.protatm_type_encoding = GELUMLP(self.protatm_padding_dim + input_dim, input_pair_dim)
|
||||
self.time_encoding = GaussianFourierEncoding1D(16)
|
||||
self.rel_geom_enc = RelativeGeometryEncoding(15, fiber_dim)
|
||||
self.rel_geom_embed = GELUMLP(fiber_dim, fiber_dim, n_hidden_feats=fiber_dim)
|
||||
# [displacement, scale]
|
||||
self.out_drift_res = torch.nn.ModuleList(
|
||||
[torch.nn.Linear(fiber_dim, 1, bias=False) for _ in range(n_stacks)]
|
||||
)
|
||||
# for i in range(n_stacks):
|
||||
# self.out_drift_res[i].weight.data.fill_(0.0)
|
||||
self.out_scale_res = torch.nn.ModuleList(
|
||||
[GELUMLP(fiber_dim, 1, zero_init=True) for _ in range(n_stacks)]
|
||||
)
|
||||
self.out_drift_atm = torch.nn.ModuleList(
|
||||
[torch.nn.Linear(fiber_dim, 1, bias=False) for _ in range(n_stacks)]
|
||||
)
|
||||
# for i in range(n_stacks):
|
||||
# self.out_drift_atm[i].weight.data.fill_(0.0)
|
||||
self.out_scale_atm = torch.nn.ModuleList(
|
||||
[GELUMLP(fiber_dim, 1, zero_init=True) for _ in range(n_stacks)]
|
||||
)
|
||||
|
||||
# Pre-tabulated edges
|
||||
self.graph_relations = [
|
||||
(
|
||||
"residue_to_residue",
|
||||
"gather_idx_ab_a",
|
||||
"gather_idx_ab_b",
|
||||
"prot_res",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"sampled_residue_to_sampled_residue",
|
||||
"gather_idx_AB_a",
|
||||
"gather_idx_AB_b",
|
||||
"prot_res",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"prot_atm_to_prot_atm_graph",
|
||||
"protatm_protatm_idx_src",
|
||||
"protatm_protatm_idx_dst",
|
||||
"prot_atm",
|
||||
"prot_atm",
|
||||
),
|
||||
(
|
||||
"prot_atm_to_prot_atm_knn",
|
||||
"knn_idx_protatm_protatm_src",
|
||||
"knn_idx_protatm_protatm_dst",
|
||||
"prot_atm",
|
||||
"prot_atm",
|
||||
),
|
||||
(
|
||||
"prot_atm_to_residue",
|
||||
"protatm_res_idx_protatm",
|
||||
"protatm_res_idx_res",
|
||||
"prot_atm",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"residue_to_prot_atm",
|
||||
"protatm_res_idx_res",
|
||||
"protatm_res_idx_protatm",
|
||||
"prot_res",
|
||||
"prot_atm",
|
||||
),
|
||||
(
|
||||
"sampled_lig_triplet_to_lig_atm",
|
||||
"gather_idx_UI_I",
|
||||
"gather_idx_UI_u",
|
||||
"lig_trp",
|
||||
"lig_atm",
|
||||
),
|
||||
(
|
||||
"lig_atm_to_sampled_lig_triplet",
|
||||
"gather_idx_UI_u",
|
||||
"gather_idx_UI_I",
|
||||
"lig_atm",
|
||||
"lig_trp",
|
||||
),
|
||||
(
|
||||
"lig_atm_to_lig_atm_graph",
|
||||
"gather_idx_uv_u",
|
||||
"gather_idx_uv_v",
|
||||
"lig_atm",
|
||||
"lig_atm",
|
||||
),
|
||||
(
|
||||
"sampled_residue_to_sampled_lig_triplet",
|
||||
"gather_idx_AJ_a",
|
||||
"gather_idx_AJ_J",
|
||||
"prot_res",
|
||||
"lig_trp",
|
||||
),
|
||||
(
|
||||
"sampled_lig_triplet_to_sampled_residue",
|
||||
"gather_idx_AJ_J",
|
||||
"gather_idx_AJ_a",
|
||||
"lig_trp",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"residue_to_sampled_lig_triplet",
|
||||
"gather_idx_aJ_a",
|
||||
"gather_idx_aJ_J",
|
||||
"prot_res",
|
||||
"lig_trp",
|
||||
),
|
||||
(
|
||||
"sampled_lig_triplet_to_residue",
|
||||
"gather_idx_aJ_J",
|
||||
"gather_idx_aJ_a",
|
||||
"lig_trp",
|
||||
"prot_res",
|
||||
),
|
||||
(
|
||||
"sampled_lig_triplet_to_sampled_lig_triplet",
|
||||
"gather_idx_IJ_I",
|
||||
"gather_idx_IJ_J",
|
||||
"lig_trp",
|
||||
"lig_trp",
|
||||
),
|
||||
(
|
||||
"prot_atm_to_lig_atm_knn",
|
||||
"knn_idx_protatm_ligatm_src",
|
||||
"knn_idx_protatm_ligatm_dst",
|
||||
"prot_atm",
|
||||
"lig_atm",
|
||||
),
|
||||
(
|
||||
"lig_atm_to_prot_atm_knn",
|
||||
"knn_idx_ligatm_protatm_src",
|
||||
"knn_idx_ligatm_protatm_dst",
|
||||
"lig_atm",
|
||||
"prot_atm",
|
||||
),
|
||||
(
|
||||
"lig_atm_to_lig_atm_knn",
|
||||
"knn_idx_ligatm_ligatm_src",
|
||||
"knn_idx_ligatm_ligatm_dst",
|
||||
"lig_atm",
|
||||
"lig_atm",
|
||||
),
|
||||
]
|
||||
self.graph_relations_no_ligand = self.graph_relations[:6]
|
||||
|
||||
def init_scalar_vec_rep(self, x, x_v=None, frame=None):
|
||||
"""Initialize scalar and vector representations."""
|
||||
if frame is None:
|
||||
# Zero-initialize the vector channels
|
||||
vec_shape = (*x.shape[:-1], 3, x.shape[-1])
|
||||
res = torch.cat([x.unsqueeze(-2), torch.zeros(vec_shape, device=x.device)], dim=-2)
|
||||
else:
|
||||
x_v = x_v.view(*x.shape[:-1], 3, x.shape[-1])
|
||||
x_v_glob = torch.matmul(frame.R, x_v)
|
||||
res = torch.cat([x.unsqueeze(-2), x_v_glob], dim=-2)
|
||||
return res
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch,
|
||||
frozen_lig=False,
|
||||
frozen_prot=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Forward pass of the EquivariantStructureDenoisingModule module."""
|
||||
features = batch["features"]
|
||||
indexer = batch["indexer"]
|
||||
metadata = batch["metadata"]
|
||||
metadata["num_structid"]
|
||||
max(metadata["num_a_per_sample"])
|
||||
|
||||
prot_res_rep_in = features["rec_res_attr_decin"]
|
||||
timestep_prot = features["timestep_encoding_prot"]
|
||||
device = features["res_type"].device
|
||||
|
||||
# Protein all-atom representation initialization
|
||||
protatm_padding_mask = features["res_atom_mask"]
|
||||
protatm_atom37_onehot = torch.nn.functional.one_hot(
|
||||
features["protatm_to_atom37_index"], num_classes=self.protatm_padding_dim
|
||||
)
|
||||
protatm_res_pair_encoding = self.protatm_type_encoding(
|
||||
torch.cat(
|
||||
[
|
||||
prot_res_rep_in[indexer["protatm_res_idx_res"]],
|
||||
protatm_atom37_onehot,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
# Gathered AA features from individual graphs
|
||||
prot_atm_rep_in = features["prot_atom_attr_projected"]
|
||||
prot_atm_rep_int = self.atm_embed(
|
||||
torch.cat(
|
||||
[
|
||||
prot_atm_rep_in,
|
||||
self.time_encoding(timestep_prot)[indexer["protatm_res_idx_res"]],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
prot_atm_coords_padded = features["input_protein_coords"]
|
||||
prot_atm_coords_flat = prot_atm_coords_padded[protatm_padding_mask]
|
||||
|
||||
# Embed the rigid body node representations
|
||||
backbone_frames = get_frame_matrix(
|
||||
prot_atm_coords_padded[:, 0],
|
||||
prot_atm_coords_padded[:, 1],
|
||||
prot_atm_coords_padded[:, 2],
|
||||
)
|
||||
prot_res_rep = self.init_scalar_vec_rep(
|
||||
self.input_node_projector(prot_res_rep_in),
|
||||
x_v=self.input_node_vec_projector(prot_res_rep_in),
|
||||
frame=backbone_frames,
|
||||
)
|
||||
prot_atm_rep = self.init_scalar_vec_rep(prot_atm_rep_int)
|
||||
# gather AA features from individual graphs
|
||||
node_reps = {
|
||||
"prot_res": prot_res_rep,
|
||||
"prot_atm": prot_atm_rep,
|
||||
}
|
||||
# Embed pair representations
|
||||
edge_reps = {
|
||||
"residue_to_residue": features["res_res_pair_attr_decin"],
|
||||
"prot_atm_to_prot_atm_graph": features["prot_atom_pair_attr_projected"],
|
||||
"prot_atm_to_prot_atm_knn": features["knn_feat_protatm_protatm"],
|
||||
"prot_atm_to_residue": protatm_res_pair_encoding,
|
||||
"residue_to_prot_atm": protatm_res_pair_encoding,
|
||||
"sampled_residue_to_sampled_residue": features["res_res_grid_attr_flat_decin"],
|
||||
}
|
||||
|
||||
if not batch["misc"]["protein_only"]:
|
||||
max(metadata["num_i_per_sample"])
|
||||
timestep_lig = features["timestep_encoding_lig"]
|
||||
lig_atm_rep_in = features["lig_atom_attr_projected"]
|
||||
lig_frame_rep_in = features["lig_trp_attr_decin"]
|
||||
# Ligand atom embedding. Two timescales
|
||||
lig_atm_rep_int = self.atm_embed(
|
||||
torch.cat(
|
||||
[lig_atm_rep_in, self.time_encoding(timestep_lig)],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
lig_atm_rep = self.init_scalar_vec_rep(lig_atm_rep_int)
|
||||
|
||||
# Prepare ligand atom - sidechain atom indexers
|
||||
# Initialize coordinate features
|
||||
lig_atm_coords = features["input_ligand_coords"].clone()
|
||||
|
||||
lig_frame_atm_idx = (
|
||||
indexer["gather_idx_ijk_i"][indexer["gather_idx_I_ijk"]],
|
||||
indexer["gather_idx_ijk_j"][indexer["gather_idx_I_ijk"]],
|
||||
indexer["gather_idx_ijk_k"][indexer["gather_idx_I_ijk"]],
|
||||
)
|
||||
ligand_trp_frames = get_frame_matrix(
|
||||
lig_atm_coords[lig_frame_atm_idx[0]],
|
||||
lig_atm_coords[lig_frame_atm_idx[1]],
|
||||
lig_atm_coords[lig_frame_atm_idx[2]],
|
||||
)
|
||||
lig_frame_rep = self.init_scalar_vec_rep(
|
||||
self.input_node_projector(lig_frame_rep_in),
|
||||
x_v=self.input_node_vec_projector(lig_frame_rep_in),
|
||||
frame=ligand_trp_frames,
|
||||
)
|
||||
node_reps.update(
|
||||
{
|
||||
"lig_atm": lig_atm_rep,
|
||||
"lig_trp": lig_frame_rep,
|
||||
}
|
||||
)
|
||||
edge_reps.update(
|
||||
{
|
||||
"lig_atm_to_lig_atm_graph": features["lig_atom_pair_attr_projected"],
|
||||
"sampled_lig_triplet_to_lig_atm": features["lig_af_pair_attr_projected"],
|
||||
"lig_atm_to_sampled_lig_triplet": features["lig_af_pair_attr_projected"],
|
||||
"sampled_residue_to_sampled_lig_triplet": features[
|
||||
"res_trp_grid_attr_flat_decin"
|
||||
],
|
||||
"sampled_lig_triplet_to_sampled_residue": features[
|
||||
"res_trp_grid_attr_flat_decin"
|
||||
],
|
||||
"residue_to_sampled_lig_triplet": features["res_trp_pair_attr_flat_decin"],
|
||||
"sampled_lig_triplet_to_residue": features["res_trp_pair_attr_flat_decin"],
|
||||
"sampled_lig_triplet_to_sampled_lig_triplet": features[
|
||||
"trp_trp_grid_attr_flat_decin"
|
||||
],
|
||||
"prot_atm_to_lig_atm_knn": features["knn_feat_protatm_ligatm"],
|
||||
"lig_atm_to_prot_atm_knn": features["knn_feat_ligatm_protatm"],
|
||||
"lig_atm_to_lig_atm_knn": features["knn_feat_ligatm_ligatm"],
|
||||
}
|
||||
)
|
||||
|
||||
# Message passing
|
||||
protatm_res_idx_res = indexer["protatm_res_idx_res"]
|
||||
if batch["misc"]["protein_only"]:
|
||||
graph_relations = self.graph_relations_no_ligand
|
||||
else:
|
||||
graph_relations = self.graph_relations
|
||||
graph_batcher = make_multi_relation_graph_batcher(graph_relations, indexer, metadata)
|
||||
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||
merged_node_reps = graph_batcher.collate_node_attr(node_reps)
|
||||
merged_edge_reps = graph_batcher.collate_edge_attr(
|
||||
graph_batcher.zero_pad_edge_attr(edge_reps, self.input_pair_dim, device)
|
||||
)
|
||||
merged_edge_reps = self.input_pair_projector(merged_edge_reps)
|
||||
assert merged_edge_idx[0].shape[0] == merged_edge_reps.shape[0]
|
||||
assert merged_edge_idx[1].shape[0] == merged_edge_reps.shape[0]
|
||||
|
||||
dummy_prot_atm_frames = RigidTransform(prot_atm_coords_flat, R=None)
|
||||
if not batch["misc"]["protein_only"]:
|
||||
dummy_lig_atm_frames = RigidTransform(lig_atm_coords, R=None)
|
||||
merged_node_t = graph_batcher.collate_node_attr(
|
||||
{
|
||||
"prot_res": backbone_frames.t,
|
||||
"prot_atm": dummy_prot_atm_frames.t,
|
||||
"lig_atm": dummy_lig_atm_frames.t,
|
||||
"lig_trp": ligand_trp_frames.t,
|
||||
}
|
||||
)
|
||||
merged_node_R = graph_batcher.collate_node_attr(
|
||||
{
|
||||
"prot_res": backbone_frames.R,
|
||||
"prot_atm": dummy_prot_atm_frames.R,
|
||||
"lig_atm": dummy_lig_atm_frames.R,
|
||||
"lig_trp": ligand_trp_frames.R,
|
||||
}
|
||||
)
|
||||
else:
|
||||
merged_node_t = graph_batcher.collate_node_attr(
|
||||
{
|
||||
"prot_res": backbone_frames.t,
|
||||
"prot_atm": dummy_prot_atm_frames.t,
|
||||
}
|
||||
)
|
||||
merged_node_R = graph_batcher.collate_node_attr(
|
||||
{
|
||||
"prot_res": backbone_frames.R,
|
||||
"prot_atm": dummy_prot_atm_frames.R,
|
||||
}
|
||||
)
|
||||
merged_node_frames = RigidTransform(merged_node_t, merged_node_R)
|
||||
merged_edge_reps = merged_edge_reps + (
|
||||
self.rel_geom_embed(
|
||||
self.rel_geom_enc(merged_node_frames, merged_edge_idx) + merged_edge_reps
|
||||
)
|
||||
)
|
||||
|
||||
# No need to reassign embeddings but need to update point coordinates & frames
|
||||
for block_id in range(self.n_blocks):
|
||||
dummy_prot_atm_frames = RigidTransform(prot_atm_coords_flat, R=None)
|
||||
if not batch["misc"]["protein_only"]:
|
||||
dummy_lig_atm_frames = RigidTransform(lig_atm_coords, R=None)
|
||||
merged_node_t = graph_batcher.collate_node_attr(
|
||||
{
|
||||
"prot_res": backbone_frames.t,
|
||||
"prot_atm": dummy_prot_atm_frames.t,
|
||||
"lig_atm": dummy_lig_atm_frames.t,
|
||||
"lig_trp": ligand_trp_frames.t,
|
||||
}
|
||||
)
|
||||
merged_node_R = graph_batcher.collate_node_attr(
|
||||
{
|
||||
"prot_res": backbone_frames.R,
|
||||
"prot_atm": dummy_prot_atm_frames.R,
|
||||
"lig_atm": dummy_lig_atm_frames.R,
|
||||
"lig_trp": ligand_trp_frames.R,
|
||||
}
|
||||
)
|
||||
else:
|
||||
merged_node_t = graph_batcher.collate_node_attr(
|
||||
{
|
||||
"prot_res": backbone_frames.t,
|
||||
"prot_atm": dummy_prot_atm_frames.t,
|
||||
}
|
||||
)
|
||||
merged_node_R = graph_batcher.collate_node_attr(
|
||||
{
|
||||
"prot_res": backbone_frames.R,
|
||||
"prot_atm": dummy_prot_atm_frames.R,
|
||||
}
|
||||
)
|
||||
# PredictDrift iteration
|
||||
merged_node_reps, merged_edge_reps = self.ipa_modules[block_id](
|
||||
merged_node_reps,
|
||||
merged_edge_idx,
|
||||
t=merged_node_t,
|
||||
R=merged_node_R,
|
||||
x_edge=merged_edge_reps,
|
||||
)
|
||||
offloaded_node_reps = graph_batcher.offload_node_attr(merged_node_reps)
|
||||
if "lig_trp" in offloaded_node_reps.keys():
|
||||
lig_frame_rep = offloaded_node_reps["lig_trp"]
|
||||
offloaded_node_reps["lig_trp"] = lig_frame_rep + self.res_adapters[block_id](
|
||||
lig_frame_rep, ligand_trp_frames.R, extra_feats=lig_frame_rep_in
|
||||
)
|
||||
prot_res_rep = offloaded_node_reps["prot_res"]
|
||||
offloaded_node_reps["prot_res"] = prot_res_rep + self.res_adapters[block_id](
|
||||
prot_res_rep, backbone_frames.R, extra_feats=prot_res_rep_in
|
||||
)
|
||||
merged_node_reps = graph_batcher.collate_node_attr(offloaded_node_reps)
|
||||
# Displacement vectors in the global coordinate system
|
||||
if not batch["misc"]["protein_only"]:
|
||||
drift_trp = (
|
||||
self.out_drift_res[block_id](offloaded_node_reps["lig_trp"][:, 1:]).squeeze(-1)
|
||||
* torch.sigmoid(
|
||||
self.out_scale_res[block_id](offloaded_node_reps["lig_trp"][:, 0])
|
||||
)
|
||||
* 10
|
||||
)
|
||||
drift_trp_gathered = segment_mean(
|
||||
drift_trp,
|
||||
indexer["gather_idx_I_molid"],
|
||||
metadata["num_molid"],
|
||||
)[indexer["gather_idx_i_molid"]]
|
||||
drift_atm = self.out_drift_atm[block_id](
|
||||
offloaded_node_reps["lig_atm"][:, 1:]
|
||||
).squeeze(-1) * torch.sigmoid(
|
||||
self.out_scale_atm[block_id](offloaded_node_reps["lig_atm"][:, 0])
|
||||
)
|
||||
if not frozen_lig:
|
||||
lig_atm_coords = lig_atm_coords + drift_atm + drift_trp_gathered
|
||||
ligand_trp_frames = get_frame_matrix(
|
||||
lig_atm_coords[lig_frame_atm_idx[0]],
|
||||
lig_atm_coords[lig_frame_atm_idx[1]],
|
||||
lig_atm_coords[lig_frame_atm_idx[2]],
|
||||
)
|
||||
|
||||
drift_bb = (
|
||||
self.out_drift_res[block_id](offloaded_node_reps["prot_res"][:, 1:]).squeeze(-1)
|
||||
* torch.sigmoid(
|
||||
self.out_scale_res[block_id](offloaded_node_reps["prot_res"][:, 0])
|
||||
)
|
||||
* 10
|
||||
)
|
||||
drift_bb_gathered = drift_bb[protatm_res_idx_res]
|
||||
drift_prot_atm_int = self.out_drift_atm[block_id](
|
||||
offloaded_node_reps["prot_atm"][:, 1:]
|
||||
).squeeze(-1) * torch.sigmoid(
|
||||
self.out_scale_atm[block_id](offloaded_node_reps["prot_atm"][:, 0])
|
||||
)
|
||||
if not frozen_prot:
|
||||
prot_atm_coords_flat = (
|
||||
prot_atm_coords_flat + drift_prot_atm_int + drift_bb_gathered
|
||||
)
|
||||
|
||||
prot_atm_coords_padded = torch.zeros_like(features["input_protein_coords"])
|
||||
prot_atm_coords_padded[protatm_padding_mask] = prot_atm_coords_flat
|
||||
backbone_frames = get_frame_matrix(
|
||||
prot_atm_coords_padded[:, 0],
|
||||
prot_atm_coords_padded[:, 1],
|
||||
prot_atm_coords_padded[:, 2],
|
||||
)
|
||||
|
||||
ret = {
|
||||
"final_embedding_prot_atom": offloaded_node_reps["prot_atm"],
|
||||
"final_embedding_prot_res": offloaded_node_reps["prot_res"],
|
||||
"final_coords_prot_atom": prot_atm_coords_flat,
|
||||
"final_coords_prot_atom_padded": prot_atm_coords_padded,
|
||||
}
|
||||
if not batch["misc"]["protein_only"]:
|
||||
ret["final_embedding_lig_atom"] = offloaded_node_reps["lig_atm"]
|
||||
ret["final_coords_lig_atom"] = lig_atm_coords
|
||||
else:
|
||||
ret["final_embedding_lig_atom"] = None
|
||||
ret["final_coords_lig_atom"] = None
|
||||
return ret
|
||||
|
||||
|
||||
def resolve_score_head(
|
||||
protein_model_cfg: DictConfig,
|
||||
score_cfg: DictConfig,
|
||||
task_cfg: DictConfig,
|
||||
state_dict: Optional[STATE_DICT] = None,
|
||||
) -> torch.nn.Module:
|
||||
"""Instantiates an EquivariantStructureDenoisingModule model for protein-ligand complex
|
||||
structure denoising.
|
||||
|
||||
:param protein_model_cfg: Protein model configuration.
|
||||
:param score_cfg: Score configuration.
|
||||
:param task_cfg: Task configuration.
|
||||
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||
:return: EquivariantStructureDenoisingModule model.
|
||||
"""
|
||||
model = EquivariantStructureDenoisingModule(
|
||||
score_cfg.fiber_dim,
|
||||
input_dim=protein_model_cfg.residue_dim,
|
||||
input_pair_dim=protein_model_cfg.pair_dim,
|
||||
hidden_dim=score_cfg.hidden_dim,
|
||||
n_stacks=score_cfg.n_stacks,
|
||||
n_heads=protein_model_cfg.n_heads,
|
||||
dropout=task_cfg.dropout,
|
||||
)
|
||||
if score_cfg.from_pretrained and state_dict is not None:
|
||||
try:
|
||||
model.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("score_head")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained score weights.")
|
||||
except Exception as e:
|
||||
log.warning(f"Skipping loading of pretrained score weights due to: {e}.")
|
||||
return model
|
||||
|
||||
|
||||
def resolve_confidence_head(
|
||||
protein_model_cfg: DictConfig,
|
||||
confidence_cfg: DictConfig,
|
||||
task_cfg: DictConfig,
|
||||
state_dict: Optional[STATE_DICT] = None,
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module]:
|
||||
"""Instantiates an EquivariantStructureDenoisingModule model for confidence prediction.
|
||||
|
||||
:param protein_model_cfg: Protein model configuration.
|
||||
:param confidence_cfg: Confidence configuration.
|
||||
:param task_cfg: Task configuration.
|
||||
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||
:return: EquivariantStructureDenoisingModule model and plDDT gram head weights.
|
||||
"""
|
||||
confidence_head = EquivariantStructureDenoisingModule(
|
||||
confidence_cfg.fiber_dim,
|
||||
input_dim=protein_model_cfg.residue_dim,
|
||||
input_pair_dim=protein_model_cfg.pair_dim,
|
||||
hidden_dim=confidence_cfg.hidden_dim,
|
||||
n_stacks=confidence_cfg.n_stacks,
|
||||
n_heads=protein_model_cfg.n_heads,
|
||||
dropout=task_cfg.dropout,
|
||||
)
|
||||
plddt_gram_head = GELUMLP(
|
||||
protein_model_cfg.pair_dim,
|
||||
8,
|
||||
n_hidden_feats=protein_model_cfg.pair_dim,
|
||||
zero_init=True,
|
||||
)
|
||||
if confidence_cfg.from_pretrained and state_dict is not None:
|
||||
try:
|
||||
confidence_head.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("confidence_head")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained confidence weights.")
|
||||
except Exception as e:
|
||||
log.warning(f"Skipping loading of pretrained confidence weights due to: {e}.")
|
||||
|
||||
try:
|
||||
plddt_gram_head.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("plddt_gram_head")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained pLDDT gram head weights.")
|
||||
except Exception as e:
|
||||
log.warning(f"Skipping loading of pretrained pLDDT gram head weights due to: {e}.")
|
||||
return confidence_head, plddt_gram_head
|
||||
|
||||
|
||||
def resolve_affinity_head(
|
||||
ligand_model_cfg: DictConfig,
|
||||
affinity_cfg: DictConfig,
|
||||
task_cfg: DictConfig,
|
||||
learnable_pooling: bool = True,
|
||||
state_dict: Optional[STATE_DICT] = None,
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
|
||||
"""Instantiates an EquivariantStructureDenoisingModule model for affinity prediction.
|
||||
|
||||
:param ligand_model_cfg: Ligand model configuration.
|
||||
:param affinity_cfg: Affinity configuration.
|
||||
:param task_cfg: Task configuration.
|
||||
:param learnable_pooling: Whether to use learnable ligand pooling modules.
|
||||
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||
:return: EquivariantStructureDenoisingModule model as well as a ligand pooling module and
|
||||
projection head.
|
||||
"""
|
||||
affinity_head = EquivariantStructureDenoisingModule(
|
||||
affinity_cfg.fiber_dim,
|
||||
input_dim=ligand_model_cfg.node_channels,
|
||||
input_pair_dim=ligand_model_cfg.pair_channels,
|
||||
hidden_dim=affinity_cfg.hidden_dim,
|
||||
n_stacks=affinity_cfg.n_stacks,
|
||||
n_heads=ligand_model_cfg.n_heads,
|
||||
dropout=affinity_cfg.dropout if affinity_cfg.get("dropout") else task_cfg.dropout,
|
||||
)
|
||||
if affinity_cfg.ligand_pooling in ["sum", "add", "summation", "addition"]:
|
||||
ligand_pooling = SumPooling(learnable=learnable_pooling, hidden_dim=affinity_cfg.fiber_dim)
|
||||
elif affinity_cfg.ligand_pooling in ["mean", "avg", "average"]:
|
||||
ligand_pooling = AveragePooling(
|
||||
learnable=learnable_pooling, hidden_dim=affinity_cfg.fiber_dim
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported ligand pooling method: {affinity_cfg.ligand_pooling}"
|
||||
)
|
||||
affinity_proj_head = GELUMLP(
|
||||
affinity_cfg.fiber_dim,
|
||||
1,
|
||||
n_hidden_feats=affinity_cfg.fiber_dim,
|
||||
zero_init=True,
|
||||
)
|
||||
if affinity_cfg.from_pretrained and state_dict is not None:
|
||||
try:
|
||||
affinity_head.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("affinity_head")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained affinity head weights.")
|
||||
except Exception as e:
|
||||
log.warning(f"Skipping loading of pretrained affinity head weights due to: {e}.")
|
||||
|
||||
if learnable_pooling:
|
||||
try:
|
||||
ligand_pooling.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("ligand_pooling")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained ligand pooling weights.")
|
||||
except Exception as e:
|
||||
log.warning(f"Skipping loading of pretrained ligand pooling weights due to: {e}.")
|
||||
|
||||
try:
|
||||
affinity_proj_head.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("affinity_proj_head")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained affinity projection head weights.")
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping loading of pretrained affinity projection head weights due to: {e}."
|
||||
)
|
||||
return affinity_head, ligand_pooling, affinity_proj_head
|
||||
2029
flowdock/models/components/flowdock.py
Normal file
2029
flowdock/models/components/flowdock.py
Normal file
File diff suppressed because it is too large
Load Diff
166
flowdock/models/components/hetero_graph.py
Normal file
166
flowdock/models/components/hetero_graph.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from beartype.typing import Dict, List, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Relation:
|
||||
edge_type: str
|
||||
edge_rev_name: str
|
||||
edge_frd_name: str
|
||||
src_node_type: str
|
||||
dst_node_type: str
|
||||
num_edges: int
|
||||
|
||||
|
||||
class MultiRelationGraphBatcher:
|
||||
"""Collate sub-graphs of different node/edge types into a single instance.
|
||||
|
||||
Returned multi-relation edge indices are stored in LongTensor of shape [2, N_edges].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
relation_forms: List[Relation],
|
||||
graph_metadata: Dict[str, int],
|
||||
):
|
||||
"""Initialize the batcher."""
|
||||
self._relation_forms = relation_forms
|
||||
self._make_offset_dict(graph_metadata)
|
||||
|
||||
def _make_offset_dict(self, graph_metadata):
|
||||
"""Create offset dictionaries for node and edge types."""
|
||||
self._node_chunk_sizes = {}
|
||||
self._edge_chunk_sizes = {}
|
||||
self._offsets_lower = {}
|
||||
self._offsets_upper = {}
|
||||
all_node_types = set()
|
||||
for relation in self._relation_forms:
|
||||
assert (
|
||||
f"num_{relation.src_node_type}" in graph_metadata.keys()
|
||||
), f"Missing metadata: num_{relation.src_node_type}"
|
||||
assert (
|
||||
f"num_{relation.dst_node_type}" in graph_metadata.keys()
|
||||
), f"Missing metadata: num_{relation.src_node_type}"
|
||||
all_node_types.add(relation.src_node_type)
|
||||
all_node_types.add(relation.dst_node_type)
|
||||
offset = 0
|
||||
# Fix node type ordering
|
||||
self.all_node_types = list(all_node_types)
|
||||
for node_type in self.all_node_types:
|
||||
self._offsets_lower[node_type] = offset
|
||||
self._node_chunk_sizes[node_type] = graph_metadata[f"num_{node_type}"]
|
||||
new_offset = offset + self._node_chunk_sizes[node_type]
|
||||
self._offsets_upper[node_type] = new_offset
|
||||
offset = new_offset
|
||||
|
||||
def collate_single_relation_graphs(self, indexer, node_attr_dict, edge_attr_dict):
|
||||
"""Collate sub-graphs of different node/edge types into a single instance."""
|
||||
return {
|
||||
"node_attr": self.collate_node_attr(node_attr_dict),
|
||||
"edge_attr": self.collate_edge_attr(edge_attr_dict),
|
||||
"edge_index": self.collate_idx_list(indexer),
|
||||
}
|
||||
|
||||
def collate_idx_list(
|
||||
self,
|
||||
indexer: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Collate edge indices for all relations."""
|
||||
ret_eidxs_rev, ret_eidxs_frd = [], []
|
||||
for relation in self._relation_forms:
|
||||
assert relation.edge_rev_name in indexer.keys()
|
||||
assert relation.edge_frd_name in indexer.keys()
|
||||
assert indexer[relation.edge_rev_name].dim() == 1
|
||||
assert indexer[relation.edge_frd_name].dim() == 1
|
||||
assert torch.all(
|
||||
indexer[relation.edge_rev_name] < self._node_chunk_sizes[relation.src_node_type]
|
||||
), f"Node index on edge exceeding boundary: {relation.edge_type}, {self._node_chunk_sizes[relation.src_node_type]}, {self._node_chunk_sizes[relation.dst_node_type]}, {max(indexer[relation.edge_rev_name])}, {max(indexer[relation.edge_frd_name])}"
|
||||
assert torch.all(
|
||||
indexer[relation.edge_frd_name] < self._node_chunk_sizes[relation.dst_node_type]
|
||||
), f"Node index on edge exceeding boundary: {relation.edge_type}, {self._node_chunk_sizes[relation.src_node_type]}, {self._node_chunk_sizes[relation.dst_node_type]}, {max(indexer[relation.edge_rev_name])}, {max(indexer[relation.edge_frd_name])}"
|
||||
ret_eidxs_rev.append(
|
||||
indexer[relation.edge_rev_name] + self._offsets_lower[relation.src_node_type]
|
||||
)
|
||||
ret_eidxs_frd.append(
|
||||
indexer[relation.edge_frd_name] + self._offsets_lower[relation.dst_node_type]
|
||||
)
|
||||
ret_eidxs_rev = torch.cat(ret_eidxs_rev, dim=0)
|
||||
ret_eidxs_frd = torch.cat(ret_eidxs_frd, dim=0)
|
||||
return torch.stack([ret_eidxs_rev, ret_eidxs_frd], dim=0)
|
||||
|
||||
def collate_node_attr(self, node_attr_dict: Dict[str, torch.Tensor]):
|
||||
"""Collate node attributes for all node types."""
|
||||
for node_type in self.all_node_types:
|
||||
assert (
|
||||
node_attr_dict[node_type].shape[0] == self._node_chunk_sizes[node_type]
|
||||
), f"Node count mismatch: {node_type}, {node_attr_dict[node_type].shape[0]}, {self._node_chunk_sizes[node_type]}"
|
||||
return torch.cat([node_attr_dict[node_type] for node_type in self.all_node_types], dim=0)
|
||||
|
||||
def collate_edge_attr(self, edge_attr_dict: Dict[str, torch.Tensor]):
|
||||
"""Collate edge attributes for all relations."""
|
||||
# for relation in self._relation_forms:
|
||||
# print(relation.edge_type, edge_attr_dict[relation.edge_type].shape)
|
||||
return torch.cat(
|
||||
[edge_attr_dict[relation.edge_type] for relation in self._relation_forms],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def zero_pad_edge_attr(
|
||||
self,
|
||||
edge_attr_dict: Dict[str, torch.Tensor],
|
||||
embedding_dim: int,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Zero pad edge attributes for all relations."""
|
||||
for relation in self._relation_forms:
|
||||
if edge_attr_dict[relation.edge_type] is None:
|
||||
edge_attr_dict[relation.edge_type] = torch.zeros(
|
||||
(relation.num_edges, embedding_dim),
|
||||
device=device,
|
||||
)
|
||||
return edge_attr_dict
|
||||
|
||||
def offload_node_attr(self, cat_node_attr: torch.Tensor):
|
||||
"""Offload node attributes for all node types."""
|
||||
node_chunk_sizes = [self._node_chunk_sizes[node_type] for node_type in self.all_node_types]
|
||||
node_attr_split = torch.split(cat_node_attr, node_chunk_sizes)
|
||||
return {
|
||||
self.all_node_types[i]: node_attr_split[i] for i in range(len(self.all_node_types))
|
||||
}
|
||||
|
||||
def offload_edge_attr(self, cat_edge_attr: torch.Tensor):
|
||||
"""Offload edge attributes for all relations."""
|
||||
edge_chunk_sizes = [relation.num_edges for relation in self._relation_forms]
|
||||
edge_attr_split = torch.split(cat_edge_attr, edge_chunk_sizes)
|
||||
return {
|
||||
self._relation_forms[i].edge_type: edge_attr_split[i]
|
||||
for i in range(len(self._relation_forms))
|
||||
}
|
||||
|
||||
|
||||
def make_multi_relation_graph_batcher(
|
||||
list_of_relations: List[Tuple[str, str, str, str, str]],
|
||||
indexer,
|
||||
metadata,
|
||||
):
|
||||
"""Make a multi-relation graph batcher."""
|
||||
# Use one instantiation of the indexer to compute chunk sizes
|
||||
relation_forms = [
|
||||
Relation(
|
||||
edge_type=rl_tuple[0],
|
||||
edge_rev_name=rl_tuple[1],
|
||||
edge_frd_name=rl_tuple[2],
|
||||
src_node_type=rl_tuple[3],
|
||||
dst_node_type=rl_tuple[4],
|
||||
num_edges=indexer[rl_tuple[1]].shape[0],
|
||||
)
|
||||
for rl_tuple in list_of_relations
|
||||
]
|
||||
return MultiRelationGraphBatcher(
|
||||
relation_forms,
|
||||
metadata,
|
||||
)
|
||||
1439
flowdock/models/components/losses.py
Normal file
1439
flowdock/models/components/losses.py
Normal file
File diff suppressed because it is too large
Load Diff
364
flowdock/models/components/mht_encoder.py
Normal file
364
flowdock/models/components/mht_encoder.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, Optional, Tuple
|
||||
from omegaconf import DictConfig
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.models.components.hetero_graph import make_multi_relation_graph_batcher
|
||||
from flowdock.models.components.modules import TransformerLayer
|
||||
from flowdock.utils import RankedLogger
|
||||
from flowdock.utils.model_utils import GELUMLP, segment_softmax, segment_sum
|
||||
|
||||
MODEL_BATCH = Dict[str, Any]
|
||||
STATE_DICT = Dict[str, Any]
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class PathConvStack(torch.nn.Module):
|
||||
"""Path integral convolution stack for ligand encoding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pair_channels: int,
|
||||
n_heads: int = 8,
|
||||
max_pi_length: int = 8,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Initialize PathConvStack model."""
|
||||
super().__init__()
|
||||
self.pair_channels = pair_channels
|
||||
self.max_pi_length = max_pi_length
|
||||
self.n_heads = n_heads
|
||||
|
||||
self.prop_value_layer = torch.nn.Linear(pair_channels, n_heads, bias=False)
|
||||
self.triangle_pair_kernel_layer = torch.nn.Linear(pair_channels, n_heads, bias=False)
|
||||
self.prop_update_mlp = GELUMLP(
|
||||
n_heads * (max_pi_length + 1), pair_channels, dropout=dropout
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prop_attr: torch.Tensor,
|
||||
stereo_attr: torch.Tensor,
|
||||
indexer: Dict[str, torch.LongTensor],
|
||||
metadata: Dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for PathConvStack model.
|
||||
|
||||
:param prop_attr: Atom-frame pair attributes.
|
||||
:param stereo_attr: Stereochemistry attributes.
|
||||
:param indexer: A dictionary of indices.
|
||||
:param metadata: A dictionary of metadata.
|
||||
:return: Updated atom-frame pair attributes.
|
||||
"""
|
||||
triangle_pair_kernel = self.triangle_pair_kernel_layer(stereo_attr)
|
||||
# Segment-wise softmax, normalized by outgoing triangles
|
||||
triangle_pair_alpha = segment_softmax(
|
||||
triangle_pair_kernel, indexer["gather_idx_ijkl_jkl"], metadata["num_ijk"]
|
||||
) # .div(self.max_pi_length)
|
||||
# Uijk,ijkl->ujkl pair representation update
|
||||
kernel = triangle_pair_alpha[indexer["gather_idx_Uijkl_ijkl"]]
|
||||
out_prop_attr = [self.prop_value_layer(prop_attr)]
|
||||
for _ in range(self.max_pi_length):
|
||||
gathered_prop_attr = out_prop_attr[-1][indexer["gather_idx_Uijkl_Uijk"]]
|
||||
out_prop_attr.append(
|
||||
segment_sum(
|
||||
kernel.mul(gathered_prop_attr),
|
||||
indexer["gather_idx_Uijkl_ujkl"],
|
||||
metadata["num_Uijk"],
|
||||
)
|
||||
)
|
||||
new_prop_attr = torch.cat(out_prop_attr, dim=-1)
|
||||
new_prop_attr = self.prop_update_mlp(new_prop_attr) + prop_attr
|
||||
return new_prop_attr
|
||||
|
||||
|
||||
class PIFormer(torch.nn.Module):
|
||||
"""PIFormer model for ligand encoding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_channels: int,
|
||||
pair_channels: int,
|
||||
n_atom_encodings: int,
|
||||
n_bond_encodings: int,
|
||||
n_atom_pos_encodings: int,
|
||||
n_stereo_encodings: int,
|
||||
heads: int,
|
||||
head_dim: int,
|
||||
max_path_length: int = 4,
|
||||
n_transformer_stacks=4,
|
||||
hidden_dim: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Initialize PIFormer model."""
|
||||
super().__init__()
|
||||
self.node_channels = node_channels
|
||||
self.pair_channels = pair_channels
|
||||
self.max_pi_length = max_path_length
|
||||
self.n_transformer_stacks = n_transformer_stacks
|
||||
self.n_atom_encodings = n_atom_encodings
|
||||
self.n_bond_encodings = n_bond_encodings
|
||||
self.n_atom_pair_encodings = n_bond_encodings + 4
|
||||
self.n_atom_pos_encodings = n_atom_pos_encodings
|
||||
|
||||
self.input_atom_layer = torch.nn.Linear(n_atom_encodings, node_channels)
|
||||
self.input_pair_layer = torch.nn.Linear(self.n_atom_pair_encodings, pair_channels)
|
||||
self.input_stereo_layer = torch.nn.Linear(n_stereo_encodings, pair_channels)
|
||||
self.input_prop_layer = GELUMLP(
|
||||
self.n_atom_pair_encodings * 3,
|
||||
pair_channels,
|
||||
)
|
||||
self.path_integral_stacks = torch.nn.ModuleList(
|
||||
[
|
||||
PathConvStack(
|
||||
pair_channels,
|
||||
max_pi_length=max_path_length,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(n_transformer_stacks)
|
||||
]
|
||||
)
|
||||
self.graph_transformer_stacks = torch.nn.ModuleList(
|
||||
[
|
||||
TransformerLayer(
|
||||
node_channels,
|
||||
heads,
|
||||
head_dim=head_dim,
|
||||
edge_channels=pair_channels,
|
||||
hidden_dim=hidden_dim,
|
||||
dropout=dropout,
|
||||
edge_update=True,
|
||||
)
|
||||
for _ in range(n_transformer_stacks)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, batch: MODEL_BATCH, masking_rate: float = 0.0) -> MODEL_BATCH:
|
||||
"""Forward pass for PIFormer model.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param masking_rate: Masking rate.
|
||||
:return: A batch dictionary.
|
||||
"""
|
||||
features = batch["features"]
|
||||
indexer = batch["indexer"]
|
||||
metadata = batch["metadata"]
|
||||
features["atom_encodings"] = features["atom_encodings"]
|
||||
atom_attr = features["atom_encodings"]
|
||||
atom_pair_attr = features["atom_pair_encodings"]
|
||||
af_pair_attr = features["atom_frame_pair_encodings"]
|
||||
stereo_enc = features["stereo_chemistry_encodings"]
|
||||
batch["features"]["lig_atom_token"] = atom_attr.detach().clone()
|
||||
batch["features"]["lig_pair_token"] = atom_pair_attr.detach().clone()
|
||||
|
||||
atom_mask = torch.rand(atom_attr.shape[0], device=atom_attr.device) > masking_rate
|
||||
stereo_mask = torch.rand(stereo_enc.shape[0], device=stereo_enc.device) > masking_rate
|
||||
atom_pair_mask = (
|
||||
torch.rand(atom_pair_attr.shape[0], device=atom_pair_attr.device) > masking_rate
|
||||
)
|
||||
af_pair_mask = (
|
||||
torch.rand(af_pair_attr.shape[0], device=atom_pair_attr.device) > masking_rate
|
||||
)
|
||||
atom_attr = atom_attr * atom_mask[:, None]
|
||||
stereo_enc = stereo_enc * stereo_mask[:, None]
|
||||
atom_pair_attr = atom_pair_attr * atom_pair_mask[:, None]
|
||||
af_pair_attr = af_pair_attr * af_pair_mask[:, None]
|
||||
|
||||
# Embedding blocks
|
||||
metadata["num_atom"] = metadata["num_u"]
|
||||
metadata["num_frame"] = metadata["num_ijk"]
|
||||
atom_attr = self.input_atom_layer(atom_attr)
|
||||
atom_pair_attr = self.input_pair_layer(atom_pair_attr)
|
||||
triangle_attr = atom_attr.new_zeros(metadata["num_frame"], self.node_channels)
|
||||
# Initialize atom-frame pair attributes. Reusing uv indices
|
||||
prop_attr = self.input_prop_layer(af_pair_attr)
|
||||
stereo_attr = self.input_stereo_layer(stereo_enc)
|
||||
|
||||
graph_relations = [
|
||||
("atom_to_atom", "gather_idx_uv_u", "gather_idx_uv_v", "atom", "atom"),
|
||||
(
|
||||
"atom_to_frame",
|
||||
"gather_idx_Uijk_u",
|
||||
"gather_idx_Uijk_ijk",
|
||||
"atom",
|
||||
"frame",
|
||||
),
|
||||
(
|
||||
"frame_to_atom",
|
||||
"gather_idx_Uijk_ijk",
|
||||
"gather_idx_Uijk_u",
|
||||
"frame",
|
||||
"atom",
|
||||
),
|
||||
(
|
||||
"frame_to_frame",
|
||||
"gather_idx_ijkl_ijk",
|
||||
"gather_idx_ijkl_jkl",
|
||||
"frame",
|
||||
"frame",
|
||||
),
|
||||
]
|
||||
|
||||
graph_batcher = make_multi_relation_graph_batcher(graph_relations, indexer, metadata)
|
||||
merged_edge_idx = graph_batcher.collate_idx_list(indexer)
|
||||
node_reps = {"atom": atom_attr, "frame": triangle_attr}
|
||||
edge_reps = {
|
||||
"atom_to_atom": atom_pair_attr,
|
||||
"atom_to_frame": prop_attr,
|
||||
"frame_to_atom": prop_attr,
|
||||
"frame_to_frame": stereo_attr,
|
||||
}
|
||||
|
||||
# Graph path integral recursion
|
||||
for block_id in range(self.n_transformer_stacks):
|
||||
merged_node_attr = graph_batcher.collate_node_attr(node_reps)
|
||||
merged_edge_attr = graph_batcher.collate_edge_attr(edge_reps)
|
||||
_, merged_node_attr, merged_edge_attr = self.graph_transformer_stacks[block_id](
|
||||
merged_node_attr,
|
||||
merged_node_attr,
|
||||
merged_edge_idx,
|
||||
merged_edge_attr,
|
||||
)
|
||||
node_reps = graph_batcher.offload_node_attr(merged_node_attr)
|
||||
edge_reps = graph_batcher.offload_edge_attr(merged_edge_attr)
|
||||
prop_attr = edge_reps["atom_to_frame"]
|
||||
stereo_attr = edge_reps["frame_to_frame"]
|
||||
prop_attr = prop_attr + self.path_integral_stacks[block_id](
|
||||
prop_attr,
|
||||
stereo_attr,
|
||||
indexer,
|
||||
metadata,
|
||||
)
|
||||
edge_reps["atom_to_frame"] = prop_attr
|
||||
|
||||
node_reps["sampled_frame"] = node_reps["frame"][indexer["gather_idx_I_ijk"]]
|
||||
|
||||
batch["metadata"]["num_lig_atm"] = metadata["num_u"]
|
||||
batch["metadata"]["num_lig_trp"] = metadata["num_I"]
|
||||
|
||||
batch["features"]["lig_atom_attr"] = node_reps["atom"]
|
||||
# Downsampled ligand frames
|
||||
batch["features"]["lig_trp_attr"] = node_reps["sampled_frame"]
|
||||
batch["features"]["lig_atom_pair_attr"] = edge_reps["atom_to_atom"]
|
||||
batch["features"]["lig_prop_attr"] = edge_reps["atom_to_frame"]
|
||||
edge_reps["sampled_atom_to_sampled_frame"] = edge_reps["atom_to_frame"][
|
||||
indexer["gather_idx_UI_Uijk"]
|
||||
]
|
||||
batch["features"]["lig_af_pair_attr"] = edge_reps["sampled_atom_to_sampled_frame"]
|
||||
return batch
|
||||
|
||||
|
||||
def resolve_ligand_encoder(
|
||||
ligand_model_cfg: DictConfig,
|
||||
task_cfg: DictConfig,
|
||||
state_dict: Optional[STATE_DICT] = None,
|
||||
) -> torch.nn.Module:
|
||||
"""Instantiates a PIFormer model for ligand encoding.
|
||||
|
||||
:param ligand_model_cfg: Ligand model configuration.
|
||||
:param task_cfg: Task configuration.
|
||||
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||
:return: Ligand encoder model.
|
||||
"""
|
||||
model = PIFormer(
|
||||
ligand_model_cfg.node_channels,
|
||||
ligand_model_cfg.pair_channels,
|
||||
ligand_model_cfg.n_atom_encodings,
|
||||
ligand_model_cfg.n_bond_encodings,
|
||||
ligand_model_cfg.n_atom_pos_encodings,
|
||||
ligand_model_cfg.n_stereo_encodings,
|
||||
ligand_model_cfg.n_attention_heads,
|
||||
ligand_model_cfg.attention_head_dim,
|
||||
hidden_dim=ligand_model_cfg.hidden_dim,
|
||||
max_path_length=ligand_model_cfg.max_path_integral_length,
|
||||
n_transformer_stacks=ligand_model_cfg.n_transformer_stacks,
|
||||
dropout=task_cfg.dropout,
|
||||
)
|
||||
if ligand_model_cfg.from_pretrained and state_dict is not None:
|
||||
try:
|
||||
model.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("ligand_encoder")
|
||||
}
|
||||
)
|
||||
log.info(
|
||||
"Successfully loaded pretrained ligand Molecular Heat Transformer (MHT) weights."
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping loading of pretrained ligand Molecular Heat Transformer (MHT) weights due to: {e}."
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def resolve_relational_reasoning_module(
|
||||
protein_model_cfg: DictConfig,
|
||||
ligand_model_cfg: DictConfig,
|
||||
relational_reasoning_cfg: DictConfig,
|
||||
state_dict: Optional[STATE_DICT] = None,
|
||||
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
|
||||
"""Instantiates relational reasoning module for ligand encoding.
|
||||
|
||||
:param protein_model_cfg: Protein model configuration.
|
||||
:param ligand_model_cfg: Ligand model configuration.
|
||||
:param relational_reasoning_cfg: Relational reasoning configuration.
|
||||
:param state_dict: Optional (potentially-pretrained) state dictionary.
|
||||
:return: Relational reasoning modules for ligand encoding.
|
||||
"""
|
||||
molgraph_single_projector = torch.nn.Linear(
|
||||
ligand_model_cfg.node_channels, protein_model_cfg.residue_dim, bias=False
|
||||
)
|
||||
molgraph_pair_projector = torch.nn.Linear(
|
||||
ligand_model_cfg.pair_channels, protein_model_cfg.pair_dim, bias=False
|
||||
)
|
||||
covalent_embed = torch.nn.Embedding(2, protein_model_cfg.pair_dim)
|
||||
if relational_reasoning_cfg.from_pretrained and state_dict is not None:
|
||||
try:
|
||||
molgraph_single_projector.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("molgraph_single_projector")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained ligand graph single projector weights.")
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping loading of pretrained ligand graph single projector weights due to: {e}."
|
||||
)
|
||||
|
||||
try:
|
||||
molgraph_pair_projector.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("molgraph_pair_projector")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained ligand graph pair projector weights.")
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping loading of pretrained ligand graph pair projector weights due to: {e}."
|
||||
)
|
||||
|
||||
try:
|
||||
covalent_embed.load_state_dict(
|
||||
{
|
||||
".".join(k.split(".")[1:]): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("covalent_embed")
|
||||
}
|
||||
)
|
||||
log.info("Successfully loaded pretrained ligand covalent embedding weights.")
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping loading of pretrained ligand covalent embedding weights due to: {e}."
|
||||
)
|
||||
return molgraph_single_projector, molgraph_pair_projector, covalent_embed
|
||||
423
flowdock/models/components/modules.py
Normal file
423
flowdock/models/components/modules.py
Normal file
@@ -0,0 +1,423 @@
|
||||
import math
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from beartype.typing import Optional, Tuple, Union
|
||||
from openfold.model.primitives import Attention
|
||||
from openfold.utils.tensor_utils import permute_final_dims
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.utils.model_utils import GELUMLP, segment_softmax
|
||||
|
||||
|
||||
class MultiHeadAttentionConv(torch.nn.Module):
|
||||
"""Native Pytorch implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: Union[int, Tuple[int, int]],
|
||||
head_dim: int,
|
||||
edge_dim: int = None,
|
||||
n_heads: int = 1,
|
||||
dropout: float = 0.0,
|
||||
edge_lin: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Multi-Head Attention Convolution layer."""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = n_heads
|
||||
self.dropout = dropout
|
||||
self.edge_dim = edge_dim
|
||||
self.edge_lin = edge_lin
|
||||
self._alpha = None
|
||||
|
||||
if isinstance(dim, int):
|
||||
dim = (dim, dim)
|
||||
|
||||
self.lin_key = torch.nn.Linear(dim[0], n_heads * head_dim, bias=False)
|
||||
self.lin_query = torch.nn.Linear(dim[1], n_heads * head_dim, bias=False)
|
||||
self.lin_value = torch.nn.Linear(dim[0], n_heads * head_dim, bias=False)
|
||||
if edge_lin is True:
|
||||
self.lin_edge = torch.nn.Linear(edge_dim, n_heads, bias=False)
|
||||
else:
|
||||
self.lin_edge = self.register_parameter("lin_edge", None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset the parameters of the layer."""
|
||||
self.lin_key.reset_parameters()
|
||||
self.lin_query.reset_parameters()
|
||||
self.lin_value.reset_parameters()
|
||||
if self.edge_lin:
|
||||
self.lin_edge.reset_parameters()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: torch.Tensor = None,
|
||||
return_attention_weights=None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Forward pass of the Multi-Head Attention Convolution layer.
|
||||
|
||||
:param x (torch.Tensor or Tuple[torch.Tensor, torch.Tensor]): The input features.
|
||||
:param edge_index (torch.Tensor): The edge index tensor.
|
||||
:param edge_attr (torch.Tensor, optional): The edge attribute tensor.
|
||||
:param return_attention_weights (bool, optional): If set to `True`,
|
||||
will additionally return the tuple
|
||||
`(edge_index, attention_weights)`, holding the computed
|
||||
attention weights for each edge. Default is `None`.
|
||||
:return: The output features or the tuple
|
||||
`(output_features, (edge_index, attention_weights))`.
|
||||
"""
|
||||
|
||||
H, C = self.n_heads, self.head_dim
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = (x, x)
|
||||
|
||||
query = self.lin_query(x[1]).view(*x[1].shape[:-1], H, C)
|
||||
key = self.lin_key(x[0]).view(*x[0].shape[:-1], H, C)
|
||||
value = self.lin_value(x[0]).view(*x[0].shape[:-1], H, C)
|
||||
|
||||
attended_values = self.message(key, query, value, edge_attr, edge_index)
|
||||
out = self.aggregate(attended_values, edge_index[1], query.shape[0])
|
||||
|
||||
alpha = self._alpha
|
||||
self._alpha = None
|
||||
|
||||
out = out.contiguous().view(*out.shape[:-2], H * C)
|
||||
|
||||
if isinstance(return_attention_weights, bool):
|
||||
assert alpha is not None
|
||||
return out, (edge_index, alpha)
|
||||
else:
|
||||
return out
|
||||
|
||||
def message(
|
||||
self,
|
||||
key: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Add the relative positional encodings to attention scores.
|
||||
|
||||
:param key (torch.Tensor): The key tensor.
|
||||
:param query (torch.Tensor): The query tensor.
|
||||
:param value (torch.Tensor): The value tensor.
|
||||
:param edge_attr (torch.Tensor): The edge attribute tensor.
|
||||
:param index (torch.Tensor): The edge index tensor.
|
||||
:return: The output tensor.
|
||||
"""
|
||||
edge_bias = 0
|
||||
if self.lin_edge is not None:
|
||||
assert edge_attr is not None
|
||||
edge_bias = self.lin_edge(edge_attr)
|
||||
|
||||
_alpha_z = (query[index[1]] * key[index[0]]).sum(dim=-1) / math.sqrt(
|
||||
self.head_dim
|
||||
) + edge_bias
|
||||
self._alpha = _alpha_z
|
||||
alpha = segment_softmax(_alpha_z, index[1], query.shape[0])
|
||||
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
||||
|
||||
out = value[index[0]]
|
||||
out *= alpha.unsqueeze(-1)
|
||||
return out
|
||||
|
||||
def aggregate(
|
||||
self, src: torch.Tensor, dst_idx: torch.Tensor, dst_size: torch.Size
|
||||
) -> torch.Tensor:
|
||||
"""Aggregate the source tensor to the destination tensor.
|
||||
|
||||
:param src (torch.Tensor): The source tensor.
|
||||
:param dst_idx (torch.Tensor): The destination index tensor.
|
||||
:param dst_size (torch.Size): The destination size tensor.
|
||||
:return: The output tensor.
|
||||
"""
|
||||
out = torch.zeros(
|
||||
dst_size,
|
||||
*src.shape[1:],
|
||||
dtype=src.dtype,
|
||||
device=src.device,
|
||||
).index_add_(0, dst_idx, src)
|
||||
return out
|
||||
|
||||
|
||||
class TransformerLayer(torch.nn.Module):
|
||||
"""A single layer of a transformer model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_dim: int,
|
||||
n_heads: int,
|
||||
head_dim: Optional[int] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
bidirectional: bool = False,
|
||||
edge_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
edge_update: bool = False,
|
||||
):
|
||||
"""Initialize the transformer layer."""
|
||||
super().__init__()
|
||||
edge_lin = edge_channels is not None
|
||||
self.edge_update = edge_update
|
||||
if head_dim is None:
|
||||
head_dim = node_dim // n_heads
|
||||
self.conv = MultiHeadAttentionConv(
|
||||
node_dim,
|
||||
head_dim,
|
||||
edge_dim=edge_channels,
|
||||
n_heads=n_heads,
|
||||
edge_lin=edge_lin,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.bidirectional = bidirectional
|
||||
self.projector = torch.nn.Linear(head_dim * n_heads, node_dim, bias=False)
|
||||
self.norm = torch.nn.LayerNorm(node_dim)
|
||||
self.mlp = GELUMLP(
|
||||
node_dim,
|
||||
node_dim,
|
||||
n_hidden_feats=hidden_dim,
|
||||
dropout=dropout,
|
||||
zero_init=True,
|
||||
)
|
||||
if edge_update:
|
||||
self.mlpe = GELUMLP(
|
||||
n_heads + edge_channels, edge_channels, dropout=dropout, zero_init=True
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_s: torch.Tensor,
|
||||
x_a: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Forward pass through the transformer layer.
|
||||
|
||||
:param x_s (torch.Tensor): The source node features.
|
||||
:param x_a (torch.Tensor): The target node features.
|
||||
:param edge_index (torch.Tensor): The edge index tensor. :param edge_attr (torch.Tensor,
|
||||
optional): The edge attribute tensor.
|
||||
:return: The output source and target node features.
|
||||
"""
|
||||
out_a, (edge_index, alpha) = self.conv(
|
||||
(x_s, x_a),
|
||||
edge_index,
|
||||
edge_attr,
|
||||
return_attention_weights=True,
|
||||
)
|
||||
x_a = x_a + self.projector(out_a)
|
||||
x_a = self.mlp(self.norm(x_a)) + x_a
|
||||
if self.bidirectional:
|
||||
out_s = self.conv((x_a, x_s), (edge_index[1], edge_index[0]), edge_attr)
|
||||
x_s = x_s + self.projector(out_s)
|
||||
x_s = self.mlp(self.norm(x_s)) + x_s
|
||||
if self.edge_update:
|
||||
edge_attr = edge_attr + self.mlpe(torch.cat([alpha, edge_attr], dim=-1))
|
||||
return x_s, x_a, edge_attr
|
||||
else:
|
||||
return x_s, x_a
|
||||
|
||||
|
||||
class PointSetAttention(torch.nn.Module):
|
||||
"""PointSetAttention module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fiber_dim: int,
|
||||
heads: int = 8,
|
||||
point_dim: int = 4,
|
||||
edge_dim: Optional[int] = None,
|
||||
edge_update: bool = False,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Initialize the PointSetAttention module."""
|
||||
super().__init__()
|
||||
self.fiber_dim = fiber_dim
|
||||
self.edge_dim = edge_dim
|
||||
self.heads = heads
|
||||
self.point_dim = point_dim
|
||||
self.dropout = dropout
|
||||
self.edge_update = edge_update
|
||||
self.distance_scaling = 10 # 1 nm
|
||||
|
||||
# num attention contributions
|
||||
num_attn_logits = 2
|
||||
|
||||
self.lin_query = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
|
||||
self.lin_key = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
|
||||
self.lin_value = torch.nn.Linear(fiber_dim, point_dim * heads, bias=False)
|
||||
if edge_dim is not None:
|
||||
self.lin_edge = torch.nn.Linear(edge_dim, heads, bias=False)
|
||||
if edge_update:
|
||||
self.edge_update_mlp = GELUMLP(heads + edge_dim, edge_dim)
|
||||
|
||||
# qkv projection for scalar attention (normal)
|
||||
self.scalar_attn_logits_scale = (num_attn_logits * point_dim) ** -0.5
|
||||
|
||||
# qkv projection for point attention (coordinate and orientation aware)
|
||||
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.0)) - 1.0)
|
||||
self.point_weights = torch.nn.Parameter(point_weight_init_value)
|
||||
|
||||
self.point_attn_logits_scale = ((num_attn_logits * point_dim) * (9 / 2)) ** -0.5
|
||||
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.0)) - 1.0)
|
||||
self.point_weights = torch.nn.Parameter(point_weight_init_value)
|
||||
|
||||
# combine out - point dim * 4
|
||||
self.to_out = torch.nn.Linear(heads * point_dim, fiber_dim, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_k: torch.Tensor,
|
||||
x_q: torch.Tensor,
|
||||
edge_index: torch.LongTensor,
|
||||
point_centers_k: torch.Tensor,
|
||||
point_centers_q: torch.Tensor,
|
||||
x_edge: torch.Tensor = None,
|
||||
):
|
||||
"""Forward pass of the PointSetAttention module."""
|
||||
H, P = self.heads, self.point_dim
|
||||
|
||||
q = self.lin_query(x_q)
|
||||
k = self.lin_key(x_k)
|
||||
v = self.lin_value(x_k)
|
||||
|
||||
scalar_q = q[..., 0, :].view(-1, H, P)
|
||||
scalar_k = k[..., 0, :].view(-1, H, P)
|
||||
scalar_v = v[..., 0, :].view(-1, H, P)
|
||||
|
||||
point_q_local = q[..., 1:, :].view(-1, 3, H, P)
|
||||
point_k_local = k[..., 1:, :].view(-1, 3, H, P)
|
||||
point_v_local = v[..., 1:, :].view(-1, 3, H, P)
|
||||
|
||||
point_q = point_q_local + point_centers_q[..., None, None] / self.distance_scaling
|
||||
point_k = point_k_local + point_centers_k[..., None, None] / self.distance_scaling
|
||||
point_v = point_v_local + point_centers_k[..., None, None] / self.distance_scaling
|
||||
|
||||
if self.edge_dim is not None:
|
||||
edge_bias = self.lin_edge(x_edge)
|
||||
else:
|
||||
edge_bias = 0
|
||||
|
||||
attn_logits, attentions = self.compute_attention(
|
||||
scalar_k, scalar_q, point_k, point_q, edge_bias, edge_index
|
||||
)
|
||||
res_scalar = self.aggregate(
|
||||
attentions[:, :, None] * scalar_v[edge_index[0]],
|
||||
edge_index[1],
|
||||
scalar_q.shape[0],
|
||||
)
|
||||
res_points = self.aggregate(
|
||||
attentions[:, None, :, None] * point_v[edge_index[0]],
|
||||
edge_index[1],
|
||||
point_q.shape[0],
|
||||
)
|
||||
res_points_local = res_points - point_centers_q[..., None, None] / self.distance_scaling
|
||||
|
||||
# [N, H, P], [N, 3, H, P] -> [N, 4, C]
|
||||
res = torch.cat([res_scalar.unsqueeze(-3), res_points_local], dim=-3).flatten(-2, -1)
|
||||
out = self.to_out(res) # [N, 4, C]
|
||||
if self.edge_update:
|
||||
edge_out = self.edge_update_mlp(torch.cat([attn_logits, x_edge], dim=-1))
|
||||
return out, edge_out
|
||||
return out
|
||||
|
||||
def compute_attention(self, scalar_k, scalar_q, point_k, point_q, edge_bias, index):
|
||||
"""Compute the attention scores."""
|
||||
scalar_q = scalar_q[index[1]]
|
||||
scalar_k = scalar_k[index[0]]
|
||||
point_q = point_q[index[1]]
|
||||
point_k = point_k[index[0]]
|
||||
|
||||
scalar_logits = (scalar_q * scalar_k).sum(dim=-1) * self.scalar_attn_logits_scale
|
||||
point_weights = F.softplus(self.point_weights).unsqueeze(0)
|
||||
point_logits = (
|
||||
torch.square(point_q - point_k).sum(dim=(-3, -1)) * self.point_attn_logits_scale
|
||||
)
|
||||
|
||||
logits = scalar_logits - 1 / 2 * point_logits * point_weights + edge_bias
|
||||
alpha = segment_softmax(logits, index[1], scalar_q.shape[0])
|
||||
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
||||
return logits, alpha
|
||||
|
||||
@staticmethod
|
||||
def aggregate(src, dst_idx, dst_size):
|
||||
"""Aggregate the source tensor to the destination tensor."""
|
||||
out = torch.zeros(
|
||||
dst_size,
|
||||
*src.shape[1:],
|
||||
dtype=src.dtype,
|
||||
device=src.device,
|
||||
).index_add_(0, dst_idx, src)
|
||||
return out
|
||||
|
||||
|
||||
class BiDirectionalTriangleAttention(torch.nn.Module):
|
||||
"""
|
||||
Adapted from https://github.com/aqlaboratory/openfold
|
||||
supports rectangular pair representation tensors
|
||||
"""
|
||||
|
||||
def __init__(self, c_in: int, c_hidden: int, no_heads: int, inf: float = 1e9):
|
||||
"""Initialize the Bi-Directional Triangle Attention layer."""
|
||||
super().__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
|
||||
self.linear = torch.nn.Linear(c_in, self.no_heads, bias=False)
|
||||
|
||||
self.mha_1 = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
|
||||
self.mha_2 = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
|
||||
self.layer_norm = torch.nn.LayerNorm(self.c_in)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x1: torch.Tensor,
|
||||
x2: torch.Tensor,
|
||||
x_pair: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
use_lma: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass of the Bi-Directional Triangle Attention layer."""
|
||||
if mask is None:
|
||||
# [*, I, J, K]
|
||||
mask = x_pair.new_ones(
|
||||
x_pair.shape[:-1],
|
||||
)
|
||||
|
||||
# [*, I, J, C_in]
|
||||
x1 = self.layer_norm(x1)
|
||||
# [*, I, K, C_in]
|
||||
x2 = self.layer_norm(x2)
|
||||
|
||||
# [*, I, 1, J, K]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, :, :]
|
||||
|
||||
# [*, I, H, J, K]
|
||||
triangle_bias = permute_final_dims(self.linear(x_pair), [0, 3, 1, 2])
|
||||
|
||||
biases_J2I = [mask_bias, triangle_bias]
|
||||
|
||||
x1_out = self.mha_1(q_x=x1, kv_x=x2, biases=biases_J2I, use_lma=use_lma)
|
||||
x1 = x1 + x1_out
|
||||
|
||||
# transpose the triangle bias for I->J attention.
|
||||
mask_bias_T_ = mask_bias.transpose(-2, -1).contiguous()
|
||||
triangle_bias_T_ = triangle_bias.transpose(-2, -1).contiguous()
|
||||
biases_I2J = [mask_bias_T_, triangle_bias_T_]
|
||||
x2_out = self.mha_2(q_x=x2, kv_x=x1, biases=biases_I2J, use_lma=use_lma)
|
||||
x2 = x2 + x2_out
|
||||
|
||||
return x1, x2
|
||||
443
flowdock/models/components/noise.py
Normal file
443
flowdock/models/components/noise.py
Normal file
@@ -0,0 +1,443 @@
|
||||
import numpy as np
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.models.components.transforms import LatentCoordinateConverter
|
||||
from flowdock.utils import RankedLogger
|
||||
from flowdock.utils.model_utils import segment_mean
|
||||
|
||||
MODEL_BATCH = Dict[str, Any]
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class DiffusionSDE:
|
||||
"""Diffusion SDE class.
|
||||
|
||||
Adapted from: https://github.com/HannesStark/FlowSite
|
||||
"""
|
||||
|
||||
def __init__(self, sigma: torch.Tensor, tau_factor: float = 5.0):
|
||||
"""Initialize the Diffusion SDE class."""
|
||||
self.lamb = 1 / sigma**2
|
||||
self.tau_factor = tau_factor
|
||||
|
||||
def var(self, t: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the variance of the diffusion SDE."""
|
||||
return (1 - torch.exp(-self.lamb * t)) / self.lamb
|
||||
|
||||
def max_t(self) -> float:
|
||||
"""Calculate the maximum time of the diffusion SDE."""
|
||||
return self.tau_factor / self.lamb
|
||||
|
||||
def mu_factor(self, t: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the mu factor of the diffusion SDE."""
|
||||
return torch.exp(-self.lamb * t / 2)
|
||||
|
||||
|
||||
class HarmonicSDE:
|
||||
"""Harmonic SDE class.
|
||||
|
||||
Adapted from: https://github.com/HannesStark/FlowSite
|
||||
"""
|
||||
|
||||
def __init__(self, J: Optional[torch.Tensor] = None, diagonalize: bool = True):
|
||||
"""Initialize the Harmonic SDE class."""
|
||||
self.l_index = 1
|
||||
self.use_cuda = False
|
||||
if not diagonalize:
|
||||
return
|
||||
if J is not None:
|
||||
self.D, self.P = np.linalg.eigh(J)
|
||||
self.N = self.D.size
|
||||
|
||||
@staticmethod
|
||||
def diagonalize(
|
||||
N,
|
||||
ptr: torch.Tensor,
|
||||
edges: Optional[List[Tuple[int, int]]] = None,
|
||||
antiedges: Optional[List[Tuple[int, int]]] = None,
|
||||
a=1,
|
||||
b=0.3,
|
||||
lamb: Optional[torch.Tensor] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""Diagonalize using the Harmonic SDE."""
|
||||
device = device or ptr.device
|
||||
J = torch.zeros((N, N), device=device)
|
||||
if edges is None:
|
||||
for i, j in zip(np.arange(N - 1), np.arange(1, N)):
|
||||
J[i, i] += a
|
||||
J[j, j] += a
|
||||
J[i, j] = J[j, i] = -a
|
||||
else:
|
||||
for i, j in edges:
|
||||
J[i, i] += a
|
||||
J[j, j] += a
|
||||
J[i, j] = J[j, i] = -a
|
||||
if antiedges is not None:
|
||||
for i, j in antiedges:
|
||||
J[i, i] -= b
|
||||
J[j, j] -= b
|
||||
J[i, j] = J[j, i] = b
|
||||
if edges is not None:
|
||||
J += torch.diag(lamb)
|
||||
|
||||
Ds, Ps = [], []
|
||||
for start, end in zip(ptr[:-1], ptr[1:]):
|
||||
D, P = torch.linalg.eigh(J[start:end, start:end])
|
||||
D_ = D
|
||||
if edges is None:
|
||||
D_inv = 1 / D
|
||||
D_inv[0] = 0
|
||||
D_ = D_inv
|
||||
Ds.append(D_)
|
||||
Ps.append(P)
|
||||
return torch.cat(Ds), torch.block_diag(*Ps)
|
||||
|
||||
def eigens(self, t):
|
||||
"""Calculate the eigenvalues of `sigma_t` using the Harmonic SDE."""
|
||||
np_ = torch if self.use_cuda else np
|
||||
D = 1 / self.D * (1 - np_.exp(-t * self.D))
|
||||
t = torch.tensor(t, device="cuda").float() if self.use_cuda else t
|
||||
return np_.where(D != 0, D, t)
|
||||
|
||||
def conditional(self, mask, x2):
|
||||
"""Calculate the conditional distribution using the Harmonic SDE."""
|
||||
J_11 = self.J[~mask][:, ~mask]
|
||||
J_12 = self.J[~mask][:, mask]
|
||||
h = -J_12 @ x2
|
||||
mu = np.linalg.inv(J_11) @ h
|
||||
D, P = np.linalg.eigh(J_11)
|
||||
z = np.random.randn(*mu.shape)
|
||||
return (P / D**0.5) @ z + mu
|
||||
|
||||
def A(self, t, invT=False):
|
||||
"""Calculate the matrix `A` using the Harmonic SDE."""
|
||||
D = self.eigens(t)
|
||||
A = self.P * (D**0.5)
|
||||
if not invT:
|
||||
return A
|
||||
AinvT = self.P / (D**0.5)
|
||||
return A, AinvT
|
||||
|
||||
def Sigma_inv(self, t):
|
||||
"""Calculate the inverse of the covariance matrix `Sigma` using the Harmonic SDE."""
|
||||
D = 1 / self.eigens(t)
|
||||
return (self.P * D) @ self.P.T
|
||||
|
||||
def Sigma(self, t):
|
||||
"""Calculate the covariance matrix `Sigma` using the Harmonic SDE."""
|
||||
D = self.eigens(t)
|
||||
return (self.P * D) @ self.P.T
|
||||
|
||||
@property
|
||||
def J(self):
|
||||
"""Return the matrix `J`."""
|
||||
return (self.P * self.D) @ self.P.T
|
||||
|
||||
def rmsd(self, t):
|
||||
"""Calculate the root mean square deviation using the Harmonic SDE."""
|
||||
l_index = self.l_index
|
||||
D = 1 / self.D * (1 - np.exp(-t * self.D))
|
||||
return np.sqrt(3 * D[l_index:].mean())
|
||||
|
||||
def sample(self, t, x=None, score=False, k=None, center=True, adj=False):
|
||||
"""Sample from the Harmonic SDE."""
|
||||
l_index = self.l_index
|
||||
np_ = torch if self.use_cuda else np
|
||||
if x is None:
|
||||
if self.use_cuda:
|
||||
x = torch.zeros((self.N, 3), device="cuda").float()
|
||||
else:
|
||||
x = np.zeros((self.N, 3))
|
||||
if t == 0:
|
||||
return x
|
||||
z = (
|
||||
np.random.randn(self.N, 3)
|
||||
if not self.use_cuda
|
||||
else torch.randn(self.N, 3, device="cuda").float()
|
||||
)
|
||||
D = self.eigens(t)
|
||||
xx = self.P.T @ x
|
||||
if center:
|
||||
z[0] = 0
|
||||
xx[0] = 0
|
||||
if k:
|
||||
z[k + l_index :] = 0
|
||||
xx[k + l_index :] = 0
|
||||
|
||||
out = np_.exp(-t * self.D / 2)[:, None] * xx + np_.sqrt(D)[:, None] * z
|
||||
|
||||
if score:
|
||||
score = -(1 / np_.sqrt(D))[:, None] * z
|
||||
if adj:
|
||||
score = score + self.D[:, None] * out
|
||||
return self.P @ out, self.P @ score
|
||||
return self.P @ out
|
||||
|
||||
def score_norm(self, t, k=None, adj=False):
|
||||
"""Calculate the score norm using the Harmonic SDE."""
|
||||
if k == 0:
|
||||
return 0
|
||||
l_index = self.l_index
|
||||
np_ = torch if self.use_cuda else np
|
||||
k = k or self.N - 1
|
||||
D = 1 / self.eigens(t)
|
||||
if adj:
|
||||
D = D * np_.exp(-self.D * t)
|
||||
return (D[l_index : k + l_index].sum() / self.N) ** 0.5
|
||||
|
||||
def inject(self, t, modes):
|
||||
"""Inject noise along the given modes using the Harmonic SDE."""
|
||||
z = (
|
||||
np.random.randn(self.N, 3)
|
||||
if not self.use_cuda
|
||||
else torch.randn(self.N, 3, device="cuda").float()
|
||||
)
|
||||
z[~modes] = 0
|
||||
A = self.A(t, invT=False)
|
||||
return A @ z
|
||||
|
||||
def score(self, x0, xt, t):
|
||||
"""Calculate the score of the diffusion kernel using the Harmonic SDE."""
|
||||
Sigma_inv = self.Sigma_inv(t)
|
||||
mu_t = (self.P * np.exp(-t * self.D / 2)) @ (self.P.T @ x0)
|
||||
return Sigma_inv @ (mu_t - xt)
|
||||
|
||||
def project(self, X, k, center=False):
|
||||
"""Project onto the first `k` nonzero modes using the Harmonic SDE."""
|
||||
l_index = self.l_index
|
||||
D = self.P.T @ X
|
||||
D[k + l_index :] = 0
|
||||
if center:
|
||||
D[0] = 0
|
||||
return self.P @ D
|
||||
|
||||
def unproject(self, X, mask, k, return_Pinv=False):
|
||||
"""Find the vector along the first k nonzero modes whose mask is closest to `X`"""
|
||||
l_index = self.l_index
|
||||
PP = self.P[mask, : k + l_index]
|
||||
Pinv = np.linalg.pinv(PP)
|
||||
out = self.P[:, : k + l_index] @ Pinv @ X
|
||||
if return_Pinv:
|
||||
return out, Pinv
|
||||
return out
|
||||
|
||||
def energy(self, X):
|
||||
"""Calculate the energy using the Harmonic SDE."""
|
||||
l_index = self.l_index
|
||||
return (self.D[:, None] * (self.P.T @ X) ** 2).sum(-1)[l_index:] / 2
|
||||
|
||||
@property
|
||||
def free_energy(self):
|
||||
"""Calculate the free energy using the Harmonic SDE."""
|
||||
l_index = self.l_index
|
||||
return 3 * np.log(self.D[l_index:]).sum() / 2
|
||||
|
||||
def KL_H(self, t):
|
||||
"""Calculate the Kullback-Leibler divergence using the Harmonic SDE."""
|
||||
l_index = self.l_index
|
||||
D = self.D[l_index:]
|
||||
return -3 * 0.5 * (np.log(1 - np.exp(-D * t)) + np.exp(-D * t)).sum(0)
|
||||
|
||||
|
||||
def sample_gaussian_prior(
|
||||
x0: torch.Tensor,
|
||||
latent_converter: LatentCoordinateConverter,
|
||||
sigma: float,
|
||||
x0_sigma: float = 1e-4,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Sample noise from a Gaussian prior distribution.
|
||||
|
||||
:param x0: ground-truth tensor
|
||||
:param latent_converter: The latent coordinate converter
|
||||
:param sigma: standard deviation of the Gaussian noise
|
||||
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
|
||||
:return: tuple of ground-truth and predicted tensors with additive Gaussian prior noise
|
||||
"""
|
||||
prior = torch.randn_like(x0)
|
||||
x_int_0 = x0 + prior * x0_sigma # add small Gaussian noise to the ground-truth tensor
|
||||
(
|
||||
x1_ca_lat,
|
||||
x1_cother_lat,
|
||||
x1_lig_lat,
|
||||
) = torch.split(
|
||||
prior * sigma,
|
||||
[
|
||||
latent_converter._n_res_per_sample,
|
||||
latent_converter._n_cother_per_sample,
|
||||
latent_converter._n_ligha_per_sample,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
x_int_1 = torch.cat(
|
||||
[
|
||||
x1_ca_lat,
|
||||
x1_cother_lat,
|
||||
x1_lig_lat,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return x_int_0, x_int_1
|
||||
|
||||
|
||||
def sample_protein_harmonic_prior(
|
||||
protein_ca_x0: torch.Tensor,
|
||||
protein_cother_x0: torch.Tensor,
|
||||
batch: MODEL_BATCH,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample protein noise from a harmonic prior distribution.
|
||||
Adapted from: https://github.com/bjing2016/alphaflow
|
||||
|
||||
Note that this function represents non-Ca atoms as Gaussian noise
|
||||
centered around each harmonically-noised Ca atom.
|
||||
|
||||
:param protein_ca_x0: ground-truth protein Ca-atom tensor
|
||||
:param protein_cother_x0: ground-truth protein other-atom tensor
|
||||
:param batch: A batch dictionary
|
||||
:return: tuple of harmonic protein Ca atom noise and Gaussian protein other atom noise
|
||||
"""
|
||||
indexer = batch["indexer"]
|
||||
metadata = batch["metadata"]
|
||||
protein_bid = indexer["gather_idx_a_structid"]
|
||||
protein_num_nodes = protein_ca_x0.size(0) * protein_ca_x0.size(1)
|
||||
ptr = torch.cumsum(torch.bincount(protein_bid), dim=0)
|
||||
ptr = torch.cat((torch.tensor([0], device=protein_bid.device), ptr))
|
||||
try:
|
||||
D_inv, P = HarmonicSDE.diagonalize(
|
||||
protein_num_nodes,
|
||||
ptr,
|
||||
a=3 / (3.8**2),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to call HarmonicSDE.diagonalize() for protein(s) {metadata['sample_ID_per_sample']} due to: {e}"
|
||||
)
|
||||
raise e
|
||||
noise = torch.randn((protein_num_nodes, 3), device=protein_ca_x0.device)
|
||||
harmonic_ca_noise = P @ (torch.sqrt(D_inv)[:, None] * noise)
|
||||
gaussian_cother_noise = (
|
||||
torch.randn_like(protein_cother_x0.flatten(0, 1))
|
||||
+ harmonic_ca_noise[indexer["gather_idx_a_cotherid"]]
|
||||
)
|
||||
return (
|
||||
harmonic_ca_noise.view(protein_ca_x0.size()).contiguous(),
|
||||
gaussian_cother_noise.view(protein_cother_x0.size()).contiguous(),
|
||||
)
|
||||
|
||||
|
||||
def sample_ligand_harmonic_prior(
|
||||
lig_x0: torch.Tensor, protein_ca_x0: torch.Tensor, batch: MODEL_BATCH, sigma: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample ligand noise from a harmonic prior distribution.
|
||||
Adapted from: https://github.com/HannesStark/FlowSite
|
||||
|
||||
:param lig_x0: ground-truth ligand tensor
|
||||
:param protein_x0: ground-truth protein Ca-atom tensor
|
||||
:param batch: A batch dictionary
|
||||
:param sigma: standard deviation of the harmonic noise
|
||||
:return: tensor of harmonic noise
|
||||
"""
|
||||
indexer = batch["indexer"]
|
||||
metadata = batch["metadata"]
|
||||
lig_num_nodes = lig_x0.size(0) * lig_x0.size(1)
|
||||
num_molid_per_sample = max(metadata["num_molid_per_sample"])
|
||||
# NOTE: here, we distinguish each ligand chain in a complex for harmonic chain sampling
|
||||
lig_bid = indexer["gather_idx_i_molid"]
|
||||
protein_sigma = (
|
||||
segment_mean(
|
||||
torch.square(protein_ca_x0).flatten(0, 1),
|
||||
indexer["gather_idx_a_structid"],
|
||||
metadata["num_structid"],
|
||||
).mean(dim=-1)
|
||||
** 0.5
|
||||
).repeat_interleave(num_molid_per_sample)
|
||||
sde = DiffusionSDE(protein_sigma * sigma)
|
||||
edges = torch.stack(
|
||||
(
|
||||
indexer["gather_idx_ij_i"],
|
||||
indexer["gather_idx_ij_j"],
|
||||
)
|
||||
)
|
||||
edges = edges[:, edges[0] < edges[1]] # de-duplicate edges
|
||||
ptr = torch.cumsum(torch.bincount(lig_bid), dim=0)
|
||||
ptr = torch.cat((torch.tensor([0], device=lig_bid.device), ptr))
|
||||
try:
|
||||
D, P = HarmonicSDE.diagonalize(
|
||||
lig_num_nodes,
|
||||
ptr,
|
||||
edges=edges.T,
|
||||
lamb=sde.lamb[lig_bid],
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to call HarmonicSDE.diagonalize() for ligand(s) {metadata['sample_ID_per_sample']} due to: {e}"
|
||||
)
|
||||
raise e
|
||||
noise = torch.randn((lig_num_nodes, 3), device=lig_x0.device)
|
||||
prior = P @ (noise / torch.sqrt(D)[:, None])
|
||||
return prior.view(lig_x0.size()).contiguous()
|
||||
|
||||
|
||||
def sample_complex_harmonic_prior(
|
||||
x0: torch.Tensor,
|
||||
latent_converter: LatentCoordinateConverter,
|
||||
batch: MODEL_BATCH,
|
||||
x0_sigma: float = 1e-4,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample protein-ligand complex noise from a harmonic prior distribution.
|
||||
From: https://github.com/bjing2016/alphaflow
|
||||
|
||||
:param x0: ground-truth tensor
|
||||
:param latent_converter: The latent coordinate converter
|
||||
:param batch: A batch dictionary
|
||||
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
|
||||
:return: tuple of ground-truth and predicted tensors with additive Gaussian and harmonic prior noise, respectively
|
||||
"""
|
||||
ca_lat, cother_lat, lig_lat = x0.split(
|
||||
[
|
||||
latent_converter._n_res_per_sample,
|
||||
latent_converter._n_cother_per_sample,
|
||||
latent_converter._n_ligha_per_sample,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
harmonic_ca_lat, gaussian_cother_lat = sample_protein_harmonic_prior(
|
||||
ca_lat,
|
||||
cother_lat,
|
||||
batch,
|
||||
)
|
||||
harmonic_lig_lat = sample_ligand_harmonic_prior(lig_lat, harmonic_ca_lat, batch)
|
||||
x1 = torch.cat(
|
||||
[
|
||||
# NOTE: the following normalization steps assume that `self.latent_model == "default"`
|
||||
harmonic_ca_lat / latent_converter.ca_scale,
|
||||
gaussian_cother_lat / latent_converter.other_scale,
|
||||
harmonic_lig_lat / latent_converter.other_scale,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
gaussian_prior = torch.randn_like(x0)
|
||||
return x0 + gaussian_prior * x0_sigma, x1
|
||||
|
||||
|
||||
def sample_esmfold_prior(
|
||||
x0: torch.Tensor, x1: torch.Tensor, sigma: float, x0_sigma: float = 1e-4
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Sample noise from an ESMFold prior distribution.
|
||||
|
||||
:param x0: ground-truth tensor
|
||||
:param x1: predicted tensor
|
||||
:param sigma: standard deviation of the ESMFold prior's additive Gaussian noise
|
||||
:param x0_sigma: standard deviation of the Gaussian noise for the ground-truth tensor
|
||||
:return: tuple of ground-truth and predicted tensors with additive Gaussian prior noise
|
||||
"""
|
||||
prior_noise = torch.randn_like(x0)
|
||||
return x0 + prior_noise * x0_sigma, x1 + prior_noise * sigma
|
||||
241
flowdock/models/components/transforms.py
Normal file
241
flowdock/models/components/transforms.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.utils.model_utils import segment_mean
|
||||
|
||||
|
||||
class LatentCoordinateConverter:
|
||||
"""Transform the batched feature dict to latent coordinate arrays."""
|
||||
|
||||
def __init__(self, config, prot_atom37_namemap, lig_namemap):
|
||||
"""Initialize the converter."""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.prot_namemap = prot_atom37_namemap
|
||||
self.lig_namemap = lig_namemap
|
||||
self.cached_noise = None
|
||||
self._last_pred_ca_trace = None
|
||||
|
||||
@staticmethod
|
||||
def nested_get(dic, keys):
|
||||
"""Get the value in the nested dictionary."""
|
||||
for key in keys:
|
||||
dic = dic[key]
|
||||
return dic
|
||||
|
||||
@staticmethod
|
||||
def nested_set(dic, keys, value):
|
||||
"""Set the value in the nested dictionary."""
|
||||
for key in keys[:-1]:
|
||||
dic = dic.setdefault(key, {})
|
||||
dic[keys[-1]] = value
|
||||
|
||||
def to_latent(self, batch):
|
||||
"""Convert the batched feature dict to latent coordinates."""
|
||||
return None
|
||||
|
||||
def assign_to_batch(self, batch, x_int):
|
||||
"""Assign the latent coordinates to the batched feature dict."""
|
||||
return None
|
||||
|
||||
|
||||
class DefaultPLCoordinateConverter(LatentCoordinateConverter):
|
||||
"""Minimal conversion, using internal coords for sidechains and global coords for others."""
|
||||
|
||||
def __init__(self, config, prot_atom37_namemap, lig_namemap):
|
||||
"""Initialize the converter."""
|
||||
super().__init__(config, prot_atom37_namemap, lig_namemap)
|
||||
# Scale parameters in Angstrom
|
||||
self.ca_scale = config.global_max_sigma
|
||||
self.other_scale = config.internal_max_sigma
|
||||
|
||||
def to_latent(self, batch: dict):
|
||||
"""Convert the batched feature dict to latent coordinates."""
|
||||
indexer = batch["indexer"]
|
||||
metadata = batch["metadata"]
|
||||
self._batch_size = metadata["num_structid"]
|
||||
atom37_mask = batch["features"]["res_atom_mask"].bool()
|
||||
self._cother_mask = atom37_mask.clone()
|
||||
self._cother_mask[:, 1] = False
|
||||
atom37_coords = self.nested_get(batch, self.prot_namemap[0])
|
||||
try:
|
||||
apo_available = True
|
||||
apo_atom37_coords = self.nested_get(
|
||||
batch, self.prot_namemap[0][:-1] + ("apo_" + self.prot_namemap[0][-1],)
|
||||
)
|
||||
except KeyError:
|
||||
apo_available = False
|
||||
apo_atom37_coords = torch.zeros_like(atom37_coords)
|
||||
ca_atom_centroid_coords = segment_mean(
|
||||
# NOTE: in contrast to NeuralPLexer, we center all coordinates at the origin using the Ca atom centroids
|
||||
atom37_coords[:, 1],
|
||||
indexer["gather_idx_a_structid"],
|
||||
self._batch_size,
|
||||
)
|
||||
if apo_available:
|
||||
apo_ca_atom_centroid_coords = segment_mean(
|
||||
apo_atom37_coords[:, 1],
|
||||
indexer["gather_idx_a_structid"],
|
||||
self._batch_size,
|
||||
)
|
||||
else:
|
||||
apo_ca_atom_centroid_coords = torch.zeros_like(ca_atom_centroid_coords)
|
||||
ca_coords_glob = (
|
||||
(atom37_coords[:, 1] - ca_atom_centroid_coords[indexer["gather_idx_a_structid"]])
|
||||
.contiguous()
|
||||
.view(self._batch_size, -1, 3)
|
||||
)
|
||||
if apo_available:
|
||||
apo_ca_coords_glob = (
|
||||
(
|
||||
apo_atom37_coords[:, 1]
|
||||
- apo_ca_atom_centroid_coords[indexer["gather_idx_a_structid"]]
|
||||
)
|
||||
.contiguous()
|
||||
.view(self._batch_size, -1, 3)
|
||||
)
|
||||
else:
|
||||
apo_ca_coords_glob = torch.zeros_like(ca_coords_glob)
|
||||
cother_coords_int = (
|
||||
(atom37_coords - ca_atom_centroid_coords[indexer["gather_idx_a_structid"], None])[
|
||||
self._cother_mask
|
||||
]
|
||||
.contiguous()
|
||||
.view(self._batch_size, -1, 3)
|
||||
)
|
||||
if apo_available:
|
||||
apo_cother_coords_int = (
|
||||
(
|
||||
apo_atom37_coords
|
||||
- apo_ca_atom_centroid_coords[indexer["gather_idx_a_structid"], None]
|
||||
)[self._cother_mask]
|
||||
.contiguous()
|
||||
.view(self._batch_size, -1, 3)
|
||||
)
|
||||
else:
|
||||
apo_cother_coords_int = torch.zeros_like(cother_coords_int)
|
||||
self._n_res_per_sample = ca_coords_glob.shape[1]
|
||||
self._n_cother_per_sample = cother_coords_int.shape[1]
|
||||
if batch["misc"]["protein_only"]:
|
||||
self._n_ligha_per_sample = 0
|
||||
x_int = torch.cat(
|
||||
[
|
||||
ca_coords_glob / self.ca_scale,
|
||||
apo_ca_coords_glob / self.ca_scale,
|
||||
cother_coords_int / self.other_scale,
|
||||
apo_cother_coords_int / self.other_scale,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return x_int
|
||||
lig_ha_coords = self.nested_get(batch, self.lig_namemap[0])
|
||||
lig_ha_coords_int = (
|
||||
lig_ha_coords - ca_atom_centroid_coords[indexer["gather_idx_i_structid"]]
|
||||
)
|
||||
lig_ha_coords_int = lig_ha_coords_int.contiguous().view(self._batch_size, -1, 3)
|
||||
ca_atom_centroid_coords = ca_atom_centroid_coords.contiguous().view(
|
||||
self._batch_size, -1, 3
|
||||
)
|
||||
apo_ca_atom_centroid_coords = apo_ca_atom_centroid_coords.contiguous().view(
|
||||
self._batch_size, -1, 3
|
||||
)
|
||||
x_int = torch.cat(
|
||||
[
|
||||
ca_coords_glob / self.ca_scale,
|
||||
apo_ca_coords_glob / self.ca_scale,
|
||||
cother_coords_int / self.other_scale,
|
||||
apo_cother_coords_int / self.other_scale,
|
||||
ca_atom_centroid_coords / self.ca_scale,
|
||||
apo_ca_atom_centroid_coords / self.ca_scale,
|
||||
lig_ha_coords_int / self.other_scale,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
# NOTE: since we use the Ca atom centroids for centralization, we have only one molid per sample
|
||||
self._n_molid_per_sample = ca_atom_centroid_coords.shape[1]
|
||||
self._n_ligha_per_sample = lig_ha_coords_int.shape[1]
|
||||
return x_int
|
||||
|
||||
def assign_to_batch(self, batch: dict, x_lat: torch.Tensor):
|
||||
"""Assign the latent coordinates to the batched feature dict."""
|
||||
indexer = batch["indexer"]
|
||||
new_atom37_coords = x_lat.new_zeros(self._batch_size * self._n_res_per_sample, 37, 3)
|
||||
apo_new_atom37_coords = x_lat.new_zeros(self._batch_size * self._n_res_per_sample, 37, 3)
|
||||
if batch["misc"]["protein_only"]:
|
||||
ca_lat, apo_ca_lat, cother_lat, apo_cother_lat = torch.split(
|
||||
x_lat,
|
||||
[
|
||||
self._n_res_per_sample,
|
||||
self._n_res_per_sample,
|
||||
self._n_cother_per_sample,
|
||||
self._n_cother_per_sample,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
(
|
||||
ca_lat,
|
||||
apo_ca_lat,
|
||||
cother_lat,
|
||||
apo_cother_lat,
|
||||
ca_cent_lat,
|
||||
_,
|
||||
lig_lat,
|
||||
) = torch.split(
|
||||
x_lat,
|
||||
[
|
||||
self._n_res_per_sample,
|
||||
self._n_res_per_sample,
|
||||
self._n_cother_per_sample,
|
||||
self._n_cother_per_sample,
|
||||
self._n_molid_per_sample,
|
||||
self._n_molid_per_sample,
|
||||
self._n_ligha_per_sample,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
new_ca_glob = (ca_lat * self.ca_scale).contiguous().flatten(0, 1)
|
||||
apo_new_ca_glob = (apo_ca_lat * self.ca_scale).contiguous().flatten(0, 1)
|
||||
new_atom37_coords[self._cother_mask] = (
|
||||
(cother_lat * self.other_scale).contiguous().flatten(0, 1)
|
||||
)
|
||||
apo_new_atom37_coords[self._cother_mask] = (
|
||||
(apo_cother_lat * self.other_scale).contiguous().flatten(0, 1)
|
||||
)
|
||||
new_atom37_coords = new_atom37_coords
|
||||
apo_new_atom37_coords = apo_new_atom37_coords
|
||||
new_atom37_coords[~self._cother_mask] = 0
|
||||
apo_new_atom37_coords[~self._cother_mask] = 0
|
||||
new_atom37_coords[:, 1] = new_ca_glob
|
||||
apo_new_atom37_coords[:, 1] = apo_new_ca_glob
|
||||
self.nested_set(batch, self.prot_namemap[1], new_atom37_coords)
|
||||
self.nested_set(
|
||||
batch,
|
||||
self.prot_namemap[1][:-1] + ("apo_" + self.prot_namemap[1][-1],),
|
||||
apo_new_atom37_coords,
|
||||
)
|
||||
if batch["misc"]["protein_only"]:
|
||||
self.nested_set(batch, self.lig_namemap[1], None)
|
||||
self.empty_cache()
|
||||
return batch
|
||||
new_ligha_coords_int = (lig_lat * self.other_scale).contiguous().flatten(0, 1)
|
||||
new_ligha_coords_cent = (ca_cent_lat * self.ca_scale).contiguous().flatten(0, 1)
|
||||
new_ligha_coords = (
|
||||
new_ligha_coords_int + new_ligha_coords_cent[indexer["gather_idx_i_structid"]]
|
||||
)
|
||||
self.nested_set(batch, self.lig_namemap[1], new_ligha_coords)
|
||||
self.empty_cache()
|
||||
return batch
|
||||
|
||||
def empty_cache(self):
|
||||
"""Empty the cached variables."""
|
||||
self._batch_size = None
|
||||
self._cother_mask = None
|
||||
self._n_res_per_sample = None
|
||||
self._n_cother_per_sample = None
|
||||
self._n_ligha_per_sample = None
|
||||
self._n_molid_per_sample = None
|
||||
943
flowdock/models/flowdock_fm_module.py
Normal file
943
flowdock/models/flowdock_fm_module.py
Normal file
@@ -0,0 +1,943 @@
|
||||
import os
|
||||
|
||||
import esm
|
||||
import numpy as np
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, Literal, Optional, Union
|
||||
from lightning import LightningModule
|
||||
from omegaconf import DictConfig
|
||||
from torchmetrics.functional.regression import (
|
||||
mean_absolute_error,
|
||||
mean_squared_error,
|
||||
pearson_corrcoef,
|
||||
spearman_corrcoef,
|
||||
)
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.models.components.losses import (
|
||||
eval_auxiliary_estimation_losses,
|
||||
eval_structure_prediction_losses,
|
||||
)
|
||||
from flowdock.utils import RankedLogger
|
||||
from flowdock.utils.data_utils import pdb_filepath_to_protein, prepare_batch
|
||||
from flowdock.utils.model_utils import extract_esm_embeddings
|
||||
from flowdock.utils.sampling_utils import multi_pose_sampling
|
||||
from flowdock.utils.visualization_utils import (
|
||||
construct_prot_lig_pairs,
|
||||
write_prot_lig_pairs_to_pdb_file,
|
||||
)
|
||||
|
||||
MODEL_BATCH = Dict[str, Any]
|
||||
MODEL_STAGE = Literal["train", "val", "test", "predict"]
|
||||
LOSS_MODES_LIST = [
|
||||
"structure_prediction",
|
||||
"auxiliary_estimation",
|
||||
"auxiliary_estimation_without_structure_prediction",
|
||||
]
|
||||
LOSS_MODES = Literal[
|
||||
"structure_prediction",
|
||||
"auxiliary_estimation",
|
||||
"auxiliary_estimation_without_structure_prediction",
|
||||
]
|
||||
AUX_ESTIMATION_STAGES = ["train", "val", "test"]
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class FlowDockFMLitModule(LightningModule):
|
||||
"""A `LightningModule` for geometric flow matching (FM) with FlowDock.
|
||||
|
||||
A `LightningModule` implements 8 key methods:
|
||||
|
||||
```python
|
||||
def __init__(self):
|
||||
# Define initialization code here.
|
||||
|
||||
def setup(self, stage):
|
||||
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
|
||||
# This hook is called on every process when using DDP.
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# The complete training step.
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# The complete validation step.
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
# The complete test step.
|
||||
|
||||
def predict_step(self, batch, batch_idx):
|
||||
# The complete predict step.
|
||||
|
||||
def configure_optimizers(self):
|
||||
# Define and configure optimizers and LR schedulers.
|
||||
```
|
||||
|
||||
Docs:
|
||||
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
net: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler,
|
||||
compile: bool,
|
||||
cfg: DictConfig,
|
||||
**kwargs: Dict[str, Any],
|
||||
):
|
||||
"""Initialize a `FlowDockFMLitModule`.
|
||||
|
||||
:param net: The model to train.
|
||||
:param optimizer: The optimizer to use for training.
|
||||
:param scheduler: The learning rate scheduler to use for training.
|
||||
:param compile: Whether to compile the model before training.
|
||||
:param cfg: The model configuration.
|
||||
:param kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# the model along with its hyperparameters
|
||||
self.net = net(cfg)
|
||||
|
||||
# this line allows to access init params with 'self.hparams' attribute
|
||||
# also ensures init params will be stored in ckpt
|
||||
self.save_hyperparameters(logger=False, ignore=["net"])
|
||||
|
||||
# for validating input arguments
|
||||
if self.hparams.cfg.task.loss_mode not in LOSS_MODES_LIST:
|
||||
raise ValueError(
|
||||
f"Invalid loss mode: {self.hparams.cfg.task.loss_mode}. Must be one of {LOSS_MODES}."
|
||||
)
|
||||
|
||||
# for inspecting the model's outputs during validation and testing
|
||||
(
|
||||
self.training_step_outputs,
|
||||
self.validation_step_outputs,
|
||||
self.test_step_outputs,
|
||||
self.predict_step_outputs,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: MODEL_BATCH,
|
||||
iter_id: Union[int, str] = 0,
|
||||
observed_block_contacts: Optional[torch.Tensor] = None,
|
||||
contact_prediction: bool = True,
|
||||
infer_geometry_prior: bool = False,
|
||||
score: bool = False,
|
||||
affinity: bool = True,
|
||||
use_template: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> MODEL_BATCH:
|
||||
"""Perform a forward pass through the model.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param iter_id: The current iteration ID.
|
||||
:param observed_block_contacts: Observed block contacts.
|
||||
:param contact_prediction: Whether to predict contacts.
|
||||
:param infer_geometry_prior: Whether to predict using a geometry prior.
|
||||
:param score: Whether to predict a denoised complex structure.
|
||||
:param affinity: Whether to predict ligand binding affinity.
|
||||
:param use_template: Whether to use a template protein structure.
|
||||
:param kwargs: Additional keyword arguments.
|
||||
:return: Batch dictionary with outputs.
|
||||
"""
|
||||
return self.net(
|
||||
batch,
|
||||
iter_id=iter_id,
|
||||
observed_block_contacts=observed_block_contacts,
|
||||
contact_prediction=contact_prediction,
|
||||
infer_geometry_prior=infer_geometry_prior,
|
||||
score=score,
|
||||
affinity=affinity,
|
||||
use_template=use_template,
|
||||
training=self.training,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def model_step(
|
||||
self,
|
||||
batch: MODEL_BATCH,
|
||||
batch_idx: int,
|
||||
stage: MODEL_STAGE,
|
||||
loss_mode: Optional[LOSS_MODES] = None,
|
||||
) -> MODEL_BATCH:
|
||||
"""Perform a single model step on a batch of data.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param batch_idx: The index of the current batch.
|
||||
:param stage: The current model stage (i.e., `train`, `val`, `test`, or `predict`).
|
||||
:param loss_mode: The loss mode to use for training.
|
||||
:return: Batch dictionary with losses.
|
||||
"""
|
||||
prepare_batch(batch)
|
||||
predicting_aux_outputs = (
|
||||
self.hparams.cfg.confidence.enabled or self.hparams.cfg.affinity.enabled
|
||||
)
|
||||
is_aux_loss_stage = stage in AUX_ESTIMATION_STAGES
|
||||
is_aux_batch = batch_idx % self.hparams.cfg.task.aux_batch_freq == 0
|
||||
struct_pred_loss_mode_requested = (
|
||||
loss_mode is not None and loss_mode == "structure_prediction"
|
||||
)
|
||||
should_eval_aux_loss = (
|
||||
predicting_aux_outputs
|
||||
and is_aux_loss_stage
|
||||
and is_aux_batch
|
||||
and not struct_pred_loss_mode_requested
|
||||
and (
|
||||
not self.hparams.cfg.task.freeze_confidence
|
||||
or (
|
||||
not self.hparams.cfg.task.freeze_affinity
|
||||
and batch["features"]["affinity"].any().item()
|
||||
)
|
||||
)
|
||||
)
|
||||
eval_aux_loss_mode_requested = (
|
||||
predicting_aux_outputs
|
||||
and loss_mode is not None
|
||||
and "auxiliary_estimation" in loss_mode
|
||||
)
|
||||
if should_eval_aux_loss or eval_aux_loss_mode_requested:
|
||||
return eval_auxiliary_estimation_losses(
|
||||
self, batch, stage, loss_mode, training=self.training
|
||||
)
|
||||
loss_fn = eval_structure_prediction_losses
|
||||
return loss_fn(self, batch, batch_idx, self.device, stage, t_1=1.0)
|
||||
|
||||
def on_train_start(self):
|
||||
"""Lightning hook that is called when training begins."""
|
||||
pass
|
||||
|
||||
def training_step(self, batch: MODEL_BATCH, batch_idx: int) -> torch.Tensor:
|
||||
"""Perform a single training step on a batch of data from the training set.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param batch_idx: The index of the current batch.
|
||||
:return: A tensor of losses between model predictions and targets.
|
||||
"""
|
||||
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
|
||||
name == self.hparams.cfg.task.overfitting_example_name
|
||||
for name in batch["metadata"]["sample_ID_per_sample"]
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
batch = self.model_step(batch, batch_idx, "train")
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to perform training step for batch index {batch_idx} due to: {e}. Skipping example."
|
||||
)
|
||||
return None
|
||||
|
||||
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
|
||||
training_outputs = {
|
||||
"affinity_logits": batch["outputs"]["affinity_logits"],
|
||||
"affinity": batch["features"]["affinity"],
|
||||
}
|
||||
self.training_step_outputs.append(training_outputs)
|
||||
|
||||
# return loss or backpropagation will fail
|
||||
return batch["outputs"]["loss"]
|
||||
|
||||
def on_train_epoch_end(self):
|
||||
"""Lightning hook that is called when a training epoch ends."""
|
||||
if self.hparams.cfg.affinity.enabled and any(
|
||||
"affinity_logits" in output for output in self.training_step_outputs
|
||||
):
|
||||
affinity_logits = torch.cat(
|
||||
[
|
||||
output["affinity_logits"]
|
||||
for output in self.training_step_outputs
|
||||
if "affinity_logits" in output
|
||||
]
|
||||
)
|
||||
affinity = torch.cat(
|
||||
[
|
||||
output["affinity"]
|
||||
for output in self.training_step_outputs
|
||||
if "affinity_logits" in output
|
||||
]
|
||||
)
|
||||
affinity_logits = affinity_logits[~affinity.isnan()]
|
||||
affinity = affinity[~affinity.isnan()]
|
||||
if affinity.numel() > 1:
|
||||
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
|
||||
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
|
||||
aff_mae = mean_absolute_error(affinity_logits, affinity)
|
||||
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
|
||||
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
|
||||
self.log(
|
||||
"train_affinity/RMSE",
|
||||
aff_rmse.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=False,
|
||||
)
|
||||
self.log(
|
||||
"train_affinity/MAE",
|
||||
aff_mae.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=False,
|
||||
)
|
||||
self.log(
|
||||
"train_affinity/Pearson",
|
||||
aff_pearson.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=False,
|
||||
)
|
||||
self.log(
|
||||
"train_affinity/Spearman",
|
||||
aff_spearman.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=False,
|
||||
)
|
||||
self.training_step_outputs.clear() # free memory
|
||||
|
||||
def on_validation_start(self):
|
||||
"""Lightning hook that is called when validation begins."""
|
||||
# create a directory to store model outputs from each validation epoch
|
||||
os.makedirs(
|
||||
os.path.join(self.trainer.default_root_dir, "validation_epoch_outputs"), exist_ok=True
|
||||
)
|
||||
|
||||
def validation_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
|
||||
"""Perform a single validation step on a batch of data from the validation set.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param batch_idx: The index of the current batch.
|
||||
:param dataloader_idx: The index of the current dataloader.
|
||||
"""
|
||||
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
|
||||
name == self.hparams.cfg.task.overfitting_example_name
|
||||
for name in batch["metadata"]["sample_ID_per_sample"]
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
prepare_batch(batch)
|
||||
sampling_stats = self.net.sample_pl_complex_structures(
|
||||
batch,
|
||||
sampler="VDODE",
|
||||
sampler_eta=1.0,
|
||||
num_steps=10,
|
||||
start_time=1.0,
|
||||
exact_prior=False,
|
||||
return_all_states=True,
|
||||
eval_input_protein=True,
|
||||
)
|
||||
all_frames = sampling_stats["all_frames"]
|
||||
del sampling_stats["all_frames"]
|
||||
for metric_name in sampling_stats.keys():
|
||||
log_stat = sampling_stats[metric_name].mean().detach()
|
||||
batch_size = sampling_stats[metric_name].shape[0]
|
||||
self.log(
|
||||
f"val_sampling/{metric_name}",
|
||||
log_stat,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
sampling_stats = self.net.sample_pl_complex_structures(
|
||||
batch,
|
||||
sampler="VDODE",
|
||||
sampler_eta=1.0,
|
||||
num_steps=10,
|
||||
start_time=1.0,
|
||||
exact_prior=False,
|
||||
use_template=False,
|
||||
)
|
||||
for metric_name in sampling_stats.keys():
|
||||
log_stat = sampling_stats[metric_name].mean().detach()
|
||||
batch_size = sampling_stats[metric_name].shape[0]
|
||||
self.log(
|
||||
f"val_sampling_notemplate/{metric_name}",
|
||||
log_stat,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
sampling_stats = self.net.sample_pl_complex_structures(
|
||||
batch,
|
||||
sampler="VDODE",
|
||||
sampler_eta=1.0,
|
||||
num_steps=10,
|
||||
start_time=1.0,
|
||||
return_summary_stats=True,
|
||||
exact_prior=True,
|
||||
)
|
||||
for metric_name in sampling_stats.keys():
|
||||
log_stat = sampling_stats[metric_name].mean().detach()
|
||||
batch_size = sampling_stats[metric_name].shape[0]
|
||||
self.log(
|
||||
f"val_sampling_trueprior/{metric_name}",
|
||||
log_stat,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
batch = self.model_step(batch, batch_idx, "val")
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to perform validation step for batch index {batch_idx} of dataloader {dataloader_idx} due to: {e}. Skipping example."
|
||||
)
|
||||
return None
|
||||
|
||||
# store model outputs for inspection
|
||||
validation_outputs = {}
|
||||
if self.hparams.cfg.task.visualize_generated_samples:
|
||||
validation_outputs = {
|
||||
"name": batch["metadata"]["sample_ID_per_sample"],
|
||||
"batch_size": batch["metadata"]["num_structid"],
|
||||
"aatype": batch["features"]["res_type"].long().cpu().numpy(),
|
||||
"res_atom_mask": batch["features"]["res_atom_mask"].cpu().numpy(),
|
||||
"protein_coordinates_list": [
|
||||
frame["receptor_padded"].cpu().numpy() for frame in all_frames
|
||||
],
|
||||
"ligand_coordinates_list": [
|
||||
frame["ligands"].cpu().numpy() for frame in all_frames
|
||||
],
|
||||
"ligand_mol": batch["metadata"]["mol_per_sample"],
|
||||
"protein_batch_indexer": batch["indexer"]["gather_idx_a_structid"].cpu().numpy(),
|
||||
"ligand_batch_indexer": batch["indexer"]["gather_idx_i_structid"].cpu().numpy(),
|
||||
"gt_protein_coordinates": batch["features"]["res_atom_positions"].cpu().numpy(),
|
||||
"gt_ligand_coordinates": batch["features"]["sdf_coordinates"].cpu().numpy(),
|
||||
"dataloader_idx": dataloader_idx,
|
||||
}
|
||||
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
|
||||
validation_outputs.update(
|
||||
{
|
||||
"affinity_logits": batch["outputs"]["affinity_logits"],
|
||||
"affinity": batch["features"]["affinity"],
|
||||
"dataloader_idx": dataloader_idx,
|
||||
}
|
||||
)
|
||||
if validation_outputs:
|
||||
self.validation_step_outputs.append(validation_outputs)
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
"Lightning hook that is called when a validation epoch ends."
|
||||
if self.hparams.cfg.task.visualize_generated_samples:
|
||||
for i, outputs in enumerate(self.validation_step_outputs):
|
||||
for batch_index in range(outputs["batch_size"]):
|
||||
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
|
||||
write_prot_lig_pairs_to_pdb_file(
|
||||
prot_lig_pairs,
|
||||
os.path.join(
|
||||
self.trainer.default_root_dir,
|
||||
"validation_epoch_outputs",
|
||||
f"{outputs['name'][batch_index]}_validation_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}_dataloader_{outputs['dataloader_idx']}.pdb",
|
||||
),
|
||||
)
|
||||
if self.hparams.cfg.affinity.enabled and any(
|
||||
"affinity_logits" in output for output in self.validation_step_outputs
|
||||
):
|
||||
affinity_logits = torch.cat(
|
||||
[
|
||||
output["affinity_logits"]
|
||||
for output in self.validation_step_outputs
|
||||
if "affinity_logits" in output
|
||||
]
|
||||
)
|
||||
affinity = torch.cat(
|
||||
[
|
||||
output["affinity"]
|
||||
for output in self.validation_step_outputs
|
||||
if "affinity_logits" in output
|
||||
]
|
||||
)
|
||||
affinity_logits = affinity_logits[~affinity.isnan()]
|
||||
affinity = affinity[~affinity.isnan()]
|
||||
if affinity.numel() > 1:
|
||||
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
|
||||
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
|
||||
aff_mae = mean_absolute_error(affinity_logits, affinity)
|
||||
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
|
||||
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
|
||||
self.log(
|
||||
"val_affinity/RMSE",
|
||||
aff_rmse.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"val_affinity/MAE",
|
||||
aff_mae.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"val_affinity/Pearson",
|
||||
aff_pearson.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"val_affinity/Spearman",
|
||||
aff_spearman.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=True,
|
||||
)
|
||||
self.validation_step_outputs.clear() # free memory
|
||||
|
||||
def on_test_start(self):
|
||||
"""Lightning hook that is called when testing begins."""
|
||||
# create a directory to store model outputs from each test epoch
|
||||
os.makedirs(
|
||||
os.path.join(self.trainer.default_root_dir, "test_epoch_outputs"), exist_ok=True
|
||||
)
|
||||
|
||||
def test_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
|
||||
"""Perform a single test step on a batch of data from the test set.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param batch_idx: The index of the current batch.
|
||||
:param dataloader_idx: The index of the current dataloader.
|
||||
"""
|
||||
if self.hparams.cfg.task.overfitting_example_name is not None and not all(
|
||||
name == self.hparams.cfg.task.overfitting_example_name
|
||||
for name in batch["metadata"]["sample_ID_per_sample"]
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
prepare_batch(batch)
|
||||
if self.hparams.cfg.task.eval_structure_prediction:
|
||||
sampling_stats = self.net.sample_pl_complex_structures(
|
||||
batch,
|
||||
sampler=self.hparams.cfg.task.sampler,
|
||||
sampler_eta=self.hparams.cfg.task.sampler_eta,
|
||||
num_steps=self.hparams.cfg.task.num_steps,
|
||||
start_time=self.hparams.cfg.task.start_time,
|
||||
exact_prior=False,
|
||||
return_all_states=True,
|
||||
eval_input_protein=True,
|
||||
)
|
||||
all_frames = sampling_stats["all_frames"]
|
||||
del sampling_stats["all_frames"]
|
||||
for metric_name in sampling_stats.keys():
|
||||
log_stat = sampling_stats[metric_name].mean().detach()
|
||||
batch_size = sampling_stats[metric_name].shape[0]
|
||||
self.log(
|
||||
f"test_sampling/{metric_name}",
|
||||
log_stat,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
sampling_stats = self.net.sample_pl_complex_structures(
|
||||
batch,
|
||||
sampler=self.hparams.cfg.task.sampler,
|
||||
sampler_eta=self.hparams.cfg.task.sampler_eta,
|
||||
num_steps=self.hparams.cfg.task.num_steps,
|
||||
start_time=self.hparams.cfg.task.start_time,
|
||||
exact_prior=False,
|
||||
use_template=False,
|
||||
)
|
||||
for metric_name in sampling_stats.keys():
|
||||
log_stat = sampling_stats[metric_name].mean().detach()
|
||||
batch_size = sampling_stats[metric_name].shape[0]
|
||||
self.log(
|
||||
f"test_sampling_notemplate/{metric_name}",
|
||||
log_stat,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
sampling_stats = self.net.sample_pl_complex_structures(
|
||||
batch,
|
||||
sampler=self.hparams.cfg.task.sampler,
|
||||
sampler_eta=self.hparams.cfg.task.sampler_eta,
|
||||
num_steps=self.hparams.cfg.task.num_steps,
|
||||
start_time=self.hparams.cfg.task.start_time,
|
||||
return_summary_stats=True,
|
||||
exact_prior=True,
|
||||
)
|
||||
for metric_name in sampling_stats.keys():
|
||||
log_stat = sampling_stats[metric_name].mean().detach()
|
||||
batch_size = sampling_stats[metric_name].shape[0]
|
||||
self.log(
|
||||
f"test_sampling_trueprior/{metric_name}",
|
||||
log_stat,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
batch = self.model_step(
|
||||
batch, batch_idx, "test", loss_mode=self.hparams.cfg.task.loss_mode
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to perform test step for {batch['metadata']['sample_ID_per_sample']} with batch index {batch_idx} of dataloader {dataloader_idx} due to: {e}."
|
||||
)
|
||||
raise e
|
||||
|
||||
# store model outputs for inspection
|
||||
test_outputs = {}
|
||||
if (
|
||||
self.hparams.cfg.task.visualize_generated_samples
|
||||
and self.hparams.cfg.task.eval_structure_prediction
|
||||
):
|
||||
test_outputs.update(
|
||||
{
|
||||
"name": batch["metadata"]["sample_ID_per_sample"],
|
||||
"batch_size": batch["metadata"]["num_structid"],
|
||||
"aatype": batch["features"]["res_type"].long().cpu().numpy(),
|
||||
"res_atom_mask": batch["features"]["res_atom_mask"].cpu().numpy(),
|
||||
"protein_coordinates_list": [
|
||||
frame["receptor_padded"].cpu().numpy() for frame in all_frames
|
||||
],
|
||||
"ligand_coordinates_list": [
|
||||
frame["ligands"].cpu().numpy() for frame in all_frames
|
||||
],
|
||||
"ligand_mol": batch["metadata"]["mol_per_sample"],
|
||||
"protein_batch_indexer": batch["indexer"]["gather_idx_a_structid"]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"ligand_batch_indexer": batch["indexer"]["gather_idx_i_structid"]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"gt_protein_coordinates": batch["features"]["res_atom_positions"]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"gt_ligand_coordinates": batch["features"]["sdf_coordinates"].cpu().numpy(),
|
||||
"dataloader_idx": dataloader_idx,
|
||||
}
|
||||
)
|
||||
if self.hparams.cfg.affinity.enabled and "affinity_logits" in batch["outputs"]:
|
||||
test_outputs.update(
|
||||
{
|
||||
"affinity_logits": batch["outputs"]["affinity_logits"],
|
||||
"affinity": batch["features"]["affinity"],
|
||||
"dataloader_idx": dataloader_idx,
|
||||
}
|
||||
)
|
||||
if test_outputs:
|
||||
self.test_step_outputs.append(test_outputs)
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
"""Lightning hook that is called when a test epoch ends."""
|
||||
if (
|
||||
self.hparams.cfg.task.visualize_generated_samples
|
||||
and self.hparams.cfg.task.eval_structure_prediction
|
||||
):
|
||||
for i, outputs in enumerate(self.test_step_outputs):
|
||||
for batch_index in range(outputs["batch_size"]):
|
||||
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
|
||||
write_prot_lig_pairs_to_pdb_file(
|
||||
prot_lig_pairs,
|
||||
os.path.join(
|
||||
self.trainer.default_root_dir,
|
||||
"test_epoch_outputs",
|
||||
f"{outputs['name'][batch_index]}_test_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}_dataloader_{outputs['dataloader_idx']}.pdb",
|
||||
),
|
||||
)
|
||||
if self.hparams.cfg.affinity.enabled and any(
|
||||
"affinity_logits" in output for output in self.test_step_outputs
|
||||
):
|
||||
affinity_logits = torch.cat(
|
||||
[
|
||||
output["affinity_logits"]
|
||||
for output in self.test_step_outputs
|
||||
if "affinity_logits" in output
|
||||
]
|
||||
)
|
||||
affinity = torch.cat(
|
||||
[
|
||||
output["affinity"]
|
||||
for output in self.test_step_outputs
|
||||
if "affinity_logits" in output
|
||||
]
|
||||
)
|
||||
affinity_logits = affinity_logits[~affinity.isnan()]
|
||||
affinity = affinity[~affinity.isnan()]
|
||||
if affinity.numel() > 1:
|
||||
# NOTE: there must be at least two affinity batches to properly score the affinity predictions
|
||||
aff_rmse = torch.sqrt(mean_squared_error(affinity_logits, affinity))
|
||||
aff_mae = mean_absolute_error(affinity_logits, affinity)
|
||||
aff_pearson = pearson_corrcoef(affinity_logits, affinity)
|
||||
aff_spearman = spearman_corrcoef(affinity_logits, affinity)
|
||||
self.log(
|
||||
"test_affinity/RMSE",
|
||||
aff_rmse.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"test_affinity/MAE",
|
||||
aff_mae.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"test_affinity/Pearson",
|
||||
aff_pearson.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"test_affinity/Spearman",
|
||||
aff_spearman.detach(),
|
||||
on_epoch=True,
|
||||
batch_size=len(affinity),
|
||||
sync_dist=True,
|
||||
)
|
||||
self.test_step_outputs.clear() # free memory
|
||||
|
||||
def on_predict_start(self):
|
||||
"""Lightning hook that is called when testing begins."""
|
||||
# create a directory to store model outputs from each predict epoch
|
||||
os.makedirs(
|
||||
os.path.join(self.trainer.default_root_dir, "predict_epoch_outputs"), exist_ok=True
|
||||
)
|
||||
|
||||
log.info("Loading pretrained ESM model...")
|
||||
esm_model, self.esm_alphabet = esm.pretrained.load_model_and_alphabet_hub(
|
||||
self.hparams.cfg.model.cfg.protein_encoder.esm_version
|
||||
)
|
||||
self.esm_model = esm_model.eval().float()
|
||||
self.esm_batch_converter = self.esm_alphabet.get_batch_converter()
|
||||
self.esm_model.cpu()
|
||||
|
||||
skip_loading_esmfold_weights = (
|
||||
# skip loading ESMFold weights if the template protein structure for a single complex input is provided
|
||||
self.hparams.cfg.task.csv_path is None
|
||||
and self.hparams.cfg.task.input_template is not None
|
||||
and os.path.exists(self.hparams.cfg.task.input_template)
|
||||
)
|
||||
if not skip_loading_esmfold_weights:
|
||||
log.info("Loading pretrained ESMFold model...")
|
||||
esmfold_model = esm.pretrained.esmfold_v1()
|
||||
self.esmfold_model = esmfold_model.eval().float()
|
||||
self.esmfold_model.set_chunk_size(self.hparams.cfg.esmfold_chunk_size)
|
||||
self.esmfold_model.cpu()
|
||||
|
||||
def predict_step(self, batch: MODEL_BATCH, batch_idx: int, dataloader_idx: int = 0):
|
||||
"""Perform a single predict step on a batch of data from the predict set.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param batch_idx: The index of the current batch.
|
||||
:param dataloader_idx: The index of the current dataloader.
|
||||
"""
|
||||
rec_path = batch["rec_path"][0]
|
||||
ligand_paths = list(
|
||||
path[0] for path in batch["lig_paths"]
|
||||
) # unpack a list of (batched) single-element string tuples
|
||||
sample_id = batch["sample_id"][0] if "sample_id" in batch else "sample"
|
||||
input_template = batch["input_template"][0] if "input_template" in batch else None
|
||||
|
||||
out_path = (
|
||||
os.path.join(self.hparams.cfg.out_path, sample_id)
|
||||
if "sample_id" in batch
|
||||
else self.hparams.cfg.out_path
|
||||
)
|
||||
|
||||
# generate ESM embeddings for the protein
|
||||
protein = pdb_filepath_to_protein(rec_path)
|
||||
sequences = [
|
||||
"".join(np.array(list(chain_seq))[chain_mask])
|
||||
for (_, chain_seq, chain_mask) in protein.letter_sequences
|
||||
]
|
||||
esm_embeddings = extract_esm_embeddings(
|
||||
self.esm_model,
|
||||
self.esm_alphabet,
|
||||
self.esm_batch_converter,
|
||||
sequences,
|
||||
device="cpu",
|
||||
esm_repr_layer=self.hparams.cfg.model.cfg.protein_encoder.esm_repr_layer,
|
||||
)
|
||||
sequences_to_embeddings = {
|
||||
f"{seq}:{i}": esm_embeddings[i].cpu().numpy() for i, seq in enumerate(sequences)
|
||||
}
|
||||
|
||||
# generate initial ESMFold-predicted structure for the protein if a template is not provided
|
||||
apo_rec_path = None
|
||||
if input_template and os.path.exists(input_template):
|
||||
apo_protein = pdb_filepath_to_protein(input_template)
|
||||
apo_chain_seq_masked = "".join(
|
||||
"".join(np.array(list(chain_seq))[chain_mask])
|
||||
for (_, chain_seq, chain_mask) in apo_protein.letter_sequences
|
||||
)
|
||||
chain_seq_masked = "".join(
|
||||
"".join(np.array(list(chain_seq))[chain_mask])
|
||||
for (_, chain_seq, chain_mask) in protein.letter_sequences
|
||||
)
|
||||
if apo_chain_seq_masked != chain_seq_masked:
|
||||
log.error(
|
||||
f"Provided template protein structure {input_template} does not match the input protein sequence within {rec_path}. Skipping example {sample_id} at batch index {batch_idx} of dataloader {dataloader_idx}."
|
||||
)
|
||||
return None
|
||||
log.info(f"Starting from provided template protein structure: {input_template}")
|
||||
apo_rec_path = input_template
|
||||
if apo_rec_path is None and self.hparams.cfg.prior_type == "esmfold":
|
||||
esmfold_sequence = ":".join(sequences)
|
||||
apo_rec_path = rec_path.replace(".pdb", "_apo.pdb")
|
||||
with torch.no_grad():
|
||||
esmfold_pdb_output = self.esmfold_model.infer_pdb(esmfold_sequence)
|
||||
with open(apo_rec_path, "w") as f:
|
||||
f.write(esmfold_pdb_output)
|
||||
|
||||
_, _, _, _, _, all_frames, batch_all, b_factors, plddt_rankings = multi_pose_sampling(
|
||||
rec_path,
|
||||
ligand_paths,
|
||||
self.hparams.cfg,
|
||||
self,
|
||||
out_path,
|
||||
separate_pdb=self.hparams.cfg.separate_pdb,
|
||||
apo_receptor_path=apo_rec_path,
|
||||
sample_id=sample_id,
|
||||
protein=protein,
|
||||
sequences_to_embeddings=sequences_to_embeddings,
|
||||
return_all_states=self.hparams.cfg.task.visualize_generated_samples,
|
||||
auxiliary_estimation_only=self.hparams.cfg.task.auxiliary_estimation_only,
|
||||
)
|
||||
# store model outputs for inspection
|
||||
if self.hparams.cfg.task.visualize_generated_samples:
|
||||
predict_outputs = {
|
||||
"name": batch_all["metadata"]["sample_ID_per_sample"],
|
||||
"batch_size": batch_all["metadata"]["num_structid"],
|
||||
"aatype": batch_all["features"]["res_type"].long().cpu().numpy(),
|
||||
"res_atom_mask": batch_all["features"]["res_atom_mask"].cpu().numpy(),
|
||||
"protein_coordinates_list": [
|
||||
frame["receptor_padded"].cpu().numpy() for frame in all_frames
|
||||
],
|
||||
"ligand_coordinates_list": [
|
||||
frame["ligands"].cpu().numpy() for frame in all_frames
|
||||
],
|
||||
"ligand_mol": batch_all["metadata"]["mol_per_sample"],
|
||||
"protein_batch_indexer": batch_all["indexer"]["gather_idx_a_structid"]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"ligand_batch_indexer": batch_all["indexer"]["gather_idx_i_structid"]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"b_factors": b_factors,
|
||||
"plddt_rankings": plddt_rankings,
|
||||
}
|
||||
self.predict_step_outputs.append(predict_outputs)
|
||||
|
||||
def on_predict_epoch_end(self):
|
||||
"""Lightning hook that is called when a predict epoch ends."""
|
||||
if self.hparams.cfg.task.visualize_generated_samples:
|
||||
for i, outputs in enumerate(self.predict_step_outputs):
|
||||
for batch_index in range(outputs["batch_size"]):
|
||||
prot_lig_pairs = construct_prot_lig_pairs(outputs, batch_index)
|
||||
ranking = (
|
||||
outputs["plddt_rankings"][batch_index]
|
||||
if "plddt_rankings" in outputs
|
||||
else None
|
||||
)
|
||||
write_prot_lig_pairs_to_pdb_file(
|
||||
prot_lig_pairs,
|
||||
os.path.join(
|
||||
self.hparams.cfg.out_path,
|
||||
outputs["name"][batch_index],
|
||||
"predict_epoch_outputs",
|
||||
f"{outputs['name'][batch_index]}{f'_rank{ranking + 1}' if ranking is not None else ''}_predict_epoch_{self.current_epoch}_global_step_{self.global_step}_output_{i}_batch_{batch_index}.pdb",
|
||||
),
|
||||
)
|
||||
self.predict_step_outputs.clear() # free memory
|
||||
|
||||
def on_after_backward(self):
|
||||
"""Skip updates in case of unstable gradients.
|
||||
|
||||
Reference: https://github.com/Lightning-AI/lightning/issues/4956
|
||||
"""
|
||||
valid_gradients = True
|
||||
for _, param in self.named_parameters():
|
||||
if param.grad is not None:
|
||||
valid_gradients = not (
|
||||
torch.isnan(param.grad).any() or torch.isinf(param.grad).any()
|
||||
)
|
||||
if not valid_gradients:
|
||||
break
|
||||
if not valid_gradients:
|
||||
log.warning(
|
||||
"Detected `inf` or `nan` values in gradients. Not updating model parameters."
|
||||
)
|
||||
self.zero_grad()
|
||||
|
||||
def optimizer_step(
|
||||
self,
|
||||
epoch,
|
||||
batch_idx,
|
||||
optimizer,
|
||||
optimizer_closure,
|
||||
):
|
||||
"""Override the optimizer step to dynamically update the learning rate.
|
||||
|
||||
:param epoch: The current epoch.
|
||||
:param batch_idx: The index of the current batch.
|
||||
:param optimizer: The optimizer to use for training.
|
||||
:param optimizer_closure: The optimizer closure.
|
||||
"""
|
||||
# update params
|
||||
optimizer = optimizer.optimizer
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
|
||||
# warm up learning rate
|
||||
if self.trainer.global_step < 1000:
|
||||
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 1000.0)
|
||||
for pg in optimizer.param_groups:
|
||||
# NOTE: `self.hparams.optimizer.keywords["lr"]` refers to the optimizer's initial learning rate
|
||||
pg["lr"] = lr_scale * self.hparams.optimizer.keywords["lr"]
|
||||
|
||||
def setup(self, stage: str):
|
||||
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
|
||||
test, or predict.
|
||||
|
||||
This is a good hook when you need to build models dynamically or adjust something about
|
||||
them. This hook is called on every process when using DDP.
|
||||
|
||||
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
||||
"""
|
||||
if self.hparams.compile and stage == "fit":
|
||||
self.net = torch.compile(self.net)
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
|
||||
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
||||
|
||||
Examples:
|
||||
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
|
||||
|
||||
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
|
||||
"""
|
||||
try:
|
||||
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
|
||||
except TypeError:
|
||||
# NOTE: strategies such as DeepSpeed require `params` to instead be specified as `model_params`
|
||||
optimizer = self.hparams.optimizer(model_params=self.trainer.model.parameters())
|
||||
if self.hparams.scheduler is not None:
|
||||
scheduler = self.hparams.scheduler(optimizer=optimizer)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val/loss",
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
},
|
||||
}
|
||||
return {"optimizer": optimizer}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_ = FlowDockFMLitModule(None, None, None, None)
|
||||
284
flowdock/sample.py
Normal file
284
flowdock/sample.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import os
|
||||
|
||||
import hydra
|
||||
import lightning as L
|
||||
import lovely_tensors as lt
|
||||
import pandas as pd
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, List, Tuple
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from lightning.pytorch.strategies.strategy import Strategy
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
lt.monkey_patch()
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# the setup_root above is equivalent to:
|
||||
# - adding project root dir to PYTHONPATH
|
||||
# (so you don't need to force user to install project as a package)
|
||||
# (necessary before importing any local modules e.g. `from flowdock import utils`)
|
||||
# - setting up PROJECT_ROOT environment variable
|
||||
# (which is used as a base for paths in "configs/paths/default.yaml")
|
||||
# (this way all filepaths are the same no matter where you run the code)
|
||||
# - loading environment variables from ".env" in root dir
|
||||
#
|
||||
# you can remove it if you:
|
||||
# 1. either install project as a package or move entry files to project root dir
|
||||
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
||||
#
|
||||
# more info: https://github.com/ashleve/rootutils
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
|
||||
from flowdock.utils import (
|
||||
RankedLogger,
|
||||
extras,
|
||||
instantiate_loggers,
|
||||
log_hyperparameters,
|
||||
task_wrapper,
|
||||
)
|
||||
from flowdock.utils.data_utils import (
|
||||
create_full_pdb_with_zero_coordinates,
|
||||
create_temp_ligand_frag_files,
|
||||
)
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
AVAILABLE_SAMPLING_TASKS = ["batched_structure_sampling"]
|
||||
|
||||
|
||||
class SamplingDataset(torch.utils.data.Dataset):
|
||||
"""Dataset for sampling."""
|
||||
|
||||
def __init__(self, cfg: DictConfig):
|
||||
"""Initializes the SamplingDataset."""
|
||||
if cfg.sampling_task == "batched_structure_sampling":
|
||||
if cfg.csv_path is not None:
|
||||
# handle variable CSV inputs
|
||||
df_rows = []
|
||||
self.df = pd.read_csv(cfg.csv_path)
|
||||
for _, row in self.df.iterrows():
|
||||
sample_id = row.id
|
||||
input_receptor = row.input_receptor
|
||||
input_ligand = row.input_ligand
|
||||
input_template = row.input_template
|
||||
assert input_receptor is not None, "Receptor path is required for sampling."
|
||||
if input_ligand is not None:
|
||||
if input_ligand.endswith(".sdf"):
|
||||
ligand_paths = create_temp_ligand_frag_files(input_ligand)
|
||||
else:
|
||||
ligand_paths = list(input_ligand.split("|"))
|
||||
else:
|
||||
ligand_paths = None # handle `null` ligand input
|
||||
if not input_receptor.endswith(".pdb"):
|
||||
log.warning(
|
||||
"Assuming the provided receptor input is a protein sequence. Creating a dummy PDB file."
|
||||
)
|
||||
create_full_pdb_with_zero_coordinates(
|
||||
input_receptor, os.path.join(cfg.out_path, f"input_{sample_id}.pdb")
|
||||
)
|
||||
input_receptor = os.path.join(cfg.out_path, f"input_{sample_id}.pdb")
|
||||
df_row = {
|
||||
"sample_id": sample_id,
|
||||
"rec_path": input_receptor,
|
||||
"lig_paths": ligand_paths,
|
||||
}
|
||||
if input_template is not None:
|
||||
df_row["input_template"] = input_template
|
||||
df_rows.append(df_row)
|
||||
self.df = pd.DataFrame(df_rows)
|
||||
else:
|
||||
sample_id = cfg.sample_id
|
||||
input_receptor = cfg.input_receptor
|
||||
input_ligand = cfg.input_ligand
|
||||
if input_ligand is not None:
|
||||
if input_ligand.endswith(".sdf"):
|
||||
ligand_paths = create_temp_ligand_frag_files(input_ligand)
|
||||
else:
|
||||
ligand_paths = list(input_ligand.split("|"))
|
||||
else:
|
||||
ligand_paths = None # handle `null` ligand input
|
||||
if not input_receptor.endswith(".pdb"):
|
||||
log.warning(
|
||||
"Assuming the provided receptor input is a protein sequence. Creating a dummy PDB file."
|
||||
)
|
||||
create_full_pdb_with_zero_coordinates(
|
||||
input_receptor, os.path.join(cfg.out_path, "input.pdb")
|
||||
)
|
||||
input_receptor = os.path.join(cfg.out_path, "input.pdb")
|
||||
self.df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"sample_id": sample_id,
|
||||
"rec_path": input_receptor,
|
||||
"lig_paths": ligand_paths,
|
||||
}
|
||||
]
|
||||
)
|
||||
if cfg.input_template is not None:
|
||||
self.df["input_template"] = cfg.input_template
|
||||
else:
|
||||
raise NotImplementedError(f"Sampling task {cfg.sampling_task} is not implemented.")
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the dataset."""
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[str, str]:
|
||||
"""Returns the input receptor and input ligand."""
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
|
||||
@task_wrapper
|
||||
def sample(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Samples using given checkpoint on a datamodule predictset.
|
||||
|
||||
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
||||
failure. Useful for multiruns, saving info about the crash, etc.
|
||||
|
||||
:param cfg: DictConfig configuration composed by Hydra.
|
||||
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
|
||||
"""
|
||||
assert cfg.ckpt_path, "Please provide a checkpoint path with which to sample!"
|
||||
assert os.path.exists(cfg.ckpt_path), f"Checkpoint path {cfg.ckpt_path} does not exist!"
|
||||
assert (
|
||||
cfg.sampling_task in AVAILABLE_SAMPLING_TASKS
|
||||
), f"Sampling task {cfg.sampling_task} is not one of the following available tasks: {AVAILABLE_SAMPLING_TASKS}."
|
||||
assert (cfg.input_receptor is not None and cfg.input_ligand is not None) or (
|
||||
cfg.csv_path is not None and os.path.exists(cfg.csv_path)
|
||||
), "Please provide either an input receptor and ligand or a CSV file with receptor and ligand sequences/filepaths."
|
||||
|
||||
# set seed for random number generators in pytorch, numpy and python.random
|
||||
if cfg.get("seed"):
|
||||
L.seed_everything(cfg.seed, workers=True)
|
||||
|
||||
log.info(
|
||||
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
|
||||
)
|
||||
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
|
||||
|
||||
# Establish model input arguments
|
||||
with open_dict(cfg):
|
||||
# NOTE: Structure trajectories will not be visualized when performing auxiliary estimation only
|
||||
cfg.model.cfg.prior_type = cfg.prior_type
|
||||
cfg.model.cfg.task.detect_covalent = cfg.detect_covalent
|
||||
cfg.model.cfg.task.use_template = cfg.use_template
|
||||
cfg.model.cfg.task.csv_path = cfg.csv_path
|
||||
cfg.model.cfg.task.input_receptor = cfg.input_receptor
|
||||
cfg.model.cfg.task.input_ligand = cfg.input_ligand
|
||||
cfg.model.cfg.task.input_template = cfg.input_template
|
||||
cfg.model.cfg.task.visualize_generated_samples = (
|
||||
cfg.visualize_sample_trajectories and not cfg.auxiliary_estimation_only
|
||||
)
|
||||
cfg.model.cfg.task.auxiliary_estimation_only = cfg.auxiliary_estimation_only
|
||||
if cfg.latent_model is not None:
|
||||
with open_dict(cfg):
|
||||
cfg.model.cfg.latent_model = cfg.latent_model
|
||||
with open_dict(cfg):
|
||||
if cfg.start_time == "auto":
|
||||
cfg.start_time = 1.0
|
||||
else:
|
||||
cfg.start_time = float(cfg.start_time)
|
||||
|
||||
log.info("Converting sampling inputs into a <SamplingDataset>")
|
||||
dataloaders: List[DataLoader] = [
|
||||
DataLoader(
|
||||
SamplingDataset(cfg),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
)
|
||||
]
|
||||
|
||||
log.info(f"Instantiating model <{cfg.model._target_}>")
|
||||
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
||||
model.hparams.cfg.update(cfg) # update model config with the sampling config
|
||||
|
||||
log.info("Instantiating loggers...")
|
||||
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
||||
|
||||
plugins = None
|
||||
if "_target_" in cfg.environment:
|
||||
log.info(f"Instantiating environment <{cfg.environment._target_}>")
|
||||
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
|
||||
|
||||
strategy = getattr(cfg.trainer, "strategy", None)
|
||||
if "_target_" in cfg.strategy:
|
||||
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
|
||||
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
|
||||
if (
|
||||
"mixed_precision" in strategy.__dict__
|
||||
and getattr(strategy, "mixed_precision", None) is not None
|
||||
):
|
||||
strategy.mixed_precision.param_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
strategy.mixed_precision.reduce_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
strategy.mixed_precision.buffer_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
||||
trainer: Trainer = (
|
||||
hydra.utils.instantiate(
|
||||
cfg.trainer,
|
||||
logger=logger,
|
||||
plugins=plugins,
|
||||
strategy=strategy,
|
||||
)
|
||||
if strategy is not None
|
||||
else hydra.utils.instantiate(
|
||||
cfg.trainer,
|
||||
logger=logger,
|
||||
plugins=plugins,
|
||||
)
|
||||
)
|
||||
|
||||
object_dict = {
|
||||
"cfg": cfg,
|
||||
"model": model,
|
||||
"logger": logger,
|
||||
"trainer": trainer,
|
||||
}
|
||||
|
||||
if logger:
|
||||
log.info("Logging hyperparameters!")
|
||||
log_hyperparameters(object_dict)
|
||||
|
||||
trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
|
||||
|
||||
metric_dict = trainer.callback_metrics
|
||||
|
||||
return metric_dict, object_dict
|
||||
|
||||
|
||||
@hydra.main(version_base="1.3", config_path="../configs", config_name="sample.yaml")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
"""Main entry point for sampling.
|
||||
|
||||
:param cfg: DictConfig configuration composed by Hydra.
|
||||
"""
|
||||
# apply extra utilities
|
||||
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
||||
extras(cfg)
|
||||
|
||||
sample(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_custom_omegaconf_resolvers()
|
||||
main()
|
||||
196
flowdock/train.py
Normal file
196
flowdock/train.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import os
|
||||
|
||||
import hydra
|
||||
import lightning as L
|
||||
import lovely_tensors as lt
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, List, Optional, Tuple
|
||||
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
||||
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from lightning.pytorch.strategies.strategy import Strategy
|
||||
from omegaconf import DictConfig
|
||||
|
||||
lt.monkey_patch()
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# the setup_root above is equivalent to:
|
||||
# - adding project root dir to PYTHONPATH
|
||||
# (so you don't need to force user to install project as a package)
|
||||
# (necessary before importing any local modules e.g. `from flowdock import utils`)
|
||||
# - setting up PROJECT_ROOT environment variable
|
||||
# (which is used as a base for paths in "configs/paths/default.yaml")
|
||||
# (this way all filepaths are the same no matter where you run the code)
|
||||
# - loading environment variables from ".env" in root dir
|
||||
#
|
||||
# you can remove it if you:
|
||||
# 1. either install project as a package or move entry files to project root dir
|
||||
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
||||
#
|
||||
# more info: https://github.com/ashleve/rootutils
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
from flowdock import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
|
||||
from flowdock.utils import (
|
||||
RankedLogger,
|
||||
extras,
|
||||
get_metric_value,
|
||||
instantiate_callbacks,
|
||||
instantiate_loggers,
|
||||
log_hyperparameters,
|
||||
task_wrapper,
|
||||
)
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@task_wrapper
|
||||
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
||||
training.
|
||||
|
||||
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
||||
failure. Useful for multiruns, saving info about the crash, etc.
|
||||
|
||||
:param cfg: A DictConfig configuration composed by Hydra.
|
||||
:return: A tuple with metrics and dict with all instantiated objects.
|
||||
"""
|
||||
# set seed for random number generators in pytorch, numpy and python.random
|
||||
if cfg.get("seed"):
|
||||
L.seed_everything(cfg.seed, workers=True)
|
||||
|
||||
log.info(
|
||||
f"Setting `float32_matmul_precision` to {cfg.model.cfg.task.float32_matmul_precision}."
|
||||
)
|
||||
torch.set_float32_matmul_precision(precision=cfg.model.cfg.task.float32_matmul_precision)
|
||||
|
||||
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
||||
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, stage="fit")
|
||||
|
||||
log.info(f"Instantiating model <{cfg.model._target_}>")
|
||||
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
||||
|
||||
log.info("Instantiating callbacks...")
|
||||
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
|
||||
|
||||
log.info("Instantiating loggers...")
|
||||
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
||||
|
||||
plugins = None
|
||||
if "_target_" in cfg.environment:
|
||||
log.info(f"Instantiating environment <{cfg.environment._target_}>")
|
||||
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
|
||||
|
||||
strategy = getattr(cfg.trainer, "strategy", None)
|
||||
if "_target_" in cfg.strategy:
|
||||
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
|
||||
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
|
||||
if (
|
||||
"mixed_precision" in strategy.__dict__
|
||||
and getattr(strategy, "mixed_precision", None) is not None
|
||||
):
|
||||
strategy.mixed_precision.param_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
strategy.mixed_precision.reduce_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
strategy.mixed_precision.buffer_dtype = (
|
||||
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
|
||||
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
||||
trainer: Trainer = (
|
||||
hydra.utils.instantiate(
|
||||
cfg.trainer,
|
||||
callbacks=callbacks,
|
||||
logger=logger,
|
||||
plugins=plugins,
|
||||
strategy=strategy,
|
||||
)
|
||||
if strategy is not None
|
||||
else hydra.utils.instantiate(
|
||||
cfg.trainer,
|
||||
callbacks=callbacks,
|
||||
logger=logger,
|
||||
plugins=plugins,
|
||||
)
|
||||
)
|
||||
|
||||
object_dict = {
|
||||
"cfg": cfg,
|
||||
"datamodule": datamodule,
|
||||
"model": model,
|
||||
"callbacks": callbacks,
|
||||
"logger": logger,
|
||||
"trainer": trainer,
|
||||
}
|
||||
|
||||
if logger:
|
||||
log.info("Logging hyperparameters!")
|
||||
log_hyperparameters(object_dict)
|
||||
|
||||
if cfg.get("train"):
|
||||
log.info("Starting training!")
|
||||
ckpt_path = None
|
||||
if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")):
|
||||
ckpt_path = cfg.get("ckpt_path")
|
||||
elif cfg.get("ckpt_path"):
|
||||
log.warning(
|
||||
"`ckpt_path` was given, but the path does not exist. Training with new model weights."
|
||||
)
|
||||
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
||||
|
||||
train_metrics = trainer.callback_metrics
|
||||
|
||||
if cfg.get("test"):
|
||||
log.info("Starting testing!")
|
||||
ckpt_path = trainer.checkpoint_callback.best_model_path
|
||||
if ckpt_path == "":
|
||||
log.warning("Best ckpt not found! Using current weights for testing...")
|
||||
ckpt_path = None
|
||||
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
||||
log.info(f"Best ckpt path: {ckpt_path}")
|
||||
|
||||
test_metrics = trainer.callback_metrics
|
||||
|
||||
# merge train and test metrics
|
||||
metric_dict = {**train_metrics, **test_metrics}
|
||||
|
||||
return metric_dict, object_dict
|
||||
|
||||
|
||||
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
|
||||
def main(cfg: DictConfig) -> Optional[float]:
|
||||
"""Main entry point for training.
|
||||
|
||||
:param cfg: DictConfig configuration composed by Hydra.
|
||||
:return: Optional[float] with optimized metric value.
|
||||
"""
|
||||
# apply extra utilities
|
||||
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
||||
extras(cfg)
|
||||
|
||||
# train the model
|
||||
metric_dict, _ = train(cfg)
|
||||
|
||||
# safely retrieve metric value for hydra-based hyperparameter optimization
|
||||
metric_value = get_metric_value(
|
||||
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
|
||||
)
|
||||
|
||||
# return optimized metric
|
||||
return metric_value
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_custom_omegaconf_resolvers()
|
||||
main()
|
||||
5
flowdock/utils/__init__.py
Normal file
5
flowdock/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from flowdock.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
||||
from flowdock.utils.logging_utils import log_hyperparameters
|
||||
from flowdock.utils.pylogger import RankedLogger
|
||||
from flowdock.utils.rich_utils import enforce_tags, print_config_tree
|
||||
from flowdock.utils.utils import extras, get_metric_value, task_wrapper
|
||||
1846
flowdock/utils/data_utils.py
Normal file
1846
flowdock/utils/data_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
77
flowdock/utils/frame_utils.py
Normal file
77
flowdock/utils/frame_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from beartype.typing import Optional
|
||||
|
||||
|
||||
class RigidTransform:
|
||||
"""Rigid Transform class."""
|
||||
|
||||
def __init__(self, t: torch.Tensor, R: Optional[torch.Tensor] = None):
|
||||
"""Initialize Rigid Transform."""
|
||||
self.t = t
|
||||
if R is None:
|
||||
R = t.new_zeros(*t.shape, 3)
|
||||
self.R = R
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get item from Rigid Transform."""
|
||||
return RigidTransform(self.t[key], self.R[key])
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
"""Unsqueeze Rigid Transform."""
|
||||
return RigidTransform(self.t.unsqueeze(dim), self.R.unsqueeze(dim))
|
||||
|
||||
def squeeze(self, dim):
|
||||
"""Squeeze Rigid Transform."""
|
||||
return RigidTransform(self.t.squeeze(dim), self.R.squeeze(dim))
|
||||
|
||||
def concatenate(self, other, dim=0):
|
||||
"""Concatenate Rigid Transform."""
|
||||
return RigidTransform(
|
||||
torch.cat([self.t, other.t], dim=dim),
|
||||
torch.cat([self.R, other.R], dim=dim),
|
||||
)
|
||||
|
||||
|
||||
def get_frame_matrix(
|
||||
ri: torch.Tensor, rj: torch.Tensor, rk: torch.Tensor, eps: float = 1e-4, strict: bool = False
|
||||
):
|
||||
"""Get frame matrix from three points using the regularized Gram-Schmidt algorithm.
|
||||
|
||||
Note that this implementation allows for shearing.
|
||||
"""
|
||||
v1 = ri - rj
|
||||
v2 = rk - rj
|
||||
if strict:
|
||||
# v1 = v1 + torch.randn_like(rj).mul(eps)
|
||||
# v2 = v2 + torch.randn_like(rj).mul(eps)
|
||||
e1 = v1 / v1.norm(dim=-1, keepdim=True)
|
||||
# Project and pad
|
||||
u2 = v2 - e1.mul(e1.mul(v2).sum(-1, keepdim=True))
|
||||
e2 = u2 / u2.norm(dim=-1, keepdim=True)
|
||||
else:
|
||||
e1 = v1 / v1.square().sum(dim=-1, keepdim=True).add(eps).sqrt()
|
||||
# Project and pad
|
||||
u2 = v2 - e1.mul(e1.mul(v2).sum(-1, keepdim=True))
|
||||
e2 = u2 / u2.square().sum(dim=-1, keepdim=True).add(eps).sqrt()
|
||||
e3 = torch.cross(e1, e2, dim=-1)
|
||||
# Rows - lab frame, columns - internal frame
|
||||
rot_j = torch.stack([e1, e2, e3], dim=-1)
|
||||
return RigidTransform(rj, torch.nan_to_num(rot_j, 0.0))
|
||||
|
||||
|
||||
def cartesian_to_internal(rs: torch.Tensor, frames: RigidTransform):
|
||||
"""Right-multiply the pose matrix."""
|
||||
rs_loc = rs - frames.t
|
||||
rs_loc = torch.matmul(rs_loc.unsqueeze(-2), frames.R)
|
||||
return rs_loc.squeeze(-2)
|
||||
|
||||
|
||||
def apply_similarity_transform(
|
||||
X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Apply a similarity transform to a set of points X.
|
||||
|
||||
From: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/ops/points_alignment.html
|
||||
"""
|
||||
X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :]
|
||||
return X
|
||||
4
flowdock/utils/generate_wandb_id.py
Normal file
4
flowdock/utils/generate_wandb_id.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import wandb
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Generated WandB run ID: {wandb.util.generate_id()}")
|
||||
55
flowdock/utils/inspect_ode_samplers.py
Normal file
55
flowdock/utils/inspect_ode_samplers.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from beartype import beartype
|
||||
from beartype.typing import Literal
|
||||
|
||||
|
||||
def clamp_tensor(value: torch.Tensor, min: float = 1e-6, max: float = 1 - 1e-6) -> torch.Tensor:
|
||||
"""Set the upper and lower bounds of a tensor via clamping.
|
||||
|
||||
:param value: The tensor to clamp.
|
||||
:param min: The minimum value to clamp to. Default is `1e-6`.
|
||||
:param max: The maximum value to clamp to. Default is `1 - 1e-6`.
|
||||
:return: The clamped tensor.
|
||||
"""
|
||||
return value.clamp(min=min, max=max)
|
||||
|
||||
|
||||
@beartype
|
||||
def main(
|
||||
start_time: float, num_steps: int, sampler: Literal["ODE", "VDODE"] = "VDODE", eta: float = 1.0
|
||||
):
|
||||
"""Inspect different ODE samplers by printing the left hand side (LHS) and right hand side.
|
||||
|
||||
(RHS) of their time ratio schedules. Note that the LHS and RHS are clamped to the range
|
||||
`[1e-6, 1 - 1e-6]` by default.
|
||||
|
||||
:param start_time: The start time of the ODE sampler.
|
||||
:param num_steps: The number of steps to take.
|
||||
:param sampler: The ODE sampler to use.
|
||||
:param eta: The variance diminishing factor.
|
||||
"""
|
||||
assert 0 < start_time <= 1.0, "The argument `start_time` must be in the range (0, 1]."
|
||||
schedule = torch.linspace(start_time, 0, num_steps + 1)
|
||||
for t, s in zip(schedule[:-1], schedule[1:]):
|
||||
if sampler == "ODE":
|
||||
# Baseline ODE
|
||||
print(
|
||||
f"t: {t:.3f}; s: {s:.3f}; LHS -> (1 - t) * x0_hat: {clamp_tensor((1 - t)):.3f}; RHS -> t * xt: {clamp_tensor(t):.3f}"
|
||||
)
|
||||
elif sampler == "VDODE":
|
||||
# Variance Diminishing ODE (VD-ODE)
|
||||
print(
|
||||
f"t: {t:.3f}; s: {s:.3f}; LHS -> (1 - ((s / t) * eta)) * x0_hat: {clamp_tensor(1 - ((s / t) * eta)):.3f}; RHS -> ((s / t) * eta) * xt: {clamp_tensor((s / t) * eta):.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--start_time", type=float, default=1.0)
|
||||
argparser.add_argument("--num_steps", type=int, default=40)
|
||||
argparser.add_argument("--sampler", type=str, choices=["ODE", "VDODE"], default="VDODE")
|
||||
argparser.add_argument("--eta", type=float, default=1.0)
|
||||
args = argparser.parse_args()
|
||||
main(args.start_time, args.num_steps, sampler=args.sampler, eta=args.eta)
|
||||
58
flowdock/utils/instantiators.py
Normal file
58
flowdock/utils/instantiators.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import hydra
|
||||
import rootutils
|
||||
from beartype.typing import List
|
||||
from lightning import Callback
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from omegaconf import DictConfig
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.utils import pylogger
|
||||
|
||||
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
||||
"""Instantiates callbacks from config.
|
||||
|
||||
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
||||
:return: A list of instantiated callbacks.
|
||||
"""
|
||||
callbacks: List[Callback] = []
|
||||
|
||||
if not callbacks_cfg:
|
||||
log.warning("No callback configs found! Skipping..")
|
||||
return callbacks
|
||||
|
||||
if not isinstance(callbacks_cfg, DictConfig):
|
||||
raise TypeError("Callbacks config must be a DictConfig!")
|
||||
|
||||
for _, cb_conf in callbacks_cfg.items():
|
||||
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
||||
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
||||
callbacks.append(hydra.utils.instantiate(cb_conf))
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
||||
"""Instantiates loggers from config.
|
||||
|
||||
:param logger_cfg: A DictConfig object containing logger configurations.
|
||||
:return: A list of instantiated loggers.
|
||||
"""
|
||||
logger: List[Logger] = []
|
||||
|
||||
if not logger_cfg:
|
||||
log.warning("No logger configs found! Skipping...")
|
||||
return logger
|
||||
|
||||
if not isinstance(logger_cfg, DictConfig):
|
||||
raise TypeError("Logger config must be a DictConfig!")
|
||||
|
||||
for _, lg_conf in logger_cfg.items():
|
||||
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
||||
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
||||
logger.append(hydra.utils.instantiate(lg_conf))
|
||||
|
||||
return logger
|
||||
59
flowdock/utils/logging_utils.py
Normal file
59
flowdock/utils/logging_utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import rootutils
|
||||
from beartype.typing import Any, Dict
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.utils import pylogger
|
||||
|
||||
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
||||
"""Controls which config parts are saved by Lightning loggers.
|
||||
|
||||
Additionally saves:
|
||||
- Number of model parameters
|
||||
|
||||
:param object_dict: A dictionary containing the following objects:
|
||||
- `"cfg"`: A DictConfig object containing the main config.
|
||||
- `"model"`: The Lightning model.
|
||||
- `"trainer"`: The Lightning trainer.
|
||||
"""
|
||||
hparams = {}
|
||||
|
||||
cfg = OmegaConf.to_container(object_dict["cfg"])
|
||||
model = object_dict["model"]
|
||||
trainer = object_dict["trainer"]
|
||||
|
||||
if not trainer.logger:
|
||||
log.warning("Logger not found! Skipping hyperparameter logging...")
|
||||
return
|
||||
|
||||
hparams["model"] = cfg["model"]
|
||||
|
||||
# save number of model parameters
|
||||
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
||||
hparams["model/params/trainable"] = sum(
|
||||
p.numel() for p in model.parameters() if p.requires_grad
|
||||
)
|
||||
hparams["model/params/non_trainable"] = sum(
|
||||
p.numel() for p in model.parameters() if not p.requires_grad
|
||||
)
|
||||
|
||||
hparams["data"] = cfg["data"]
|
||||
hparams["trainer"] = cfg["trainer"]
|
||||
|
||||
hparams["callbacks"] = cfg.get("callbacks")
|
||||
hparams["extras"] = cfg.get("extras")
|
||||
|
||||
hparams["task_name"] = cfg.get("task_name")
|
||||
hparams["tags"] = cfg.get("tags")
|
||||
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
||||
hparams["seed"] = cfg.get("seed")
|
||||
|
||||
# send hparams to all loggers
|
||||
for logger in trainer.loggers:
|
||||
logger.log_hyperparams(hparams)
|
||||
88
flowdock/utils/metric_utils.py
Normal file
88
flowdock/utils/metric_utils.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import subprocess # nosec
|
||||
|
||||
import torch
|
||||
from beartype import beartype
|
||||
from beartype.typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
MODEL_BATCH = Dict[str, Any]
|
||||
|
||||
|
||||
@beartype
|
||||
def calculate_usalign_metrics(
|
||||
pred_pdb_filepath: str,
|
||||
reference_pdb_filepath: str,
|
||||
usalign_exec_path: str,
|
||||
flags: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculates US-align structural metrics between predicted and reference macromolecular
|
||||
structures.
|
||||
|
||||
:param pred_pdb_filepath: Filepath to predicted macromolecular structure in PDB format.
|
||||
:param reference_pdb_filepath: Filepath to reference macromolecular structure in PDB format.
|
||||
:param usalign_exec_path: Path to US-align executable.
|
||||
:param flags: Command-line flags to pass to US-align, optional.
|
||||
:return: Dictionary containing macromolecular US-align structural metrics and metadata.
|
||||
"""
|
||||
# run US-align with subprocess and capture output
|
||||
cmd = [usalign_exec_path, pred_pdb_filepath, reference_pdb_filepath]
|
||||
if flags is not None:
|
||||
cmd += flags
|
||||
output = subprocess.check_output(cmd, text=True, stderr=subprocess.PIPE) # nosec
|
||||
|
||||
# parse US-align output to extract structural metrics
|
||||
metrics = {}
|
||||
for line in output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("Name of Structure_1:"):
|
||||
metrics["Name of Structure_1"] = line.split(": ", 1)[1]
|
||||
elif line.startswith("Name of Structure_2:"):
|
||||
metrics["Name of Structure_2"] = line.split(": ", 1)[1]
|
||||
elif line.startswith("Length of Structure_1:"):
|
||||
metrics["Length of Structure_1"] = int(line.split(": ")[1].split()[0])
|
||||
elif line.startswith("Length of Structure_2:"):
|
||||
metrics["Length of Structure_2"] = int(line.split(": ")[1].split()[0])
|
||||
elif line.startswith("Aligned length="):
|
||||
aligned_length = line.split("=")[1].split(",")[0]
|
||||
rmsd = line.split("=")[2].split(",")[0]
|
||||
seq_id = line.split("=")[4]
|
||||
metrics["Aligned length"] = int(aligned_length.strip())
|
||||
metrics["RMSD"] = float(rmsd.strip())
|
||||
metrics["Seq_ID"] = float(seq_id.strip())
|
||||
elif line.startswith("TM-score="):
|
||||
if "normalized by length of Structure_1" in line:
|
||||
metrics["TM-score_1"] = float(line.split("=")[1].split()[0])
|
||||
elif "normalized by length of Structure_2" in line:
|
||||
metrics["TM-score_2"] = float(line.split("=")[1].split()[0])
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def compute_per_atom_lddt(
|
||||
batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Computes per-atom local distance difference test (LDDT) between predicted and target
|
||||
coordinates.
|
||||
|
||||
:param batch: Dictionary containing metadata and target coordinates.
|
||||
:param pred_coords: Predicted atomic coordinates.
|
||||
:param target_coords: Target atomic coordinates.
|
||||
:return: Tuple of lDDT and lDDT list.
|
||||
"""
|
||||
pred_coords = pred_coords.contiguous().view(batch["metadata"]["num_structid"], -1, 3)
|
||||
target_coords = target_coords.contiguous().view(batch["metadata"]["num_structid"], -1, 3)
|
||||
target_dist = (target_coords[:, :, None] - target_coords[:, None, :]).norm(dim=-1)
|
||||
pred_dist = (pred_coords[:, :, None] - pred_coords[:, None, :]).norm(dim=-1)
|
||||
conserved_mask = target_dist < 15.0
|
||||
lddt_list = []
|
||||
thresholds = [0, 0.5, 1, 2, 4, 6, 8, 12, 1e9]
|
||||
for threshold_idx in range(8):
|
||||
distdiff = (pred_dist - target_dist).abs()
|
||||
bin_fraction = (distdiff > thresholds[threshold_idx]) & (
|
||||
distdiff < thresholds[threshold_idx + 1]
|
||||
)
|
||||
lddt_list.append(
|
||||
bin_fraction.mul(conserved_mask).long().sum(dim=2) / conserved_mask.long().sum(dim=2)
|
||||
)
|
||||
lddt_list = torch.stack(lddt_list, dim=-1)
|
||||
lddt = torch.cumsum(lddt_list[:, :, :4], dim=-1).mean(dim=-1)
|
||||
return lddt, lddt_list
|
||||
446
flowdock/utils/model_utils.py
Normal file
446
flowdock/utils/model_utils.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# Adapted from: https://github.com/zrqiao/NeuralPLexer
|
||||
|
||||
import rootutils
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from beartype.typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from torch_scatter import scatter_max, scatter_min
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.utils import RankedLogger
|
||||
|
||||
MODEL_BATCH = Dict[str, Any]
|
||||
STATE_DICT = Dict[str, Any]
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class GELUMLP(torch.nn.Module):
|
||||
"""Simple MLP with post-LayerNorm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_in_feats: int,
|
||||
n_out_feats: int,
|
||||
n_hidden_feats: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
zero_init: bool = False,
|
||||
):
|
||||
"""Initialize the GELUMLP."""
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
if n_hidden_feats is None:
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(n_in_feats, n_in_feats),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.LayerNorm(n_in_feats),
|
||||
torch.nn.Linear(n_in_feats, n_out_feats),
|
||||
)
|
||||
else:
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(n_in_feats, n_hidden_feats),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Dropout(p=self.dropout),
|
||||
torch.nn.Linear(n_hidden_feats, n_hidden_feats),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.LayerNorm(n_hidden_feats),
|
||||
torch.nn.Linear(n_hidden_feats, n_out_feats),
|
||||
)
|
||||
torch.nn.init.xavier_uniform_(self.layers[0].weight, gain=1)
|
||||
# zero init for residual branches
|
||||
if zero_init:
|
||||
self.layers[-1].weight.data.fill_(0.0)
|
||||
else:
|
||||
torch.nn.init.xavier_uniform_(self.layers[-1].weight, gain=1)
|
||||
|
||||
def _zero_init(self, module):
|
||||
"""Zero-initialize weights and biases."""
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module.weight.data.zero_()
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass through the GELUMLP."""
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class SumPooling(torch.nn.Module):
|
||||
"""Sum pooling layer."""
|
||||
|
||||
def __init__(self, learnable: bool, hidden_dim: int = 1):
|
||||
"""Initialize the SumPooling layer."""
|
||||
super().__init__()
|
||||
self.pooled_transform = (
|
||||
torch.nn.Linear(hidden_dim, hidden_dim) if learnable else torch.nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, dst_idx, dst_size):
|
||||
"""Forward pass through the SumPooling layer."""
|
||||
return self.pooled_transform(segment_sum(x, dst_idx, dst_size))
|
||||
|
||||
|
||||
class AveragePooling(torch.nn.Module):
|
||||
"""Average pooling layer."""
|
||||
|
||||
def __init__(self, learnable: bool, hidden_dim: int = 1):
|
||||
"""Initialize the AveragePooling layer."""
|
||||
super().__init__()
|
||||
self.pooled_transform = (
|
||||
torch.nn.Linear(hidden_dim, hidden_dim) if learnable else torch.nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, dst_idx, dst_size):
|
||||
"""Forward pass through the AveragePooling layer."""
|
||||
out = torch.zeros(
|
||||
dst_size,
|
||||
*x.shape[1:],
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
).index_add_(0, dst_idx, x)
|
||||
nmr = torch.zeros(
|
||||
dst_size,
|
||||
*x.shape[1:],
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
).index_add_(0, dst_idx, torch.ones_like(x))
|
||||
return self.pooled_transform(out / (nmr + 1e-8))
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
"""Initialize weights with Kaiming uniform."""
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.kaiming_uniform_(m.weight)
|
||||
|
||||
|
||||
def segment_sum(src, dst_idx, dst_size):
|
||||
"""Computes the sum of each segment in a tensor."""
|
||||
out = torch.zeros(
|
||||
dst_size,
|
||||
*src.shape[1:],
|
||||
dtype=src.dtype,
|
||||
device=src.device,
|
||||
).index_add_(0, dst_idx, src)
|
||||
return out
|
||||
|
||||
|
||||
def segment_mean(src, dst_idx, dst_size):
|
||||
"""Computes the mean value of each segment in a tensor."""
|
||||
out = torch.zeros(
|
||||
dst_size,
|
||||
*src.shape[1:],
|
||||
dtype=src.dtype,
|
||||
device=src.device,
|
||||
).index_add_(0, dst_idx, src)
|
||||
denom = (
|
||||
torch.zeros(
|
||||
dst_size,
|
||||
*src.shape[1:],
|
||||
dtype=src.dtype,
|
||||
device=src.device,
|
||||
).index_add_(0, dst_idx, torch.ones_like(src))
|
||||
+ 1e-8
|
||||
)
|
||||
return out / denom
|
||||
|
||||
|
||||
def segment_argmin(scores, dst_idx, dst_size, randomize: bool = False) -> torch.Tensor:
|
||||
"""Samples the index of the minimum value in each segment."""
|
||||
if randomize:
|
||||
noise = torch.rand_like(scores)
|
||||
scores = scores - torch.log(-torch.log(noise))
|
||||
_, sampled_idx = scatter_min(scores, dst_idx, dim=0, dim_size=dst_size)
|
||||
return sampled_idx
|
||||
|
||||
|
||||
def segment_logsumexp(src, dst_idx, dst_size, extra_dims=None):
|
||||
"""Computes the logsumexp of each segment in a tensor."""
|
||||
src_max, _ = scatter_max(src, dst_idx, dim=0, dim_size=dst_size)
|
||||
if extra_dims is not None:
|
||||
src_max = torch.amax(src_max, dim=extra_dims, keepdim=True)
|
||||
src = src - src_max[dst_idx]
|
||||
out = torch.zeros(
|
||||
dst_size,
|
||||
*src.shape[1:],
|
||||
dtype=src.dtype,
|
||||
device=src.device,
|
||||
).index_add_(0, dst_idx, torch.exp(src))
|
||||
if extra_dims is not None:
|
||||
out = torch.sum(out, dim=extra_dims)
|
||||
return torch.log(out + 1e-8) + src_max.view(*out.shape)
|
||||
|
||||
|
||||
def segment_softmax(src, dst_idx, dst_size, extra_dims=None, floor_value=None):
|
||||
"""Computes the softmax of each segment in a tensor."""
|
||||
src_max, _ = scatter_max(src, dst_idx, dim=0, dim_size=dst_size)
|
||||
if extra_dims is not None:
|
||||
src_max = torch.amax(src_max, dim=extra_dims, keepdim=True)
|
||||
src = src - src_max[dst_idx]
|
||||
exp1 = torch.exp(src)
|
||||
exp0 = torch.zeros(
|
||||
dst_size,
|
||||
*src.shape[1:],
|
||||
dtype=src.dtype,
|
||||
device=src.device,
|
||||
).index_add_(0, dst_idx, exp1)
|
||||
if extra_dims is not None:
|
||||
exp0 = torch.sum(exp0, dim=extra_dims, keepdim=True)
|
||||
exp0 = torch.index_select(input=exp0, dim=0, index=dst_idx)
|
||||
exp = exp1.div(exp0 + 1e-8)
|
||||
if floor_value is not None:
|
||||
exp = exp.clamp(min=floor_value)
|
||||
exp0 = torch.zeros(
|
||||
dst_size,
|
||||
*src.shape[1:],
|
||||
dtype=src.dtype,
|
||||
device=src.device,
|
||||
).index_add_(0, dst_idx, exp)
|
||||
if extra_dims is not None:
|
||||
exp0 = torch.sum(exp0, dim=extra_dims, keepdim=True)
|
||||
exp0 = torch.index_select(input=exp0, dim=0, index=dst_idx)
|
||||
exp = exp.div(exp0 + 1e-8)
|
||||
return exp
|
||||
|
||||
|
||||
def batched_sample_onehot(logits, dim=0, max_only=False):
|
||||
"""Implements the Gumbel-Max trick to sample from a one-hot distribution."""
|
||||
if max_only:
|
||||
sampled_idx = torch.argmax(logits, dim=dim, keepdim=True)
|
||||
else:
|
||||
noise = torch.rand_like(logits)
|
||||
sampled_idx = torch.argmax(logits - torch.log(-torch.log(noise)), dim=dim, keepdim=True)
|
||||
out_onehot = torch.zeros_like(logits, dtype=torch.bool)
|
||||
out_onehot.scatter_(dim=dim, index=sampled_idx, value=1)
|
||||
return out_onehot
|
||||
|
||||
|
||||
def topk_edge_mask_from_logits(scores, k, randomize=False):
|
||||
"""Samples the top-k edges from a set of logits."""
|
||||
assert len(scores.shape) == 3, "Scores should have shape [B, N, N]"
|
||||
if randomize:
|
||||
noise = torch.rand_like(scores)
|
||||
scores = scores - torch.log(-torch.log(noise))
|
||||
node_degree = min(k, scores.shape[2])
|
||||
_, topk_idx = torch.topk(scores, node_degree, dim=-1, largest=True)
|
||||
edge_mask = scores.new_zeros(scores.shape, dtype=torch.bool)
|
||||
edge_mask = edge_mask.scatter_(dim=2, index=topk_idx, value=1).bool()
|
||||
return edge_mask
|
||||
|
||||
|
||||
def sample_inplace_to_torch(sample):
|
||||
"""Convert NumPy sample to PyTorch tensors."""
|
||||
if sample is None:
|
||||
return None
|
||||
sample["features"] = {k: torch.FloatTensor(v) for k, v in sample["features"].items()}
|
||||
sample["indexer"] = {k: torch.LongTensor(v) for k, v in sample["indexer"].items()}
|
||||
if "labels" in sample.keys():
|
||||
sample["labels"] = {k: torch.FloatTensor(v) for k, v in sample["labels"].items()}
|
||||
return sample
|
||||
|
||||
|
||||
def inplace_to_device(sample, device):
|
||||
"""Move sample to device."""
|
||||
sample["features"] = {k: v.to(device) for k, v in sample["features"].items()}
|
||||
sample["indexer"] = {k: v.to(device) for k, v in sample["indexer"].items()}
|
||||
if "labels" in sample.keys():
|
||||
sample["labels"] = sample["labels"].to(device)
|
||||
return sample
|
||||
|
||||
|
||||
def inplace_to_torch(sample):
|
||||
"""Convert NumPy sample to PyTorch tensors."""
|
||||
if sample is None:
|
||||
return None
|
||||
sample["features"] = {k: torch.FloatTensor(v) for k, v in sample["features"].items()}
|
||||
sample["indexer"] = {k: torch.LongTensor(v) for k, v in sample["indexer"].items()}
|
||||
if "labels" in sample.keys():
|
||||
sample["labels"] = {k: torch.FloatTensor(v) for k, v in sample["labels"].items()}
|
||||
return sample
|
||||
|
||||
|
||||
def distance_to_gaussian_contact_logits(
|
||||
x: torch.Tensor, contact_scale: float, cutoff: Optional[float] = None
|
||||
) -> torch.Tensor:
|
||||
"""Convert distance to Gaussian contact logits.
|
||||
|
||||
:param x: Distance tensor.
|
||||
:param contact_scale: The contact scale.
|
||||
:param cutoff: The distance cutoff.
|
||||
:return: Gaussian contact logits.
|
||||
"""
|
||||
if cutoff is None:
|
||||
cutoff = contact_scale * 2
|
||||
return torch.log(torch.clamp(1 - (x / cutoff), min=1e-9))
|
||||
|
||||
|
||||
def distogram_to_gaussian_contact_logits(
|
||||
dgram: torch.Tensor, dist_bins: torch.Tensor, contact_scale: float
|
||||
) -> torch.Tensor:
|
||||
"""Convert a distance histogram (distogram) matrix to a Gaussian contact map.
|
||||
|
||||
:param dgram: A distogram matrix.
|
||||
:return: A Gaussian contact map.
|
||||
"""
|
||||
return torch.logsumexp(
|
||||
dgram + distance_to_gaussian_contact_logits(dist_bins, contact_scale),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def eval_true_contact_maps(
|
||||
batch: MODEL_BATCH, contact_scale: float, **kwargs: Dict[str, Any]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Evaluate true contact maps.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param contact_scale: The contact scale.
|
||||
:param kwargs: Additional keyword arguments.
|
||||
:return: True contact maps.
|
||||
"""
|
||||
indexer = batch["indexer"]
|
||||
batch_size = batch["metadata"]["num_structid"]
|
||||
with torch.no_grad():
|
||||
# Residue centroids
|
||||
res_cent_coords = (
|
||||
batch["features"]["res_atom_positions"]
|
||||
.mul(batch["features"]["res_atom_mask"].bool()[:, :, None])
|
||||
.sum(dim=1)
|
||||
.div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9)
|
||||
)
|
||||
res_lig_dist = (
|
||||
res_cent_coords.view(batch_size, -1, 3)[:, :, None]
|
||||
- batch["features"]["sdf_coordinates"][indexer["gather_idx_U_u"]].view(
|
||||
batch_size, -1, 3
|
||||
)[:, None, :]
|
||||
).norm(dim=-1)
|
||||
res_lig_contact_logit = distance_to_gaussian_contact_logits(
|
||||
res_lig_dist, contact_scale, **kwargs
|
||||
)
|
||||
return res_lig_dist, res_lig_contact_logit.flatten()
|
||||
|
||||
|
||||
def sample_reslig_contact_matrix(
|
||||
batch: MODEL_BATCH, res_lig_logits: torch.Tensor, last: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Sample residue-ligand contact matrix.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param res_lig_logits: Residue-ligand contact logits.
|
||||
:param last: The last contact matrix.
|
||||
:return: Sampled residue-ligand contact matrix.
|
||||
"""
|
||||
metadata = batch["metadata"]
|
||||
batch_size = metadata["num_structid"]
|
||||
max(metadata["num_molid_per_sample"])
|
||||
n_a_per_sample = max(metadata["num_a_per_sample"])
|
||||
n_I_per_sample = max(metadata["num_I_per_sample"])
|
||||
res_lig_logits = res_lig_logits.view(batch_size, n_a_per_sample, n_I_per_sample)
|
||||
# Sampling from unoccupied lattice sites
|
||||
if last is None:
|
||||
last = torch.zeros_like(res_lig_logits, dtype=torch.bool)
|
||||
# Column-graph-wise masking for already sampled ligands
|
||||
# sampled_ligand_mask = torch.amax(last, dim=1, keepdim=True)
|
||||
sampled_frame_mask = torch.sum(last, dim=1, keepdim=True).contiguous()
|
||||
masked_logits = res_lig_logits - sampled_frame_mask * 1e9
|
||||
sampled_block_onehot = batched_sample_onehot(masked_logits.flatten(1, 2), dim=1).view(
|
||||
batch_size, n_a_per_sample, n_I_per_sample
|
||||
)
|
||||
new_block_contact_mat = last + sampled_block_onehot
|
||||
# Remove non-contact patches
|
||||
valid_logit_mask = res_lig_logits > -16.0
|
||||
new_block_contact_mat = (new_block_contact_mat * valid_logit_mask).bool()
|
||||
return new_block_contact_mat
|
||||
|
||||
|
||||
def merge_res_lig_logits_to_graph(
|
||||
batch: MODEL_BATCH,
|
||||
res_lig_logits: torch.Tensor,
|
||||
single_protein_batch: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Patch merging [B, N_res, N_atm] -> [B, N_res, N_graph].
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param res_lig_logits: Residue-ligand contact logits.
|
||||
:param single_protein_batch: Whether to use single protein batch.
|
||||
:return: Merged residue-ligand logits.
|
||||
"""
|
||||
assert single_protein_batch, "Only single protein batch is supported."
|
||||
metadata = batch["metadata"]
|
||||
indexer = batch["indexer"]
|
||||
batch_size = metadata["num_structid"]
|
||||
max(metadata["num_molid_per_sample"])
|
||||
n_mol_per_sample = max(metadata["num_molid_per_sample"])
|
||||
n_a_per_sample = max(metadata["num_a_per_sample"])
|
||||
n_I_per_sample = max(metadata["num_I_per_sample"])
|
||||
res_lig_logits = res_lig_logits.view(batch_size, n_a_per_sample, n_I_per_sample)
|
||||
graph_wise_logits = segment_logsumexp(
|
||||
res_lig_logits.permute(2, 0, 1),
|
||||
indexer["gather_idx_I_molid"][:n_I_per_sample],
|
||||
n_mol_per_sample,
|
||||
).permute(1, 2, 0)
|
||||
return graph_wise_logits
|
||||
|
||||
|
||||
def sample_res_rowmask_from_contacts(
|
||||
batch: MODEL_BATCH,
|
||||
res_ligatm_logits: torch.Tensor,
|
||||
single_protein_batch: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Sample residue row mask from contacts.
|
||||
|
||||
:param batch: A batch dictionary.
|
||||
:param res_ligatm_logits: Residue-ligand atom contact logits.
|
||||
:return: Sampled residue row mask.
|
||||
"""
|
||||
metadata = batch["metadata"]
|
||||
max(metadata["num_molid_per_sample"])
|
||||
lig_wise_logits = (
|
||||
merge_res_lig_logits_to_graph(batch, res_ligatm_logits, single_protein_batch)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
)
|
||||
sampled_res_onehot_mask = batched_sample_onehot(lig_wise_logits.flatten(0, 1), dim=1)
|
||||
return sampled_res_onehot_mask
|
||||
|
||||
|
||||
def extract_esm_embeddings(
|
||||
esm_model: torch.nn.Module,
|
||||
esm_alphabet: torch.nn.Module,
|
||||
esm_batch_converter: torch.nn.Module,
|
||||
sequences: List[str],
|
||||
device: Union[str, torch.device],
|
||||
esm_repr_layer: int = 33,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Extract embeddings from ESM model.
|
||||
|
||||
:param esm_model: ESM model.
|
||||
:param esm_alphabet: ESM alphabet.
|
||||
:param esm_batch_converter: ESM batch converter.
|
||||
:param sequences: A list of sequences.
|
||||
:param device: Device to use.
|
||||
:param esm_repr_layer: ESM representation layer index from which to extract embeddings.
|
||||
:return: A corresponding list of embeddings.
|
||||
"""
|
||||
# Disable dropout for deterministic results
|
||||
esm_model.eval()
|
||||
|
||||
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
|
||||
data = [(str(i), seq) for i, seq in enumerate(sequences)]
|
||||
_, _, batch_tokens = esm_batch_converter(data)
|
||||
batch_tokens = batch_tokens.to(device)
|
||||
batch_lens = (batch_tokens != esm_alphabet.padding_idx).sum(1)
|
||||
|
||||
# Extract per-residue representations (on CPU)
|
||||
with torch.no_grad():
|
||||
results = esm_model(batch_tokens, repr_layers=[esm_repr_layer], return_contacts=True)
|
||||
token_representations = results["representations"][esm_repr_layer]
|
||||
|
||||
# Generate per-residue representations
|
||||
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
|
||||
sequence_representations = []
|
||||
for i, tokens_len in enumerate(batch_lens):
|
||||
sequence_representations.append(token_representations[i, 1 : tokens_len - 1])
|
||||
|
||||
return sequence_representations
|
||||
51
flowdock/utils/pylogger.py
Normal file
51
flowdock/utils/pylogger.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
|
||||
from beartype.typing import Mapping, Optional
|
||||
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
||||
|
||||
|
||||
class RankedLogger(logging.LoggerAdapter):
|
||||
"""A multi-GPU-friendly python command line logger."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = __name__,
|
||||
rank_zero_only: bool = False,
|
||||
extra: Optional[Mapping[str, object]] = None,
|
||||
) -> None:
|
||||
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
||||
with their rank prefixed in the log message.
|
||||
|
||||
:param name: The name of the logger. Default is ``__name__``.
|
||||
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
||||
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
super().__init__(logger=logger, extra=extra)
|
||||
self.rank_zero_only = rank_zero_only
|
||||
|
||||
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
|
||||
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
||||
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
||||
occur on that rank/process.
|
||||
|
||||
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
||||
:param msg: The message to log.
|
||||
:param rank: The rank to log at.
|
||||
:param args: Additional args to pass to the underlying logging function.
|
||||
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
||||
"""
|
||||
if self.isEnabledFor(level):
|
||||
msg, kwargs = self.process(msg, kwargs)
|
||||
current_rank = getattr(rank_zero_only, "rank", None)
|
||||
if current_rank is None:
|
||||
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
|
||||
msg = rank_prefixed_message(msg, current_rank)
|
||||
if self.rank_zero_only:
|
||||
if current_rank == 0:
|
||||
self.logger.log(level, msg, *args, **kwargs)
|
||||
else:
|
||||
if rank is None:
|
||||
self.logger.log(level, msg, *args, **kwargs)
|
||||
elif current_rank == rank:
|
||||
self.logger.log(level, msg, *args, **kwargs)
|
||||
102
flowdock/utils/rich_utils.py
Normal file
102
flowdock/utils/rich_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from pathlib import Path
|
||||
|
||||
import rich
|
||||
import rich.syntax
|
||||
import rich.tree
|
||||
import rootutils
|
||||
from beartype.typing import Sequence
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from rich.prompt import Prompt
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.utils import pylogger
|
||||
|
||||
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def print_config_tree(
|
||||
cfg: DictConfig,
|
||||
print_order: Sequence[str] = (
|
||||
"data",
|
||||
"model",
|
||||
"callbacks",
|
||||
"logger",
|
||||
"trainer",
|
||||
"paths",
|
||||
"extras",
|
||||
),
|
||||
resolve: bool = False,
|
||||
save_to_file: bool = False,
|
||||
) -> None:
|
||||
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
||||
|
||||
:param cfg: A DictConfig composed by Hydra.
|
||||
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
|
||||
"callbacks", "logger", "trainer", "paths", "extras")``.
|
||||
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
||||
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
|
||||
"""
|
||||
style = "dim"
|
||||
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
||||
|
||||
queue = []
|
||||
|
||||
# add fields from `print_order` to queue
|
||||
for field in print_order:
|
||||
queue.append(field) if field in cfg else log.warning(
|
||||
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
||||
)
|
||||
|
||||
# add all the other fields to queue (not specified in `print_order`)
|
||||
for field in cfg:
|
||||
if field not in queue:
|
||||
queue.append(field)
|
||||
|
||||
# generate config tree from queue
|
||||
for field in queue:
|
||||
branch = tree.add(field, style=style, guide_style=style)
|
||||
|
||||
config_group = cfg[field]
|
||||
if isinstance(config_group, DictConfig):
|
||||
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
||||
else:
|
||||
branch_content = str(config_group)
|
||||
|
||||
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
||||
|
||||
# print config tree
|
||||
rich.print(tree)
|
||||
|
||||
# save config tree to file
|
||||
if save_to_file:
|
||||
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
||||
rich.print(tree, file=file)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
||||
"""Prompts user to input tags from command line if no tags are provided in config.
|
||||
|
||||
:param cfg: A DictConfig composed by Hydra.
|
||||
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
|
||||
"""
|
||||
if not cfg.get("tags"):
|
||||
if "id" in HydraConfig().cfg.hydra.job:
|
||||
raise ValueError("Specify tags before launching a multirun!")
|
||||
|
||||
log.warning("No tags provided in config. Prompting user to input tags...")
|
||||
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
||||
tags = [t.strip() for t in tags.split(",") if t != ""]
|
||||
|
||||
with open_dict(cfg):
|
||||
cfg.tags = tags
|
||||
|
||||
log.info(f"Tags: {cfg.tags}")
|
||||
|
||||
if save_to_file:
|
||||
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
||||
rich.print(cfg.tags, file=file)
|
||||
427
flowdock/utils/sampling_utils.py
Normal file
427
flowdock/utils/sampling_utils.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import rootutils
|
||||
import torch
|
||||
from beartype.typing import Any, Dict, List, Optional, Tuple
|
||||
from lightning import LightningModule
|
||||
from omegaconf import DictConfig
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import AllChem
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.data.components.mol_features import (
|
||||
collate_numpy_samples,
|
||||
process_mol_file,
|
||||
)
|
||||
from flowdock.utils import RankedLogger
|
||||
from flowdock.utils.data_utils import (
|
||||
FDProtein,
|
||||
merge_protein_and_ligands,
|
||||
pdb_filepath_to_protein,
|
||||
prepare_batch,
|
||||
process_protein,
|
||||
)
|
||||
from flowdock.utils.model_utils import inplace_to_device, inplace_to_torch, segment_mean
|
||||
from flowdock.utils.visualization_utils import (
|
||||
write_conformer_sdf,
|
||||
write_pdb_models,
|
||||
write_pdb_single,
|
||||
)
|
||||
|
||||
log = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
def featurize_protein_and_ligands(
|
||||
rec_path: str,
|
||||
lig_paths: List[str],
|
||||
n_lig_patches: int,
|
||||
apo_rec_path: Optional[str] = None,
|
||||
chain_id: Optional[str] = None,
|
||||
protein: Optional[FDProtein] = None,
|
||||
sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None,
|
||||
enforce_sanitization: bool = False,
|
||||
discard_sdf_coords: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
):
|
||||
"""Featurize a protein-ligand complex.
|
||||
|
||||
:param rec_path: Path to the receptor file.
|
||||
:param lig_paths: List of paths to the ligand files.
|
||||
:param n_lig_patches: Number of ligand patches.
|
||||
:param apo_rec_path: Path to the apo receptor file.
|
||||
:param chain_id: Chain ID of the receptor.
|
||||
:param protein: Optional protein object.
|
||||
:param sequences_to_embeddings: Mapping of sequences to embeddings.
|
||||
:param enforce_sanitization: Whether to enforce sanitization.
|
||||
:param discard_sdf_coords: Whether to discard SDF coordinates.
|
||||
:param kwargs: Additional keyword arguments.
|
||||
:return: Featurized protein-ligand complex.
|
||||
"""
|
||||
assert rec_path is not None
|
||||
if lig_paths is None:
|
||||
lig_paths = []
|
||||
if isinstance(lig_paths, str):
|
||||
lig_paths = [lig_paths]
|
||||
out_mol = None
|
||||
lig_samples = []
|
||||
for lig_path in lig_paths:
|
||||
try:
|
||||
lig_sample, mol_ref = process_mol_file(
|
||||
lig_path,
|
||||
sanitize=True,
|
||||
return_mol=True,
|
||||
discard_coords=discard_sdf_coords,
|
||||
)
|
||||
except Exception as e:
|
||||
if enforce_sanitization:
|
||||
raise
|
||||
log.warning(
|
||||
f"RDKit sanitization failed for ligand {lig_path} due to: {e}. Loading raw attributes."
|
||||
)
|
||||
lig_sample, mol_ref = process_mol_file(
|
||||
lig_path,
|
||||
sanitize=False,
|
||||
return_mol=True,
|
||||
discard_coords=discard_sdf_coords,
|
||||
)
|
||||
lig_samples.append(lig_sample)
|
||||
if out_mol is None:
|
||||
out_mol = mol_ref
|
||||
else:
|
||||
out_mol = AllChem.CombineMols(out_mol, mol_ref)
|
||||
protein = protein if protein is not None else pdb_filepath_to_protein(rec_path)
|
||||
rec_sample = process_protein(
|
||||
protein,
|
||||
chain_id=chain_id,
|
||||
sequences_to_embeddings=None if apo_rec_path is not None else sequences_to_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
if apo_rec_path is not None:
|
||||
apo_protein = pdb_filepath_to_protein(apo_rec_path)
|
||||
apo_rec_sample = process_protein(
|
||||
apo_protein,
|
||||
chain_id=chain_id,
|
||||
sequences_to_embeddings=sequences_to_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
for key in rec_sample.keys():
|
||||
for subkey, value in apo_rec_sample[key].items():
|
||||
rec_sample[key]["apo_" + subkey] = value
|
||||
merged_sample = merge_protein_and_ligands(
|
||||
lig_samples,
|
||||
rec_sample,
|
||||
n_lig_patches=n_lig_patches,
|
||||
label=None,
|
||||
)
|
||||
return merged_sample, out_mol
|
||||
|
||||
|
||||
def multi_pose_sampling(
|
||||
receptor_path: str,
|
||||
ligand_path: str,
|
||||
cfg: DictConfig,
|
||||
lit_module: LightningModule,
|
||||
out_path: str,
|
||||
save_pdb: bool = True,
|
||||
separate_pdb: bool = True,
|
||||
chain_id: Optional[str] = None,
|
||||
apo_receptor_path: Optional[str] = None,
|
||||
sample_id: Optional[str] = None,
|
||||
protein: Optional[FDProtein] = None,
|
||||
sequences_to_embeddings: Optional[Dict[str, np.ndarray]] = None,
|
||||
confidence: bool = True,
|
||||
affinity: bool = True,
|
||||
return_all_states: bool = False,
|
||||
auxiliary_estimation_only: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[
|
||||
Optional[Chem.Mol],
|
||||
Optional[List[float]],
|
||||
Optional[List[float]],
|
||||
Optional[List[float]],
|
||||
Optional[List[Any]],
|
||||
Optional[Any],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
]:
|
||||
"""Sample multiple poses of a protein-ligand complex.
|
||||
|
||||
:param receptor_path: Path to the receptor file.
|
||||
:param ligand_path: Path to the ligand file.
|
||||
:param cfg: Config dictionary.
|
||||
:param lit_module: LightningModule instance.
|
||||
:param out_path: Path to save the output files.
|
||||
:param save_pdb: Whether to save PDB files.
|
||||
:param separate_pdb: Whether to save separate PDB files for each pose.
|
||||
:param chain_id: Chain ID of the receptor.
|
||||
:param apo_receptor_path: Path to the optional apo receptor file.
|
||||
:param sample_id: Optional sample ID.
|
||||
:param protein: Optional protein object.
|
||||
:param sequences_to_embeddings: Mapping of sequences to embeddings.
|
||||
:param confidence: Whether to estimate confidence scores.
|
||||
:param affinity: Whether to estimate affinity scores.
|
||||
:param return_all_states: Whether to return all states.
|
||||
:param auxiliary_estimation_only: Whether to only estimate auxiliary outputs (e.g., confidence,
|
||||
affinity) for the input (generated) samples (potentially derived from external sources).
|
||||
:param kwargs: Additional keyword arguments.
|
||||
:return: Reference molecule, protein plDDTs, ligand plDDTs, ligand fragment plDDTs, estimated
|
||||
binding affinities, structure trajectories, input batch, B-factors, and structure rankings.
|
||||
"""
|
||||
if return_all_states and auxiliary_estimation_only:
|
||||
# NOTE: If auxiliary estimation is solely enabled, structure trajectory sampling will be disabled
|
||||
return_all_states = False
|
||||
struct_res_all, lig_res_all = [], []
|
||||
plddt_all, plddt_lig_all, plddt_ligs_all, res_plddt_all = [], [], [], []
|
||||
affinity_all, ligs_affinity_all = [], []
|
||||
frames_all = []
|
||||
chunk_size = cfg.chunk_size
|
||||
for _ in range(cfg.n_samples // chunk_size):
|
||||
# Resample anchor node frames
|
||||
np_sample, mol = featurize_protein_and_ligands(
|
||||
receptor_path,
|
||||
ligand_path,
|
||||
n_lig_patches=lit_module.hparams.cfg.mol_encoder.n_patches,
|
||||
apo_rec_path=apo_receptor_path,
|
||||
chain_id=chain_id,
|
||||
protein=protein,
|
||||
sequences_to_embeddings=sequences_to_embeddings,
|
||||
discard_sdf_coords=cfg.discard_sdf_coords and not auxiliary_estimation_only,
|
||||
**kwargs,
|
||||
)
|
||||
np_sample_batched = collate_numpy_samples([np_sample for _ in range(chunk_size)])
|
||||
sample = inplace_to_device(inplace_to_torch(np_sample_batched), device=lit_module.device)
|
||||
prepare_batch(sample)
|
||||
if auxiliary_estimation_only:
|
||||
# Predict auxiliary quantities using the provided input protein and ligand structures
|
||||
if "num_molid" in sample["metadata"].keys() and sample["metadata"]["num_molid"] > 0:
|
||||
sample["misc"]["protein_only"] = False
|
||||
else:
|
||||
sample["misc"]["protein_only"] = True
|
||||
output_struct = {
|
||||
"receptor": sample["features"]["res_atom_positions"].flatten(0, 1),
|
||||
"receptor_padded": sample["features"]["res_atom_positions"],
|
||||
"ligands": sample["features"]["sdf_coordinates"],
|
||||
}
|
||||
else:
|
||||
output_struct = lit_module.net.sample_pl_complex_structures(
|
||||
sample,
|
||||
sampler=cfg.sampler,
|
||||
sampler_eta=cfg.sampler_eta,
|
||||
num_steps=cfg.num_steps,
|
||||
return_all_states=return_all_states,
|
||||
start_time=cfg.start_time,
|
||||
exact_prior=cfg.exact_prior,
|
||||
)
|
||||
frames_all.append(output_struct.get("all_frames", None))
|
||||
if mol is not None:
|
||||
ref_mol = AllChem.Mol(mol)
|
||||
out_x1 = np.split(output_struct["ligands"].cpu().numpy(), cfg.chunk_size)
|
||||
out_x2 = np.split(output_struct["receptor_padded"].cpu().numpy(), cfg.chunk_size)
|
||||
if confidence and affinity:
|
||||
assert (
|
||||
lit_module.net.confidence_cfg.enabled
|
||||
), "Confidence estimation must be enabled in the model configuration."
|
||||
assert (
|
||||
lit_module.net.affinity_cfg.enabled
|
||||
), "Affinity estimation must be enabled in the model configuration."
|
||||
plddt, plddt_lig, plddt_ligs = lit_module.net.run_auxiliary_estimation(
|
||||
sample,
|
||||
output_struct,
|
||||
return_avg_stats=True,
|
||||
training=False,
|
||||
)
|
||||
aff = sample["outputs"]["affinity_logits"]
|
||||
elif confidence:
|
||||
assert (
|
||||
lit_module.net.confidence_cfg.enabled
|
||||
), "Confidence estimation must be enabled in the model configuration."
|
||||
plddt, plddt_lig, plddt_ligs = lit_module.net.run_auxiliary_estimation(
|
||||
sample,
|
||||
output_struct,
|
||||
return_avg_stats=True,
|
||||
training=False,
|
||||
)
|
||||
elif affinity:
|
||||
assert (
|
||||
lit_module.net.affinity_cfg.enabled
|
||||
), "Affinity estimation must be enabled in the model configuration."
|
||||
lit_module.net.run_auxiliary_estimation(
|
||||
sample, output_struct, return_avg_stats=True, training=False
|
||||
)
|
||||
plddt, plddt_lig = None, None
|
||||
aff = sample["outputs"]["affinity_logits"].cpu()
|
||||
|
||||
mol_idx_i_structid = segment_mean(
|
||||
sample["indexer"]["gather_idx_i_structid"],
|
||||
sample["indexer"]["gather_idx_i_molid"],
|
||||
sample["metadata"]["num_molid"],
|
||||
).long()
|
||||
for struct_idx in range(cfg.chunk_size):
|
||||
struct_res = {
|
||||
"features": {
|
||||
"asym_id": np_sample["features"]["res_chain_id"],
|
||||
"residue_index": np.arange(len(np_sample["features"]["res_type"])) + 1,
|
||||
"aatype": np_sample["features"]["res_type"],
|
||||
},
|
||||
"structure_module": {
|
||||
"final_atom_positions": out_x2[struct_idx],
|
||||
"final_atom_mask": sample["features"]["res_atom_mask"].bool().cpu().numpy(),
|
||||
},
|
||||
}
|
||||
struct_res_all.append(struct_res)
|
||||
if mol is not None:
|
||||
lig_res_all.append(out_x1[struct_idx])
|
||||
if confidence:
|
||||
plddt_all.append(plddt[struct_idx].item())
|
||||
res_plddt_all.append(
|
||||
sample["outputs"]["plddt"][
|
||||
struct_idx, : sample["metadata"]["num_a_per_sample"][0]
|
||||
]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
if plddt_lig is None:
|
||||
plddt_lig_all.append(None)
|
||||
else:
|
||||
plddt_lig_all.append(plddt_lig[struct_idx].item())
|
||||
if plddt_ligs is None:
|
||||
plddt_ligs_all.append(None)
|
||||
else:
|
||||
plddt_ligs_all.append(plddt_ligs[mol_idx_i_structid == struct_idx].tolist())
|
||||
if affinity:
|
||||
# collect the average affinity across all ligands in each complex
|
||||
ligs_aff = aff[mol_idx_i_structid == struct_idx]
|
||||
affinity_all.append(ligs_aff.mean().item())
|
||||
ligs_affinity_all.append(ligs_aff.tolist())
|
||||
if confidence and cfg.rank_outputs_by_confidence:
|
||||
plddt_lig_predicted = all(plddt_lig_all)
|
||||
if cfg.plddt_ranking_type == "protein":
|
||||
struct_plddts = np.array(plddt_all) # rank outputs using average protein plDDT
|
||||
elif cfg.plddt_ranking_type == "ligand":
|
||||
struct_plddts = np.array(
|
||||
plddt_lig_all if plddt_lig_predicted else plddt_all
|
||||
) # rank outputs using average ligand plDDT if available
|
||||
if not plddt_lig_predicted:
|
||||
log.warning(
|
||||
"Ligand plDDT not available for all samples, using protein plDDT instead"
|
||||
)
|
||||
elif cfg.plddt_ranking_type == "protein_ligand":
|
||||
struct_plddts = np.array(
|
||||
plddt_all + plddt_lig_all if plddt_lig_predicted else plddt_all
|
||||
) # rank outputs using the sum of the average protein and ligand plDDTs if ligand plDDT is available
|
||||
if not plddt_lig_predicted:
|
||||
log.warning(
|
||||
"Ligand plDDT not available for all samples, using protein plDDT instead"
|
||||
)
|
||||
struct_plddt_rankings = np.argsort(
|
||||
-struct_plddts
|
||||
).argsort() # ensure that higher plDDTs have a higher rank (e.g., `rank1`)
|
||||
receptor_plddt = np.array(res_plddt_all) if confidence else None
|
||||
b_factors = (
|
||||
np.repeat(
|
||||
receptor_plddt[..., None],
|
||||
struct_res_all[0]["structure_module"]["final_atom_mask"].shape[-1],
|
||||
axis=-1,
|
||||
)
|
||||
if confidence
|
||||
else None
|
||||
)
|
||||
if save_pdb:
|
||||
if separate_pdb:
|
||||
for struct_id, struct_res in enumerate(struct_res_all):
|
||||
if confidence and cfg.rank_outputs_by_confidence:
|
||||
write_pdb_single(
|
||||
struct_res,
|
||||
out_path=os.path.join(
|
||||
out_path,
|
||||
f"prot_rank{struct_plddt_rankings[struct_id] + 1}_plddt{struct_plddts[struct_id]:.7f}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.pdb",
|
||||
),
|
||||
b_factors=b_factors[struct_id] if confidence else None,
|
||||
)
|
||||
else:
|
||||
write_pdb_single(
|
||||
struct_res,
|
||||
out_path=os.path.join(
|
||||
out_path,
|
||||
f"prot_{struct_id}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.pdb",
|
||||
),
|
||||
b_factors=b_factors[struct_id] if confidence else None,
|
||||
)
|
||||
write_pdb_models(
|
||||
struct_res_all, out_path=os.path.join(out_path, "prot_all.pdb"), b_factors=b_factors
|
||||
)
|
||||
if mol is not None:
|
||||
write_conformer_sdf(ref_mol, None, out_path=os.path.join(out_path, "lig_ref.sdf"))
|
||||
lig_res_all = np.array(lig_res_all)
|
||||
write_conformer_sdf(mol, lig_res_all, out_path=os.path.join(out_path, "lig_all.sdf"))
|
||||
for struct_id in range(len(lig_res_all)):
|
||||
if confidence and cfg.rank_outputs_by_confidence:
|
||||
write_conformer_sdf(
|
||||
mol,
|
||||
lig_res_all[struct_id : struct_id + 1],
|
||||
out_path=os.path.join(
|
||||
out_path,
|
||||
f"lig_rank{struct_plddt_rankings[struct_id] + 1}_plddt{struct_plddts[struct_id]:.7f}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.sdf",
|
||||
),
|
||||
)
|
||||
else:
|
||||
write_conformer_sdf(
|
||||
mol,
|
||||
lig_res_all[struct_id : struct_id + 1],
|
||||
out_path=os.path.join(
|
||||
out_path,
|
||||
f"lig_{struct_id}{f'_affinity{affinity_all[struct_id]:.7f}' if affinity else ''}.sdf",
|
||||
),
|
||||
)
|
||||
if confidence:
|
||||
aux_estimation_all_df = pd.DataFrame(
|
||||
{
|
||||
"sample_id": [sample_id] * len(struct_res_all),
|
||||
"rank": struct_plddt_rankings + 1 if cfg.rank_outputs_by_confidence else None,
|
||||
"plddt_ligs": plddt_ligs_all,
|
||||
"affinity_ligs": ligs_affinity_all,
|
||||
}
|
||||
)
|
||||
aux_estimation_all_df.to_csv(
|
||||
os.path.join(out_path, f"{sample_id if sample_id is not None else 'sample'}_auxiliary_estimation.csv"), index=False
|
||||
)
|
||||
else:
|
||||
ref_mol = None
|
||||
if not confidence:
|
||||
plddt_all, plddt_lig_all, plddt_ligs_all = None, None, None
|
||||
if not affinity:
|
||||
affinity_all = None
|
||||
if return_all_states:
|
||||
if mol is not None:
|
||||
np_sample["metadata"]["sample_ID"] = sample_id if sample_id is not None else "sample"
|
||||
np_sample["metadata"]["mol"] = ref_mol
|
||||
batch_all = inplace_to_torch(
|
||||
collate_numpy_samples([np_sample for _ in range(cfg.n_samples)])
|
||||
)
|
||||
merge_frames_all = frames_all[0]
|
||||
for frames in frames_all[1:]:
|
||||
for frame_index, frame in enumerate(frames):
|
||||
for key in frame.keys():
|
||||
merge_frames_all[frame_index][key] = torch.cat(
|
||||
[merge_frames_all[frame_index][key], frame[key]], dim=0
|
||||
)
|
||||
frames_all = merge_frames_all
|
||||
else:
|
||||
frames_all = None
|
||||
batch_all = None
|
||||
if not (confidence and cfg.rank_outputs_by_confidence):
|
||||
struct_plddt_rankings = None
|
||||
return (
|
||||
ref_mol,
|
||||
plddt_all,
|
||||
plddt_lig_all,
|
||||
plddt_ligs_all,
|
||||
affinity_all,
|
||||
frames_all,
|
||||
batch_all,
|
||||
b_factors,
|
||||
struct_plddt_rankings,
|
||||
)
|
||||
153
flowdock/utils/utils.py
Normal file
153
flowdock/utils/utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import warnings
|
||||
from importlib.util import find_spec
|
||||
|
||||
import rootutils
|
||||
from beartype.typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from omegaconf import DictConfig
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.utils import pylogger, rich_utils
|
||||
|
||||
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
def extras(cfg: DictConfig) -> None:
|
||||
"""Applies optional utilities before the task is started.
|
||||
|
||||
Utilities:
|
||||
- Ignoring python warnings
|
||||
- Setting tags from command line
|
||||
- Rich config printing
|
||||
|
||||
:param cfg: A DictConfig object containing the config tree.
|
||||
"""
|
||||
# return if no `extras` config
|
||||
if not cfg.get("extras"):
|
||||
log.warning("Extras config not found! <cfg.extras=null>")
|
||||
return
|
||||
|
||||
# disable python warnings
|
||||
if cfg.extras.get("ignore_warnings"):
|
||||
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# prompt user to input tags from command line if none are provided in the config
|
||||
if cfg.extras.get("enforce_tags"):
|
||||
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
||||
rich_utils.enforce_tags(cfg, save_to_file=True)
|
||||
|
||||
# pretty print config tree using Rich library
|
||||
if cfg.extras.get("print_config"):
|
||||
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
||||
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
||||
|
||||
|
||||
def task_wrapper(task_func: Callable) -> Callable:
|
||||
"""Optional decorator that controls the failure behavior when executing the task function.
|
||||
|
||||
This wrapper can be used to:
|
||||
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
||||
- save the exception to a `.log` file
|
||||
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
||||
- etc. (adjust depending on your needs)
|
||||
|
||||
Example:
|
||||
```
|
||||
@utils.task_wrapper
|
||||
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
...
|
||||
return metric_dict, object_dict
|
||||
```
|
||||
|
||||
:param task_func: The task function to be wrapped.
|
||||
|
||||
:return: The wrapped task function.
|
||||
"""
|
||||
|
||||
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# execute the task
|
||||
try:
|
||||
metric_dict, object_dict = task_func(cfg=cfg)
|
||||
|
||||
# things to do if exception occurs
|
||||
except Exception as ex:
|
||||
# save exception to `.log` file
|
||||
log.exception("")
|
||||
|
||||
# some hyperparameter combinations might be invalid or cause out-of-memory errors
|
||||
# so when using hparam search plugins like Optuna, you might want to disable
|
||||
# raising the below exception to avoid multirun failure
|
||||
raise ex
|
||||
|
||||
# things to always do after either success or exception
|
||||
finally:
|
||||
# display output dir path in terminal
|
||||
log.info(f"Output dir: {cfg.paths.output_dir}")
|
||||
|
||||
# always close wandb run (even if exception occurs so multirun won't fail)
|
||||
if find_spec("wandb"): # check if wandb is installed
|
||||
import wandb
|
||||
|
||||
if wandb.run:
|
||||
log.info("Closing wandb!")
|
||||
wandb.finish()
|
||||
|
||||
return metric_dict, object_dict
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
|
||||
"""Safely retrieves value of the metric logged in LightningModule.
|
||||
|
||||
:param metric_dict: A dict containing metric values.
|
||||
:param metric_name: If provided, the name of the metric to retrieve.
|
||||
:return: If a metric name was provided, the value of the metric.
|
||||
"""
|
||||
if not metric_name:
|
||||
log.info("Metric name is None! Skipping metric value retrieval...")
|
||||
return None
|
||||
|
||||
if metric_name not in metric_dict:
|
||||
raise Exception(
|
||||
f"Metric value not found! <metric_name={metric_name}>\n"
|
||||
"Make sure metric name logged in LightningModule is correct!\n"
|
||||
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
||||
)
|
||||
|
||||
metric_value = metric_dict[metric_name].item()
|
||||
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
||||
|
||||
return metric_value
|
||||
|
||||
|
||||
def read_strings_from_txt(path: str) -> List[str]:
|
||||
"""Reads strings from a text file and returns them as a list.
|
||||
|
||||
:param path: Path to the text file.
|
||||
:return: List of strings.
|
||||
"""
|
||||
with open(path) as file:
|
||||
# NOTE: every line will be one element of the returned list
|
||||
lines = file.readlines()
|
||||
return [line.rstrip() for line in lines]
|
||||
|
||||
|
||||
def fasta_to_dict(filename: str) -> Dict[str, str]:
|
||||
"""Converts a FASTA file to a dictionary where the keys are the sequence IDs and the values are
|
||||
the sequences.
|
||||
|
||||
:param filename: Path to the FASTA file.
|
||||
:return: Dictionary with sequence IDs as keys and sequences as values.
|
||||
"""
|
||||
fasta_dict = {}
|
||||
with open(filename) as file:
|
||||
for line in file:
|
||||
line = line.rstrip() # remove trailing whitespace
|
||||
if line.startswith(">"): # identifier line
|
||||
seq_id = line[1:] # remove the '>' character
|
||||
fasta_dict[seq_id] = ""
|
||||
else: # sequence line
|
||||
fasta_dict[seq_id] += line
|
||||
return fasta_dict
|
||||
365
flowdock/utils/visualization_utils.py
Normal file
365
flowdock/utils/visualization_utils.py
Normal file
@@ -0,0 +1,365 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import rootutils
|
||||
from beartype import beartype
|
||||
from beartype.typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
from openfold.np.protein import Protein as OFProtein
|
||||
from rdkit import Chem
|
||||
from rdkit.Geometry.rdGeometry import Point3D
|
||||
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
from flowdock.data.components import residue_constants
|
||||
from flowdock.utils.data_utils import (
|
||||
PDB_CHAIN_IDS,
|
||||
PDB_MAX_CHAINS,
|
||||
FDProtein,
|
||||
create_full_prot,
|
||||
get_mol_with_new_conformer_coords,
|
||||
)
|
||||
|
||||
FeatureDict = Mapping[str, np.ndarray]
|
||||
ModelOutput = Mapping[str, Any] # Is a nested dict.
|
||||
PROT_LIG_PAIRS = List[Tuple[OFProtein, Tuple[Chem.Mol, ...]]]
|
||||
|
||||
|
||||
@beartype
|
||||
def _chain_end(
|
||||
atom_index: Union[int, np.int64],
|
||||
end_resname: str,
|
||||
chain_name: str,
|
||||
residue_index: Union[int, np.int64],
|
||||
) -> str:
|
||||
"""Returns a PDB `TER` record for the end of a chain.
|
||||
|
||||
Adapted from: https://github.com/jasonkyuyim/se3_diffusion
|
||||
|
||||
:param atom_index: The index of the last atom in the chain.
|
||||
:param end_resname: The residue name of the last residue in the chain.
|
||||
:param chain_name: The chain name of the last residue in the chain.
|
||||
:param residue_index: The residue index of the last residue in the chain.
|
||||
:return: A PDB `TER` record.
|
||||
"""
|
||||
chain_end = "TER"
|
||||
return (
|
||||
f"{chain_end:<6}{atom_index:>5} {end_resname:>3} "
|
||||
f"{chain_name:>1}{residue_index:>4}"
|
||||
)
|
||||
|
||||
|
||||
@beartype
|
||||
def res_1to3(restypes: List[str], r: Union[int, np.int64]) -> str:
|
||||
"""Convert a residue type from 1-letter to 3-letter code.
|
||||
|
||||
:param restypes: List of residue types.
|
||||
:param r: Residue type index.
|
||||
:return: 3-letter code as a string.
|
||||
"""
|
||||
return residue_constants.restype_1to3.get(restypes[r], "UNK")
|
||||
|
||||
|
||||
@beartype
|
||||
def to_pdb(prot: Union[OFProtein, FDProtein], model=1, add_end=True, add_endmdl=True) -> str:
|
||||
"""Converts a `Protein` instance to a PDB string.
|
||||
|
||||
Adapted from: https://github.com/jasonkyuyim/se3_diffusion
|
||||
|
||||
:param prot: The protein to convert to PDB.
|
||||
:param model: The model number to use.
|
||||
:param add_end: Whether to add an `END` record.
|
||||
:param add_endmdl: Whether to add an `ENDMDL` record.
|
||||
:return: PDB string.
|
||||
"""
|
||||
restypes = residue_constants.restypes + ["X"]
|
||||
atom_types = residue_constants.atom_types
|
||||
|
||||
pdb_lines = []
|
||||
|
||||
atom_mask = prot.atom_mask
|
||||
aatype = prot.aatype
|
||||
atom_positions = prot.atom_positions
|
||||
residue_index = prot.residue_index.astype(int)
|
||||
chain_index = prot.chain_index.astype(int)
|
||||
b_factors = prot.b_factors
|
||||
|
||||
if np.any(aatype > residue_constants.restype_num):
|
||||
raise ValueError("Invalid aatypes.")
|
||||
|
||||
# construct a mapping from chain integer indices to chain ID strings
|
||||
chain_ids = {}
|
||||
for i in np.unique(chain_index): # NOTE: `np.unique` gives sorted output
|
||||
if i >= PDB_MAX_CHAINS:
|
||||
raise ValueError(f"The PDB format supports at most {PDB_MAX_CHAINS} chains.")
|
||||
chain_ids[i] = PDB_CHAIN_IDS[i]
|
||||
|
||||
pdb_lines.append(f"MODEL {model}")
|
||||
atom_index = 1
|
||||
last_chain_index = chain_index[0]
|
||||
# add all atom sites
|
||||
for i in range(aatype.shape[0]):
|
||||
# close the previous chain if in a multichain PDB
|
||||
if last_chain_index != chain_index[i]:
|
||||
pdb_lines.append(
|
||||
_chain_end(
|
||||
atom_index,
|
||||
res_1to3(restypes, aatype[i - 1]),
|
||||
chain_ids[chain_index[i - 1]],
|
||||
residue_index[i - 1],
|
||||
)
|
||||
)
|
||||
last_chain_index = chain_index[i]
|
||||
atom_index += 1 # NOTE: atom index increases at the `TER` symbol
|
||||
|
||||
res_name_3 = res_1to3(restypes, aatype[i])
|
||||
for atom_name, pos, mask, b_factor in zip(
|
||||
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
|
||||
):
|
||||
if mask < 0.5:
|
||||
continue
|
||||
|
||||
record_type = "ATOM"
|
||||
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
||||
alt_loc = ""
|
||||
insertion_code = ""
|
||||
occupancy = 1.00
|
||||
element = atom_name[0] # NOTE: `Protein` supports only C, N, O, S, this works
|
||||
charge = ""
|
||||
# NOTE: PDB is a columnar format, every space matters here!
|
||||
atom_line = (
|
||||
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
|
||||
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
|
||||
f"{residue_index[i]:>4}{insertion_code:>1} "
|
||||
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
||||
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
||||
f"{element:>2}{charge:>2}"
|
||||
)
|
||||
pdb_lines.append(atom_line)
|
||||
atom_index += 1
|
||||
|
||||
# close the final chain
|
||||
pdb_lines.append(
|
||||
_chain_end(
|
||||
atom_index,
|
||||
res_1to3(restypes, aatype[-1]),
|
||||
chain_ids[chain_index[-1]],
|
||||
residue_index[-1],
|
||||
)
|
||||
)
|
||||
if add_endmdl:
|
||||
pdb_lines.append("ENDMDL")
|
||||
if add_end:
|
||||
pdb_lines.append("END")
|
||||
|
||||
# pad all lines to 80 characters
|
||||
pdb_lines = [line.ljust(80) for line in pdb_lines]
|
||||
return "\n".join(pdb_lines) + "\n" # add terminating newline
|
||||
|
||||
|
||||
@beartype
|
||||
def construct_prot_lig_pairs(outputs: Dict[str, Any], batch_index: int) -> PROT_LIG_PAIRS:
|
||||
"""Construct protein-ligand pairs from model outputs.
|
||||
|
||||
:param outputs: The model outputs.
|
||||
:param batch_index: The index of the current batch.
|
||||
:return: A list of protein-ligand object pairs.
|
||||
"""
|
||||
protein_batch_indexer = outputs["protein_batch_indexer"]
|
||||
ligand_batch_indexer = outputs["ligand_batch_indexer"]
|
||||
|
||||
protein_all_atom_mask = outputs["res_atom_mask"][protein_batch_indexer == batch_index]
|
||||
protein_all_atom_coordinates_mask = np.broadcast_to(
|
||||
np.expand_dims(protein_all_atom_mask, -1), (protein_all_atom_mask.shape[0], 37, 3)
|
||||
)
|
||||
protein_aatype = outputs["aatype"][protein_batch_indexer == batch_index]
|
||||
|
||||
# assemble predicted structures
|
||||
prot_lig_pairs = []
|
||||
for protein_coordinates, ligand_coordinates in zip(
|
||||
outputs["protein_coordinates_list"], outputs["ligand_coordinates_list"]
|
||||
):
|
||||
protein_all_atom_coordinates = (
|
||||
protein_coordinates[protein_batch_indexer == batch_index]
|
||||
* protein_all_atom_coordinates_mask
|
||||
)
|
||||
protein = create_full_prot(
|
||||
protein_all_atom_coordinates,
|
||||
protein_all_atom_mask,
|
||||
protein_aatype,
|
||||
b_factors=outputs["b_factors"][batch_index] if "b_factors" in outputs else None,
|
||||
)
|
||||
ligand = get_mol_with_new_conformer_coords(
|
||||
outputs["ligand_mol"][batch_index],
|
||||
ligand_coordinates[ligand_batch_indexer == batch_index],
|
||||
)
|
||||
ligands = tuple(Chem.GetMolFrags(ligand, asMols=True, sanitizeFrags=False))
|
||||
prot_lig_pairs.append((protein, ligands))
|
||||
|
||||
# assemble ground-truth structures
|
||||
if "gt_protein_coordinates" in outputs and "gt_ligand_coordinates" in outputs:
|
||||
protein_gt_all_atom_coordinates = (
|
||||
outputs["gt_protein_coordinates"][protein_batch_indexer == batch_index]
|
||||
* protein_all_atom_coordinates_mask
|
||||
)
|
||||
gt_protein = create_full_prot(
|
||||
protein_gt_all_atom_coordinates,
|
||||
protein_all_atom_mask,
|
||||
protein_aatype,
|
||||
)
|
||||
gt_ligand = get_mol_with_new_conformer_coords(
|
||||
outputs["ligand_mol"][batch_index],
|
||||
outputs["gt_ligand_coordinates"][ligand_batch_indexer == batch_index],
|
||||
)
|
||||
gt_ligands = tuple(Chem.GetMolFrags(gt_ligand, asMols=True, sanitizeFrags=False))
|
||||
prot_lig_pairs.append((gt_protein, gt_ligands))
|
||||
|
||||
return prot_lig_pairs
|
||||
|
||||
|
||||
@beartype
|
||||
def write_prot_lig_pairs_to_pdb_file(prot_lig_pairs: PROT_LIG_PAIRS, output_filepath: str):
|
||||
"""Write a list of protein-ligand pairs to a PDB file.
|
||||
|
||||
:param prot_lig_pairs: List of protein-ligand object pairs, where each ligand may consist of
|
||||
multiple ligand chains.
|
||||
:param output_filepath: Output file path.
|
||||
"""
|
||||
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
|
||||
with open(output_filepath, "w") as f:
|
||||
model_id = 1
|
||||
for prot, lig_mols in prot_lig_pairs:
|
||||
pdb_prot = to_pdb(prot, model=model_id, add_end=False, add_endmdl=False)
|
||||
f.write(pdb_prot)
|
||||
for lig_mol in lig_mols:
|
||||
f.write(
|
||||
Chem.MolToPDBBlock(lig_mol).replace(
|
||||
"END\n", "TER\n"
|
||||
) # enable proper ligand chain separation
|
||||
)
|
||||
f.write("END\n")
|
||||
f.write("ENDMDL\n") # add `ENDMDL` line to separate models
|
||||
model_id += 1
|
||||
|
||||
|
||||
def from_prediction(
|
||||
features: FeatureDict,
|
||||
result: ModelOutput,
|
||||
b_factors: Optional[np.ndarray] = None,
|
||||
remove_leading_feature_dimension: bool = False,
|
||||
) -> FDProtein:
|
||||
"""Assembles a protein from a prediction.
|
||||
|
||||
Args:
|
||||
features: Dictionary holding model inputs.
|
||||
result: Dictionary holding model outputs.
|
||||
b_factors: (Optional) B-factors to use for the protein.
|
||||
remove_leading_feature_dimension: Whether to remove the leading dimension
|
||||
of the `features` values.
|
||||
|
||||
Returns:
|
||||
A protein instance.
|
||||
"""
|
||||
fold_output = result["structure_module"]
|
||||
|
||||
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
|
||||
return arr[0] if remove_leading_feature_dimension else arr
|
||||
|
||||
if "asym_id" in features:
|
||||
chain_index = _maybe_remove_leading_dim(features["asym_id"])
|
||||
else:
|
||||
chain_index = np.zeros_like(_maybe_remove_leading_dim(features["aatype"]))
|
||||
|
||||
if b_factors is None:
|
||||
b_factors = np.zeros_like(fold_output["final_atom_mask"])
|
||||
|
||||
return FDProtein(
|
||||
letter_sequences=None,
|
||||
aatype=_maybe_remove_leading_dim(features["aatype"]),
|
||||
atom_positions=fold_output["final_atom_positions"],
|
||||
atom_mask=fold_output["final_atom_mask"],
|
||||
residue_index=_maybe_remove_leading_dim(features["residue_index"]),
|
||||
chain_index=chain_index,
|
||||
b_factors=b_factors,
|
||||
atomtypes=None,
|
||||
)
|
||||
|
||||
|
||||
def write_pdb_single(
|
||||
result: ModelOutput,
|
||||
out_path: str = os.path.join("test_results", "debug.pdb"),
|
||||
model: int = 1,
|
||||
b_factors: Optional[np.ndarray] = None,
|
||||
):
|
||||
"""Write a single model to a PDB file.
|
||||
|
||||
:param result: Model results batch.
|
||||
:param out_path: Output path.
|
||||
:param model: Model ID.
|
||||
:param b_factors: Optional B-factors.
|
||||
"""
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
protein = from_prediction(result["features"], result, b_factors=b_factors)
|
||||
out_string = to_pdb(protein, model=model)
|
||||
with open(out_path, "w") as of:
|
||||
of.write(out_string)
|
||||
|
||||
|
||||
def write_pdb_models(
|
||||
results,
|
||||
out_path: str = os.path.join("test_results", "debug.pdb"),
|
||||
b_factors: Optional[np.ndarray] = None,
|
||||
):
|
||||
"""Write multiple models to a PDB file.
|
||||
|
||||
:param results: Model results.
|
||||
:param out_path: Output path.
|
||||
:param b_factors: Optional B-factors.
|
||||
"""
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
with open(out_path, "w") as of:
|
||||
for mid, result in enumerate(results):
|
||||
protein = from_prediction(
|
||||
result["features"],
|
||||
result,
|
||||
b_factors=b_factors[mid] if b_factors is not None else None,
|
||||
)
|
||||
out_string = to_pdb(protein, model=mid + 1)
|
||||
of.write(out_string)
|
||||
of.write("END")
|
||||
|
||||
|
||||
def write_conformer_sdf(
|
||||
mol: Chem.Mol,
|
||||
confs: Optional[np.array] = None,
|
||||
out_path: str = os.path.join("test_results", "debug.sdf"),
|
||||
):
|
||||
"""Write a molecule with conformers to an SDF file.
|
||||
|
||||
:param mol: RDKit molecule.
|
||||
:param confs: Conformers.
|
||||
:param out_path: Output path.
|
||||
"""
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
if confs is None:
|
||||
w = Chem.SDWriter(out_path)
|
||||
w.write(mol)
|
||||
w.close()
|
||||
return 0
|
||||
mol.RemoveAllConformers()
|
||||
for i in range(len(confs)):
|
||||
conf = Chem.Conformer(mol.GetNumAtoms())
|
||||
for j in range(mol.GetNumAtoms()):
|
||||
x, y, z = confs[i, j].tolist()
|
||||
conf.SetAtomPosition(j, Point3D(x, y, z))
|
||||
mol.AddConformer(conf, assignId=True)
|
||||
|
||||
w = Chem.SDWriter(out_path)
|
||||
try:
|
||||
for cid in range(len(confs)):
|
||||
w.write(mol, confId=cid)
|
||||
except Exception as e:
|
||||
w.SetKekulize(False)
|
||||
for cid in range(len(confs)):
|
||||
w.write(mol, confId=cid)
|
||||
w.close()
|
||||
return 0
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user