Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
93
rf2aa/data/protein.py
Normal file
93
rf2aa/data/protein.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
|
||||
from rf2aa.data.data_loader import RawInputData
|
||||
from rf2aa.data.data_loader_utils import blank_template, TemplFeaturize
|
||||
from rf2aa.data.parsers import parse_a3m, parse_templates_raw
|
||||
from rf2aa.data.preprocessing import make_msa
|
||||
from rf2aa.util import get_protein_bond_feats
|
||||
|
||||
|
||||
def get_templates(
|
||||
qlen,
|
||||
ffdb,
|
||||
hhr_fn,
|
||||
atab_fn,
|
||||
seqID_cut,
|
||||
n_templ,
|
||||
pick_top: bool = True,
|
||||
offset: int = 0,
|
||||
random_noise: float = 5.0,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
(
|
||||
xyz_parsed,
|
||||
mask_parsed,
|
||||
qmap_parsed,
|
||||
f0d_parsed,
|
||||
f1d_parsed,
|
||||
seq_parsed,
|
||||
ids_parsed,
|
||||
) = parse_templates_raw(ffdb, hhr_fn=hhr_fn, atab_fn=atab_fn)
|
||||
tplt = {
|
||||
"xyz": xyz_parsed.unsqueeze(0),
|
||||
"mask": mask_parsed.unsqueeze(0),
|
||||
"qmap": qmap_parsed.unsqueeze(0),
|
||||
"f0d": f0d_parsed.unsqueeze(0),
|
||||
"f1d": f1d_parsed.unsqueeze(0),
|
||||
"seq": seq_parsed.unsqueeze(0),
|
||||
"ids": ids_parsed,
|
||||
}
|
||||
params = {
|
||||
"SEQID": seqID_cut,
|
||||
}
|
||||
return TemplFeaturize(
|
||||
tplt,
|
||||
qlen,
|
||||
params,
|
||||
offset=offset,
|
||||
npick=n_templ,
|
||||
pick_top=pick_top,
|
||||
random_noise=random_noise,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
|
||||
def load_protein(msa_file, hhr_fn, atab_fn, model_runner):
|
||||
msa, ins, taxIDs = parse_a3m(msa_file)
|
||||
# NOTE: this next line is a bug, but is the way that
|
||||
# the code is written in the original implementation!
|
||||
ins[0] = msa[0]
|
||||
|
||||
L = msa.shape[1]
|
||||
if hhr_fn is None or atab_fn is None:
|
||||
print("No templates provided")
|
||||
xyz_t, t1d, mask_t, _ = blank_template(1, L)
|
||||
else:
|
||||
xyz_t, t1d, mask_t, _ = get_templates(
|
||||
L,
|
||||
model_runner.ffdb,
|
||||
hhr_fn,
|
||||
atab_fn,
|
||||
seqID_cut=model_runner.config.loader_params.seqid,
|
||||
n_templ=model_runner.config.loader_params.n_templ,
|
||||
deterministic=model_runner.deterministic,
|
||||
)
|
||||
|
||||
bond_feats = get_protein_bond_feats(L)
|
||||
chirals = torch.zeros(0, 5)
|
||||
atom_frames = torch.zeros(0, 3, 2)
|
||||
return RawInputData(
|
||||
torch.from_numpy(msa),
|
||||
torch.from_numpy(ins),
|
||||
bond_feats,
|
||||
xyz_t,
|
||||
mask_t,
|
||||
t1d,
|
||||
chirals,
|
||||
atom_frames,
|
||||
taxids=taxIDs,
|
||||
)
|
||||
|
||||
def generate_msa_and_load_protein(fasta_file, chain, model_runner):
|
||||
msa_file, hhr_file, atab_file = make_msa(fasta_file, chain, model_runner)
|
||||
return load_protein(str(msa_file), str(hhr_file), str(atab_file), model_runner)
|
||||
Reference in New Issue
Block a user