Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
90
rf2aa/setup_model.py
Normal file
90
rf2aa/setup_model.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
import hydra
|
||||
import os
|
||||
|
||||
from rf2aa.training.EMA import EMA
|
||||
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
|
||||
from rf2aa.util_module import XYZConverter
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
||||
#TODO: control environment variables from config
|
||||
# limit thread counts
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
os.environ['OPENBLAS_NUM_THREADS'] = '4'
|
||||
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
|
||||
|
||||
## To reproduce errors
|
||||
import random
|
||||
|
||||
def seed_all(seed=0):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
torch.set_num_threads(4)
|
||||
#torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config) -> None:
|
||||
self.config = config
|
||||
|
||||
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
|
||||
if self.config.experiment.output_dir is not None:
|
||||
self.output_dir = self.config.experiment.output_dir
|
||||
else:
|
||||
self.output_dir = "models/"
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
|
||||
def move_constants_to_device(self, gpu):
|
||||
self.fi_dev = ChemData().frame_indices.to(gpu)
|
||||
self.xyz_converter = XYZConverter().to(gpu)
|
||||
|
||||
self.l2a = ChemData().long2alt.to(gpu)
|
||||
self.aamask = ChemData().allatom_mask.to(gpu)
|
||||
self.num_bonds = ChemData().num_bonds.to(gpu)
|
||||
self.atom_type_index = ChemData().atom_type_index.to(gpu)
|
||||
self.ljlk_parameters = ChemData().ljlk_parameters.to(gpu)
|
||||
self.lj_correction_parameters = ChemData().lj_correction_parameters.to(gpu)
|
||||
self.hbtypes = ChemData().hbtypes.to(gpu)
|
||||
self.hbbaseatoms = ChemData().hbbaseatoms.to(gpu)
|
||||
self.hbpolys = ChemData().hbpolys.to(gpu)
|
||||
self.cb_len = ChemData().cb_length_t.to(gpu)
|
||||
self.cb_ang = ChemData().cb_angle_t.to(gpu)
|
||||
self.cb_tor = ChemData().cb_torsion_t.to(gpu)
|
||||
|
||||
class LegacyTrainer(Trainer):
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
def construct_model(self, device="cpu"):
|
||||
self.model = RoseTTAFoldModule(
|
||||
**self.config.legacy_model_param,
|
||||
aamask = ChemData().allatom_mask.to(device),
|
||||
atom_type_index = ChemData().atom_type_index.to(device),
|
||||
ljlk_parameters = ChemData().ljlk_parameters.to(device),
|
||||
lj_correction_parameters = ChemData().lj_correction_parameters.to(device),
|
||||
num_bonds = ChemData().num_bonds.to(device),
|
||||
cb_len = ChemData().cb_length_t.to(device),
|
||||
cb_ang = ChemData().cb_angle_t.to(device),
|
||||
cb_tor = ChemData().cb_torsion_t.to(device),
|
||||
|
||||
).to(device)
|
||||
if self.config.training_params.EMA is not None:
|
||||
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||
|
||||
@hydra.main(version_base=None, config_path='config/train')
|
||||
def main(config):
|
||||
seed_all()
|
||||
trainer = trainer_factory[config.experiment.trainer](config=config)
|
||||
|
||||
trainer_factory = {
|
||||
"legacy": LegacyTrainer,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user