Initial commit: FlowDock pipeline configured for WES execution
Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled

This commit is contained in:
2026-03-16 15:23:29 +01:00
commit a3ffec6a07
116 changed files with 16139 additions and 0 deletions

1
configs/__init__.py Normal file
View File

@@ -0,0 +1 @@
# this file is needed here to include configs when building project as a package

View File

@@ -0,0 +1,21 @@
defaults:
- ema
- last_model_checkpoint
- learning_rate_monitor
- model_checkpoint
- model_summary
- rich_progress_bar
- _self_
last_model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "last"
monitor: null
verbose: True
auto_insert_metric_name: False
every_n_epochs: 1
save_on_train_epoch_end: True
enable_version_counter: False
model_summary:
max_depth: -1

View File

@@ -0,0 +1,15 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
monitor: ??? # quantity to be monitored, must be specified !!!
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
patience: 3 # number of checks with no improvement after which training will be stopped
verbose: False # verbosity mode
mode: "min" # "max" means higher metric value is better, can be also "min"
strict: True # whether to crash the training if monitor is not found in the validation metrics
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
# log_rank_zero_only: False # this keyword argument isn't available in stable version

View File

@@ -0,0 +1,10 @@
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
# Maintains an exponential moving average (EMA) of model weights.
# Look at the above link for more detailed information regarding the original implementation.
ema:
_target_: flowdock.models.components.callbacks.ema.EMA
decay: 0.999
validate_original_weights: false
every_n_steps: 4
cpu_offload: false

View File

@@ -0,0 +1,21 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
last_model_checkpoint:
# NOTE: this is a direct copy of `model_checkpoint`,
# which is necessary to make to work around the
# key-duplication limitations of YAML config files
_target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint
dirpath: null # directory to save the model file
filename: null # checkpoint filename
monitor: null # name of the logged metric which determines when model is improving
verbose: False # verbosity mode
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 1 # save k best models (determined by above metric)
mode: "min" # "max" means higher metric value is better, can be also "min"
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
save_weights_only: False # if True, then only the models weights will be saved
every_n_train_steps: null # number of training steps between checkpoints
train_time_interval: null # checkpoints are monitored at the specified time interval
every_n_epochs: null # number of epochs between checkpoints
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
enable_version_counter: True # enables versioning for checkpoint names

View File

@@ -0,0 +1,7 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html
learning_rate_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: null # set to `epoch` or `step` to log learning rate of all optimizers at the same interval, or set to `null` to log at individual interval according to the interval key of each scheduler
log_momentum: false # whether to also log the momentum values of the optimizer, if the optimizer has the `momentum` or `betas` attribute
log_weight_decay: false # whether to also log the weight decay values of the optimizer, if the optimizer has the `weight_decay` attribute

View File

@@ -0,0 +1,18 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
model_checkpoint:
_target_: flowdock.models.components.callbacks.ema.EMAModelCheckpoint
dirpath: null # directory to save the model file
filename: "best" # checkpoint filename
monitor: val_sampling/ligand_hit_score_2A_epoch # name of the logged metric which determines when model is improving
verbose: True # verbosity mode
save_last: False # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 1 # save k best models (determined by above metric)
mode: "max" # "max" means higher metric value is better, can be also "min"
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
save_weights_only: False # if True, then only the models weights will be saved
every_n_train_steps: null # number of training steps between checkpoints
train_time_interval: null # checkpoints are monitored at the specified time interval
every_n_epochs: null # number of epochs between checkpoints
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
enable_version_counter: False # enables versioning for checkpoint names

View File

@@ -0,0 +1,5 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
model_summary:
_target_: lightning.pytorch.callbacks.RichModelSummary
max_depth: 1 # the maximum depth of layer nesting that the summary will include

View File

View File

@@ -0,0 +1,4 @@
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
rich_progress_bar:
_target_: lightning.pytorch.callbacks.RichProgressBar

View File

@@ -0,0 +1,35 @@
# @package _global_
# default debugging setup, runs 1 full epoch
# other debugging configs can inherit from this one
# overwrite task name so debugging logs are stored in separate folder
task_name: "debug"
# disable callbacks and loggers during debugging
callbacks: null
logger: null
extras:
ignore_warnings: False
enforce_tags: False
# sets level of all command line loggers to 'DEBUG'
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
hydra:
job_logging:
root:
level: DEBUG
# use this to also set hydra loggers to 'DEBUG'
# verbose: True
trainer:
max_epochs: 1
accelerator: cpu # debuggers don't like gpus
devices: 1 # debuggers don't like multiprocessing
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
data:
num_workers: 0 # debuggers don't like multiprocessing
pin_memory: False # disable gpu memory pin

9
configs/debug/fdr.yaml Normal file
View File

@@ -0,0 +1,9 @@
# @package _global_
# runs 1 train, 1 validation and 1 test step
defaults:
- default
trainer:
fast_dev_run: true

12
configs/debug/limit.yaml Normal file
View File

@@ -0,0 +1,12 @@
# @package _global_
# uses only 1% of the training data and 5% of validation/test data
defaults:
- default
trainer:
max_epochs: 3
limit_train_batches: 0.01
limit_val_batches: 0.05
limit_test_batches: 0.05

View File

@@ -0,0 +1,13 @@
# @package _global_
# overfits to 3 batches
defaults:
- default
trainer:
max_epochs: 20
overfit_batches: 3
# model ckpt and early stopping need to be disabled during overfitting
callbacks: null

View File

@@ -0,0 +1,12 @@
# @package _global_
# runs with execution time profiling
defaults:
- default
trainer:
max_epochs: 1
profiler: "simple"
# profiler: "advanced"
# profiler: "pytorch"

View File

@@ -0,0 +1,2 @@
defaults:
- _self_

View File

@@ -0,0 +1 @@
_target_: lightning.fabric.plugins.environments.LightningEnvironment

View File

@@ -0,0 +1,3 @@
_target_: lightning.fabric.plugins.environments.SLURMEnvironment
auto_requeue: true
requeue_signal: null

49
configs/eval.yaml Normal file
View File

@@ -0,0 +1,49 @@
# @package _global_
defaults:
- data: combined # choose datamodule with `test_dataloader()` for evaluation
- model: flowdock_fm
- logger: null
- strategy: default
- trainer: default
- paths: default
- extras: default
- hydra: default
- environment: default
- _self_
task_name: "eval"
tags: ["eval", "combined", "flowdock_fm"]
# passing checkpoint path is necessary for evaluation
ckpt_path: ???
# seed for random number generators in pytorch, numpy and python.random
seed: null
# model arguments
model:
cfg:
mol_encoder:
from_pretrained: false
protein_encoder:
from_pretrained: false
relational_reasoning:
from_pretrained: false
contact_predictor:
from_pretrained: false
score_head:
from_pretrained: false
confidence:
from_pretrained: false
affinity:
from_pretrained: false
task:
freeze_mol_encoder: true
freeze_protein_encoder: false
freeze_relational_reasoning: false
freeze_contact_predictor: false
freeze_score_head: false
freeze_confidence: true
freeze_affinity: false

View File

@@ -0,0 +1,35 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=flowdock_fm
defaults:
- override /data: combined
- override /model: flowdock_fm
- override /callbacks: default
- override /trainer: default
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
tags: ["flowdock_fm", "combined_dataset"]
seed: 496
trainer:
max_epochs: 300
check_val_every_n_epoch: 5 # NOTE: we increase this since validation steps involve full model sampling and evaluation
reload_dataloaders_every_n_epochs: 1
model:
optimizer:
lr: 2e-4
compile: false
data:
batch_size: 16
logger:
wandb:
tags: ${tags}
group: "FlowDock-FM"

View File

@@ -0,0 +1,8 @@
# disable python warnings if they annoy you
ignore_warnings: False
# ask user for tags if none are provided in the config
enforce_tags: True
# pretty print config tree at the start of the run using Rich library
print_config: True

View File

@@ -0,0 +1,50 @@
# @package _global_
# example hyperparameter optimization of some experiment with Optuna:
# python train.py -m hparams_search=mnist_optuna experiment=example
defaults:
- override /hydra/sweeper: optuna
# choose metric which will be optimized by Optuna
# make sure this is the correct name of some metric logged in lightning module!
optimized_metric: "val/loss"
# here we define Optuna hyperparameter search
# it optimizes for value returned from function with @hydra.main decorator
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
hydra:
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
sweeper:
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
# storage URL to persist optimization results
# for example, you can use SQLite if you set 'sqlite:///example.db'
storage: null
# name of the study to persist optimization results
study_name: null
# number of parallel workers
n_jobs: 1
# 'minimize' or 'maximize' the objective
direction: minimize
# total number of runs that will be executed
n_trials: 20
# choose Optuna hyperparameter sampler
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
sampler:
_target_: optuna.samplers.TPESampler
seed: 1234
n_startup_trials: 10 # number of random sampling runs before optimization starts
# define hyperparameter search space
params:
model.optimizer.lr: interval(0.0001, 0.1)
data.batch_size: choice(2, 4, 8, 16)
model.net.hidden_dim: choice(64, 128, 256)

View File

@@ -0,0 +1,19 @@
# https://hydra.cc/docs/configure_hydra/intro/
# enable color logging
defaults:
- override hydra_logging: colorlog
- override job_logging: colorlog
# output directory, generated dynamically on each run
run:
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
sweep:
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
subdir: ${hydra.job.num}
job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/${task_name}.log

0
configs/local/.gitkeep Normal file
View File

28
configs/logger/aim.yaml Normal file
View File

@@ -0,0 +1,28 @@
# https://aimstack.io/
# example usage in lightning module:
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
# `aim up`
aim:
_target_: aim.pytorch_lightning.AimLogger
repo: ${paths.root_dir} # .aim folder will be created here
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
# aim allows to group runs under experiment name
experiment: null # any string, set to "default" if not specified
train_metric_prefix: "train/"
val_metric_prefix: "val/"
test_metric_prefix: "test/"
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
system_tracking_interval: 10 # set to null to disable system metrics tracking
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
log_system_params: true
# enable/disable tracking console logs (default value is true)
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550

12
configs/logger/comet.yaml Normal file
View File

@@ -0,0 +1,12 @@
# https://www.comet.ml
comet:
_target_: lightning.pytorch.loggers.comet.CometLogger
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
save_dir: "${paths.output_dir}"
project_name: "FlowDock_FM"
rest_api_key: null
# experiment_name: ""
experiment_key: null # set to resume experiment
offline: False
prefix: ""

7
configs/logger/csv.yaml Normal file
View File

@@ -0,0 +1,7 @@
# csv logger built in lightning
csv:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
save_dir: "${paths.output_dir}"
name: "csv/"
prefix: ""

View File

@@ -0,0 +1,9 @@
# train with many loggers at once
defaults:
# - comet
- csv
# - mlflow
# - neptune
- tensorboard
- wandb

View File

@@ -0,0 +1,12 @@
# https://mlflow.org
mlflow:
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
# experiment_name: ""
# run_name: ""
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
tags: null
# save_dir: "./mlruns"
prefix: ""
artifact_location: null
# run_id: ""

View File

@@ -0,0 +1,9 @@
# https://neptune.ai
neptune:
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
project: username/FlowDock_FM
# name: ""
log_model_checkpoints: True
prefix: ""

View File

@@ -0,0 +1,10 @@
# https://www.tensorflow.org/tensorboard/
tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.output_dir}/tensorboard/"
name: null
log_graph: False
default_hp_metric: True
prefix: ""
# version: ""

16
configs/logger/wandb.yaml Normal file
View File

@@ -0,0 +1,16 @@
# https://wandb.ai
wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "FlowDock_FM"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
entity: "bml-lab" # set to name of your wandb team
group: ""
tags: []
job_type: ""

View File

@@ -0,0 +1,148 @@
_target_: flowdock.models.flowdock_fm_module.FlowDockFMLitModule
net:
_target_: flowdock.models.components.flowdock.FlowDock
_partial_: true
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 2e-4
weight_decay: 0.0
scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
_partial_: true
T_0: ${int_divide:${trainer.max_epochs},15}
T_mult: 2
eta_min: 1e-8
verbose: true
# compile model for faster training with pytorch 2.0
compile: false
# model arguments
cfg:
mol_encoder:
node_channels: 512
pair_channels: 64
n_atom_encodings: 23
n_bond_encodings: 4
n_atom_pos_encodings: 6
n_stereo_encodings: 14
n_attention_heads: 8
attention_head_dim: 8
hidden_dim: 2048
max_path_integral_length: 6
n_transformer_stacks: 8
n_heads: 8
n_patches: ${data.n_lig_patches}
checkpoint_file: ${oc.env:PROJECT_ROOT}/checkpoints/neuralplexermodels_downstream_datasets_predictions/models/complex_structure_prediction.ckpt
megamolbart: null
from_pretrained: true
protein_encoder:
use_esm_embedding: true
esm_version: esm2_t33_650M_UR50D
esm_repr_layer: 33
residue_dim: 512
plm_embed_dim: 1280
n_aa_types: 21
atom_padding_dim: 37
n_atom_types: 4 # [C, N, O, S]
n_patches: ${data.n_protein_patches}
n_attention_heads: 8
scalar_dim: 16
point_dim: 4
pair_dim: 64
n_heads: 8
head_dim: 8
max_residue_degree: 32
n_encoder_stacks: 2
from_pretrained: true
relational_reasoning:
from_pretrained: true
contact_predictor:
n_stacks: 4
dropout: 0.01
from_pretrained: true
score_head:
fiber_dim: 64
hidden_dim: 512
n_stacks: 4
max_atom_degree: 8
from_pretrained: true
confidence:
enabled: true # whether the confidence prediction head is to be used e.g., during inference
fiber_dim: ${..score_head.fiber_dim}
hidden_dim: ${..score_head.hidden_dim}
n_stacks: ${..score_head.n_stacks}
from_pretrained: true
affinity:
enabled: true # whether the affinity prediction head is to be used e.g., during inference
fiber_dim: ${..score_head.fiber_dim}
hidden_dim: ${..score_head.hidden_dim}
n_stacks: ${..score_head.n_stacks}
ligand_pooling: sum # NOTE: must be a value in (`sum`, `mean`)
dropout: 0.01
from_pretrained: false
latent_model: default
prior_type: esmfold # NOTE: must be a value in (`gaussian`, `harmonic`, `esmfold`)
task:
pretrained: null
ligands: true
epoch_frac: ${data.epoch_frac}
label_name: null
sequence_crop_size: 1600
edge_crop_size: ${data.edge_crop_size} # NOTE: for dynamic batching via `max_n_edges`
max_masking_rate: 0.0
n_modes: 8
dropout: 0.01
# pretraining: true
freeze_mol_encoder: true
freeze_protein_encoder: false
freeze_relational_reasoning: false
freeze_contact_predictor: true
freeze_score_head: false
freeze_confidence: true
freeze_affinity: false
use_template: true
use_plddt: false
block_contact_decoding_scheme: "beam"
frozen_ligand_backbone: false
frozen_protein_backbone: false
single_protein_batch: true
contact_loss_weight: 0.2
global_score_loss_weight: 0.2
ligand_score_loss_weight: 0.1
clash_loss_weight: 10.0
local_distgeom_loss_weight: 10.0
drmsd_loss_weight: 2.0
distogram_loss_weight: 0.05
plddt_loss_weight: 1.0
affinity_loss_weight: 0.1
aux_batch_freq: 10 # NOTE: e.g., `10` means that auxiliary estimation losses will be calculated every 10th batch
global_max_sigma: 5.0
internal_max_sigma: 2.0
detect_covalent: true
# runtime configs
float32_matmul_precision: highest
# sampling configs
constrained_inpainting: false
visualize_generated_samples: true
# testing configs
loss_mode: auxiliary_estimation # NOTE: must be one of (`structure_prediction`, `auxiliary_estimation`, `auxiliary_estimation_without_structure_prediction`)
num_steps: 20
sampler: VDODE # NOTE: must be one of (`ODE`, `VDODE`)
sampler_eta: 1.0 # NOTE: this corresponds to the variance diminishing factor for the `VDODE` sampler, which offers a trade-off between exploration (1.0) and exploitation (> 1.0)
start_time: 1.0
eval_structure_prediction: false # whether to evaluate structure prediction performance (`true`) or instead only binding affinity performance (`false`) when running a test epoch
# overfitting configs
overfitting_example_name: ${data.overfitting_example_name}

View File

@@ -0,0 +1,21 @@
# path to root directory
# this requires PROJECT_ROOT environment variable to exist
# you can replace it with "." if you want the root to be the current working directory
root_dir: ${oc.env:PROJECT_ROOT}
# path to data directory
data_dir: ${paths.root_dir}/data/
# path to logging directory
log_dir: ${paths.root_dir}/logs/
# path to output directory, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like ckpts and metrics
output_dir: ${hydra:runtime.output_dir}
# path to working directory
work_dir: ${hydra:runtime.cwd}
# path to the directory containing the model checkpoints
ckpt_dir: ${paths.root_dir}/checkpoints/

78
configs/sample.yaml Normal file
View File

@@ -0,0 +1,78 @@
# @package _global_
defaults:
- data: combined # NOTE: this will not be referenced during sampling
- model: flowdock_fm
- logger: null
- strategy: default
- trainer: default
- paths: default
- extras: default
- hydra: default
- environment: default
- _self_
task_name: "sample"
tags: ["sample", "combined", "flowdock_fm"]
# passing checkpoint path is necessary for sampling
ckpt_path: ???
# seed for random number generators in pytorch, numpy and python.random
seed: null
# sampling arguments
sampling_task: batched_structure_sampling # NOTE: must be one of (`batched_structure_sampling`)
sample_id: null # optional identifier for the sampling run
input_receptor: null # NOTE: must be either a protein sequence string (with chains separated by `|`) or a path to a PDB file (from which protein chain sequences will be parsed)
input_ligand: null # NOTE: must be either a ligand SMILES string (with chains/fragments separated by `|`) or a path to a ligand SDF file (from which ligand SMILES will be parsed)
input_template: null # path to a protein PDB file to use as a starting protein template for sampling (with an ESMFold prior model)
out_path: ??? # path to which to save the output PDB and SDF files
n_samples: 5 # number of structures to sample
chunk_size: 5 # number of structures to concurrently sample within each batch segment - NOTE: `n_samples` should be evenly divisible by `chunk_size` to produce the expected number of outputs
num_steps: 40 # number of sampling steps to perform
latent_model: null # if provided, the type of latent model to use
sampler: VDODE # sampling algorithm to use - NOTE: must be one of (`ODE`, `VDODE`)
sampler_eta: 1.0 # the variance diminishing factor for the `VDODE` sampler - NOTE: offers a trade-off between exploration (1.0) and exploitation (> 1.0)
start_time: "1.0" # time at which to start sampling
max_chain_encoding_k: -1 # maximum number of chains to encode in the chain encoding
exact_prior: false # whether to use the "ground-truth" binding site for sampling, if available
prior_type: esmfold # the type of prior to use for sampling - NOTE: must be one of (`gaussian`, `harmonic`, `esmfold`)
discard_ligand: false # whether to discard a given input ligand during sampling
discard_sdf_coords: true # whether to discard the input ligand's 3D structure during sampling, if available
detect_covalent: false # whether to detect covalent bonds between the input receptor and ligand
use_template: true # whether to use the input protein template for sampling if one is provided
separate_pdb: true # whether to save separate PDB files for each sampled structure instead of simply a single PDB file
rank_outputs_by_confidence: true # whether to rank the sampled structures by estimated confidence
plddt_ranking_type: ligand # the type of plDDT ranking to apply to generated samples - NOTE: must be one of (`protein`, `ligand`, `protein_ligand`)
visualize_sample_trajectories: false # whether to visualize the generated samples' trajectories
auxiliary_estimation_only: false # whether to only estimate auxiliary outputs (e.g., confidence, affinity) for the input (generated) samples (potentially derived from external sources)
csv_path: null # if provided, the CSV file (with columns `id`, `input_receptor`, `input_ligand`, and `input_template`) from which to parse input receptors and ligands for sampling, overriding the `input_receptor` and `input_ligand` arguments in the process and ignoring the `input_template` for now
esmfold_chunk_size: null # chunks axial attention computation to reduce memory usage from O(L^2) to O(L); equivalent to running a for loop over chunks of of each dimension; lower values will result in lower memory usage at the cost of speed; recommended values: 128, 64, 32
# model arguments
model:
cfg:
mol_encoder:
from_pretrained: false
protein_encoder:
from_pretrained: false
relational_reasoning:
from_pretrained: false
contact_predictor:
from_pretrained: false
score_head:
from_pretrained: false
confidence:
from_pretrained: false
affinity:
from_pretrained: false
task:
freeze_mol_encoder: true
freeze_protein_encoder: false
freeze_relational_reasoning: false
freeze_contact_predictor: false
freeze_score_head: false
freeze_confidence: true
freeze_affinity: false

View File

@@ -0,0 +1,4 @@
_target_: lightning.pytorch.strategies.DDPStrategy
static_graph: false
gradient_as_bucket_view: false
find_unused_parameters: true

View File

@@ -0,0 +1,5 @@
_target_: lightning.pytorch.strategies.DDPStrategy
static_graph: false
gradient_as_bucket_view: false
find_unused_parameters: true
start_method: spawn

View File

@@ -0,0 +1,5 @@
_target_: lightning.pytorch.strategies.DeepSpeedStrategy
stage: 2
offload_optimizer: false
allgather_bucket_size: 200_000_000
reduce_bucket_size: 200_000_000

View File

@@ -0,0 +1,2 @@
defaults:
- _self_

View File

@@ -0,0 +1,12 @@
_target_: lightning.pytorch.strategies.FSDPStrategy
sharding_strategy: ${resolve_variable:torch.distributed.fsdp.ShardingStrategy.FULL_SHARD}
cpu_offload: null
activation_checkpointing: null
mixed_precision:
_target_: torch.distributed.fsdp.MixedPrecision
param_dtype: null
reduce_dtype: null
buffer_dtype: null
keep_low_precision_grads: false
cast_forward_inputs: false
cast_root_forward_inputs: true

View File

@@ -0,0 +1,4 @@
_target_: lightning.pytorch.strategies.DDPStrategy
static_graph: true
gradient_as_bucket_view: true
find_unused_parameters: false

51
configs/train.yaml Normal file
View File

@@ -0,0 +1,51 @@
# @package _global_
# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: combined
- model: flowdock_fm
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- strategy: default
- trainer: default
- paths: default
- extras: default
- hydra: default
- environment: default
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null
# config for hyperparameter optimization
- hparams_search: null
# optional local config for machine/user specific settings
# it's optional since it doesn't need to exist and is excluded from version control
- optional local: default
# debugging config (enable through command line, e.g. `python train.py debug=default)
- debug: null
# task name, determines output directory path
task_name: "train"
# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
tags: ["train", "combined", "flowdock_fm"]
# set False to skip model training
train: True
# evaluate on test set, using best model weights achieved during training
# lightning chooses best weights based on the metric specified in checkpoint callback
test: False
# simply provide checkpoint path to resume training
ckpt_path: null
# seed for random number generators in pytorch, numpy and python.random
seed: null

5
configs/trainer/cpu.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: cpu
devices: 1

9
configs/trainer/ddp.yaml Normal file
View File

@@ -0,0 +1,9 @@
defaults:
- default
strategy: ddp
accelerator: gpu
devices: 4
num_nodes: 1
sync_batchnorm: True

View File

@@ -0,0 +1,7 @@
defaults:
- default
# simulate DDP on CPU, useful for debugging
accelerator: cpu
devices: 2
strategy: ddp_spawn

View File

@@ -0,0 +1,9 @@
defaults:
- default
strategy: ddp_spawn
accelerator: gpu
devices: 4
num_nodes: 1
sync_batchnorm: True

View File

@@ -0,0 +1,29 @@
_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.output_dir}
min_epochs: 1 # prevents early stopping
max_epochs: 10
accelerator: cpu
devices: 1
# mixed precision for extra speed-up
# precision: 16
# perform a validation loop every N training epochs
check_val_every_n_epoch: 1
# set True to to ensure deterministic results
# makes training slower but gives more reproducibility than just setting seeds
deterministic: False
# determine the frequency of how often to reload the dataloaders
reload_dataloaders_every_n_epochs: 1
# if `gradient_clip_val` is not `null`, gradients will be norm-clipped during training
gradient_clip_algorithm: norm
gradient_clip_val: 1.0
# if `num_sanity_val_steps` is > 0, then specifically that many validation steps will be run during the first call to `trainer.fit`
num_sanity_val_steps: 0

5
configs/trainer/gpu.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: gpu
devices: 1

5
configs/trainer/mps.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- default
accelerator: mps
devices: 1