Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
46
rf2aa/data/nucleic_acid.py
Normal file
46
rf2aa/data/nucleic_acid.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from rf2aa.data.parsers import parse_mixed_fasta, parse_multichain_fasta
|
||||
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, blank_template
|
||||
from rf2aa.data.data_loader import RawInputData
|
||||
from rf2aa.util import get_protein_bond_feats
|
||||
|
||||
def load_nucleic_acid(fasta_fn, input_type, model_runner):
|
||||
if input_type not in ["dna", "rna"]:
|
||||
raise ValueError("Only DNA and RNA inputs allowed for nucleic acids")
|
||||
if input_type == "dna":
|
||||
dna_alphabet = True
|
||||
rna_alphabet = False
|
||||
elif input_type == "rna":
|
||||
dna_alphabet = False
|
||||
rna_alphabet = True
|
||||
|
||||
loader_params = model_runner.config.loader_params
|
||||
msa, ins, L = parse_multichain_fasta(fasta_fn, rna_alphabet=rna_alphabet, dna_alphabet=dna_alphabet)
|
||||
if (msa.shape[0] > loader_params["MAXSEQ"]):
|
||||
idxs_tokeep = np.random.permutation(msa.shape[0])[:loader_params["MAXSEQ"]]
|
||||
idxs_tokeep[0] = 0
|
||||
msa = msa[idxs_tokeep]
|
||||
ins = ins[idxs_tokeep]
|
||||
if len(L) > 1:
|
||||
raise ValueError("Please provide separate fasta files for each nucleic acid chain")
|
||||
L = L[0]
|
||||
xyz_t, t1d, mask_t, _ = blank_template(loader_params["n_templ"], L)
|
||||
|
||||
|
||||
bond_feats = get_protein_bond_feats(L)
|
||||
chirals = torch.zeros(0, 5)
|
||||
atom_frames = torch.zeros(0, 3, 2)
|
||||
|
||||
return RawInputData(
|
||||
torch.from_numpy(msa),
|
||||
torch.from_numpy(ins),
|
||||
bond_feats,
|
||||
xyz_t,
|
||||
mask_t,
|
||||
t1d,
|
||||
chirals,
|
||||
atom_frames,
|
||||
taxids=None,
|
||||
)
|
||||
Reference in New Issue
Block a user