Initial commit: Chai-1 protein structure prediction pipeline for WES
- Nextflow pipeline using chai1 Docker image from Harbor - S3-based input/output paths (s3://omic/eureka/chai-lab/) - GPU-accelerated protein folding with MSA support Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
4
tests/__init__.py
Executable file
4
tests/__init__.py
Executable file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
|
||||
37
tests/example_inputs.py
Executable file
37
tests/example_inputs.py
Executable file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
|
||||
example_ligands = [
|
||||
"C",
|
||||
"O",
|
||||
"C(C1C(C(C(C(O1)O)O)O)O)O",
|
||||
"[O-]S(=O)(=O)[O-]",
|
||||
"CC1=C(C(CCC1)(C)C)/C=C/C(=C/C=C/C(=C/C=O)/C)/C",
|
||||
"CCC1=C(c2cc3c(c(c4n3[Mg]56[n+]2c1cc7n5c8c(c9[n+]6c(c4)C(C9CCC(=O)OC/C=C(\C)/CCC[C@H](C)CCC[C@H](C)CCCC(C)C)C)[C@H](C(=O)c8c7C)C(=O)OC)C)C=C)C=O",
|
||||
r"C=CC1=C(C)/C2=C/c3c(C)c(CCC(=O)O)c4n3[Fe@TB16]35<-N2=C1/C=c1/c(C)c(C=C)/c(n13)=C/C1=N->5/C(=C\4)C(CCC(=O)O)=C1C",
|
||||
# different ions
|
||||
"[Mg+2]",
|
||||
"[Na+]",
|
||||
"[Cl-]",
|
||||
]
|
||||
|
||||
example_proteins = [
|
||||
"AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVR",
|
||||
"(KCJ)(SEP)(PPN)(B3S)(BAL)(PPN)K(NH2)",
|
||||
"XDHPX",
|
||||
]
|
||||
|
||||
|
||||
example_rna = [
|
||||
"AGUGGCUA",
|
||||
"AAAAAA",
|
||||
"AGUC",
|
||||
]
|
||||
|
||||
example_dna = [
|
||||
"AGTGGCTA",
|
||||
"AAAAAA",
|
||||
"AGTC",
|
||||
]
|
||||
24
tests/test_cif_utils.py
Executable file
24
tests/test_cif_utils.py
Executable file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from chai_lab.data.io.cif_utils import get_chain_letter
|
||||
|
||||
|
||||
def test_get_chain_letter():
|
||||
with pytest.raises(AssertionError):
|
||||
get_chain_letter(0)
|
||||
assert get_chain_letter(1) == "A"
|
||||
assert get_chain_letter(26) == "Z"
|
||||
assert get_chain_letter(27) == "a"
|
||||
assert get_chain_letter(52) == "z"
|
||||
|
||||
assert get_chain_letter(53) == "AA"
|
||||
assert get_chain_letter(54) == "AB"
|
||||
|
||||
# For one-letter codes, there are 26 + 26 = 52 codes
|
||||
# For two-letter codes, there are 52 * 52 codes
|
||||
assert get_chain_letter(52 * 52 + 52) == "zz"
|
||||
108
tests/test_glycans.py
Executable file
108
tests/test_glycans.py
Executable file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from chai_lab.chai1 import make_all_atom_feature_context
|
||||
from chai_lab.data.parsing.glycans import _glycan_string_to_sugars_and_bonds
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ccd_code", ["MAN", "99K", "FUC"])
|
||||
def test_parsing_ccd_codes(ccd_code: str):
|
||||
"""Test that various single CCD codes are parsed correctly."""
|
||||
res, _ = _glycan_string_to_sugars_and_bonds(ccd_code)
|
||||
assert len(res) == 1
|
||||
|
||||
|
||||
def test_complex_parsing():
|
||||
glycan = "MAN(6-1 FUC)(4-1 MAN(6-1 MAN(6-1 MAN)))".replace(" ", "")
|
||||
sugars, bonds = _glycan_string_to_sugars_and_bonds(glycan)
|
||||
assert len(sugars) == 5
|
||||
|
||||
bond1, bond2, bond3, bond4 = bonds
|
||||
|
||||
assert bond1.src_sugar_index == 0
|
||||
assert bond1.dst_sugar_index == 1
|
||||
assert bond1.src_atom == 6
|
||||
assert bond1.dst_atom == 1
|
||||
assert bond2.src_sugar_index == 0
|
||||
assert bond2.dst_sugar_index == 2
|
||||
assert bond2.src_atom == 4
|
||||
assert bond2.dst_atom == 1
|
||||
assert bond3.src_sugar_index == 2
|
||||
assert bond3.dst_sugar_index == 3
|
||||
assert bond3.src_atom == 6
|
||||
assert bond3.dst_atom == 1
|
||||
assert bond4.src_sugar_index == 3
|
||||
assert bond4.dst_sugar_index == 4
|
||||
assert bond4.src_atom == 6
|
||||
assert bond4.dst_atom == 1
|
||||
|
||||
|
||||
def test_complex_parsing_2():
|
||||
glycan = "MAN(4-1 FUC(4-1 MAN)(6-1 FUC(4-1 MAN)))(6-1 MAN(6-1 MAN(4-1 MAN)(6-1 FUC)))".replace(
|
||||
" ", ""
|
||||
)
|
||||
sugars, bonds = _glycan_string_to_sugars_and_bonds(glycan)
|
||||
assert len(sugars) == 9
|
||||
|
||||
expected_bonds = [
|
||||
(0, 1),
|
||||
(1, 2),
|
||||
(1, 3),
|
||||
(3, 4),
|
||||
(0, 5),
|
||||
(5, 6),
|
||||
(6, 7),
|
||||
(6, 8),
|
||||
]
|
||||
for (expected_src, expected_dst), bond in zip(expected_bonds, bonds, strict=True):
|
||||
assert bond.src_sugar_index == expected_src
|
||||
assert bond.dst_sugar_index == expected_dst
|
||||
|
||||
|
||||
def test_glycan_tokenization_with_bond():
|
||||
"""Test that tokenization works, and that atoms are dropped as expected."""
|
||||
glycan = ">glycan|foo\nNAG(4-1 NAG)\n"
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
tmp_path = Path(tmpdir)
|
||||
|
||||
fasta_file = tmp_path / "input.fasta"
|
||||
fasta_file.write_text(glycan)
|
||||
|
||||
output_dir = tmp_path / "out"
|
||||
|
||||
feature_context = make_all_atom_feature_context(
|
||||
fasta_file,
|
||||
output_dir=output_dir,
|
||||
use_esm_embeddings=False, # Just a test; no need
|
||||
)
|
||||
|
||||
# Each NAG component is C8 H15 N O6 -> 8 + 1 + 6 = 15 heavy atoms
|
||||
# The bond between them displaces one oxygen, leaving 2 * 15 - 1 = 29 atoms
|
||||
assert feature_context.structure_context.atom_exists_mask.sum() == 29
|
||||
# We originally constructed all atoms in dropped the atoms that leave
|
||||
assert feature_context.structure_context.atom_exists_mask.numel() == 30
|
||||
elements = Counter(
|
||||
feature_context.structure_context.atom_ref_element[
|
||||
feature_context.structure_context.atom_exists_mask
|
||||
].tolist()
|
||||
)
|
||||
assert elements[6] == 16 # 6 = Carbon
|
||||
assert elements[7] == 2 # 7 = Nitrogen
|
||||
assert elements[8] == 11 # 8 = Oxygen
|
||||
|
||||
# Single bond feature between O and C
|
||||
left, right = feature_context.structure_context.atom_covalent_bond_indices
|
||||
assert left.numel() == right.numel() == 1
|
||||
bond_elements = set(
|
||||
[
|
||||
feature_context.structure_context.atom_ref_element[left].item(),
|
||||
feature_context.structure_context.atom_ref_element[right].item(),
|
||||
]
|
||||
)
|
||||
assert bond_elements == {8, 6}
|
||||
106
tests/test_inference_dataset.py
Executable file
106
tests/test_inference_dataset.py
Executable file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
|
||||
"""
|
||||
Tests for inference dataset.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
|
||||
from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import (
|
||||
AllAtomResidueTokenizer,
|
||||
)
|
||||
from chai_lab.data.dataset.structure.all_atom_structure_context import (
|
||||
AllAtomStructureContext,
|
||||
)
|
||||
from chai_lab.data.dataset.structure.chain import Chain
|
||||
from chai_lab.data.parsing.structure.entity_type import EntityType
|
||||
from chai_lab.data.sources.rdkit import RefConformerGenerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer() -> AllAtomResidueTokenizer:
|
||||
return AllAtomResidueTokenizer(RefConformerGenerator())
|
||||
|
||||
|
||||
def test_malformed_smiles(tokenizer: AllAtomResidueTokenizer):
|
||||
"""Malformed SMILES should be dropped."""
|
||||
# Zn ligand is malformed (should be [Zn+2])
|
||||
inputs = [
|
||||
Input("RKDESES", entity_type=EntityType.PROTEIN.value, entity_name="foo"),
|
||||
Input("Zn", entity_type=EntityType.LIGAND.value, entity_name="bar"),
|
||||
Input("RKEEE", entity_type=EntityType.PROTEIN.value, entity_name="baz"),
|
||||
Input("EEEEEEEEEEEE", entity_type=EntityType.PROTEIN.value, entity_name="boz"),
|
||||
]
|
||||
chains = load_chains_from_raw(
|
||||
inputs,
|
||||
identifier="test",
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
assert len(chains) == 3
|
||||
for chain in chains:
|
||||
# NOTE this check is only valid because there are no residues that are tokenized per-atom
|
||||
# Ensures that the entity data and the structure context in each chain are paired correctly
|
||||
assert chain.structure_context.num_tokens == len(
|
||||
chain.entity_data.full_sequence
|
||||
)
|
||||
|
||||
|
||||
def test_ions_parsing(tokenizer: AllAtomResidueTokenizer):
|
||||
"""Ions as SMILES strings should carry the correct charge."""
|
||||
inputs = [Input("[Mg+2]", entity_type=EntityType.LIGAND.value, entity_name="foo")]
|
||||
chains = load_chains_from_raw(inputs, identifier="foo", tokenizer=tokenizer)
|
||||
assert len(chains) == 1
|
||||
chain = chains[0]
|
||||
assert chain.structure_context.num_atoms == 1
|
||||
assert chain.structure_context.atom_ref_charge == 2
|
||||
assert chain.structure_context.atom_ref_element.item() == 12
|
||||
|
||||
|
||||
def test_protein_with_smiles(tokenizer: AllAtomResidueTokenizer):
|
||||
"""Complex with multiple duplicated protein chains and SMILES ligands."""
|
||||
# Based on https://www.rcsb.org/structure/1AFS
|
||||
seq = "MDSISLRVALNDGNFIPVLGFGTTVPEKVAKDEVIKATKIAIDNGFRHFDSAYLYEVEEEVGQAIRSKIEDGTVKREDIFYTSKLWSTFHRPELVRTCLEKTLKSTQLDYVDLYIIHFPMALQPGDIFFPRDEHGKLLFETVDICDTWEAMEKCKDAGLAKSIGVSNFNCRQLERILNKPGLKYKPVCNQVECHLYLNQSKMLDYCKSKDIILVSYCTLGSSRDKTWVDQKSPVLLDDPVLCAIAKKYKQTPALVALRYQLQRGVVPLIRSFNAKRIKELTQVFEFQLASEDMKALDGLNRNFRYNNAKYFDDHPNHPFTDEN"
|
||||
nap = "NC(=O)c1ccc[n+](c1)[CH]2O[CH](CO[P]([O-])(=O)O[P](O)(=O)OC[CH]3O[CH]([CH](O[P](O)(O)=O)[CH]3O)n4cnc5c(N)ncnc45)[CH](O)[CH]2O"
|
||||
tes = "O=C4C=C3C(C2CCC1(C(CCC1O)C2CC3)C)(C)CC4"
|
||||
inputs = [
|
||||
Input(seq, EntityType.PROTEIN.value, entity_name="A"),
|
||||
Input(seq, EntityType.PROTEIN.value, entity_name="B"),
|
||||
Input(nap, EntityType.LIGAND.value, entity_name="C"),
|
||||
Input(nap, EntityType.LIGAND.value, entity_name="D"),
|
||||
Input(tes, EntityType.LIGAND.value, entity_name="E"),
|
||||
Input(tes, EntityType.LIGAND.value, entity_name="F"),
|
||||
]
|
||||
chains: list[Chain] = load_chains_from_raw(inputs, tokenizer=tokenizer)
|
||||
assert len(chains) == len(inputs)
|
||||
|
||||
example = AllAtomStructureContext.merge(
|
||||
[chain.structure_context for chain in chains]
|
||||
)
|
||||
|
||||
# Should be 1 protein chain, 2 ligand chains
|
||||
assert example.token_entity_id.unique().numel() == 3
|
||||
assert example.token_asym_id.unique().numel() == 6
|
||||
|
||||
# Check protein chains
|
||||
prot_entity_ids = example.token_entity_id[
|
||||
example.token_entity_type == EntityType.PROTEIN.value
|
||||
]
|
||||
assert torch.unique(prot_entity_ids).numel() == 1
|
||||
prot_sym_ids = example.token_sym_id[
|
||||
example.token_entity_type == EntityType.PROTEIN.value
|
||||
]
|
||||
assert torch.unique(prot_sym_ids).numel() == 2 # Two copies of this chain
|
||||
|
||||
# Check ligand chains
|
||||
lig_entity_ids = example.token_entity_id[
|
||||
example.token_entity_type == EntityType.LIGAND.value
|
||||
]
|
||||
assert torch.unique(lig_entity_ids).numel() == 2
|
||||
lig_sym_ids = example.token_sym_id[
|
||||
example.token_entity_type == EntityType.LIGAND.value
|
||||
]
|
||||
assert torch.unique(lig_sym_ids).numel() == 2 # Two copies of each ligand
|
||||
36
tests/test_msa_a3m_tokenization.py
Executable file
36
tests/test_msa_a3m_tokenization.py
Executable file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
"""
|
||||
Test for tokenization
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from chai_lab.data.parsing.msas.a3m import tokenize_sequences_to_arrays
|
||||
from chai_lab.data.residue_constants import residue_types_with_nucleotides_order
|
||||
|
||||
|
||||
def test_tokenization_basic():
|
||||
test_sequence = "RKDES"
|
||||
|
||||
out, dels = tokenize_sequences_to_arrays([test_sequence])
|
||||
assert out.shape == dels.shape == (1, 5)
|
||||
assert np.all(
|
||||
out
|
||||
== np.array(
|
||||
[residue_types_with_nucleotides_order[res] for res in test_sequence]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_tokenization_with_insertion():
|
||||
"""Insertions (lower case) should be ignored."""
|
||||
test_sequence = "RKDES"
|
||||
test_with_ins = "RKrkdesDES"
|
||||
|
||||
out, dels = tokenize_sequences_to_arrays([test_sequence, test_with_ins])
|
||||
assert out.shape == dels.shape == (2, 5)
|
||||
assert np.all(out[0] == out[1])
|
||||
assert dels.sum() == 5
|
||||
assert dels[1, 2] == 5
|
||||
25
tests/test_msa_preprocess.py
Executable file
25
tests/test_msa_preprocess.py
Executable file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
import torch
|
||||
|
||||
from chai_lab.data.dataset.msas.msa_context import NO_PAIRING_KEY
|
||||
from chai_lab.data.dataset.msas.preprocess import _UKEY_FOR_QUERY, prepair_ukey
|
||||
|
||||
|
||||
def test_prepair_ukey():
|
||||
keys = torch.tensor([1, 1, 2, 1, NO_PAIRING_KEY, 2, 3])
|
||||
edit_dists = torch.arange(len(keys))
|
||||
|
||||
paired = prepair_ukey(keys, edit_dists)
|
||||
assert list(paired) == [_UKEY_FOR_QUERY, (1, 0), (2, 0), (1, 1), (2, 1), (3, 0)]
|
||||
assert set(paired.values()) == set(
|
||||
[i for i, val in enumerate(keys.tolist()) if val != NO_PAIRING_KEY]
|
||||
)
|
||||
|
||||
# Reverse the edit distances
|
||||
paired = prepair_ukey(keys, torch.tensor(edit_dists.tolist()[::-1]))
|
||||
assert list(paired) == [_UKEY_FOR_QUERY, (1, 1), (2, 1), (1, 0), (2, 0), (3, 0)]
|
||||
assert set(paired.values()) == set(
|
||||
[i for i, val in enumerate(keys.tolist()) if val != NO_PAIRING_KEY]
|
||||
)
|
||||
79
tests/test_parsing.py
Executable file
79
tests/test_parsing.py
Executable file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from chai_lab.data.parsing.fasta import read_fasta
|
||||
from chai_lab.data.parsing.input_validation import (
|
||||
constituents_of_modified_fasta,
|
||||
identify_potential_entity_types,
|
||||
)
|
||||
from chai_lab.data.parsing.structure.entity_type import EntityType
|
||||
|
||||
from .example_inputs import example_dna, example_ligands, example_proteins, example_rna
|
||||
|
||||
|
||||
def test_simple_protein_fasta():
|
||||
parts = constituents_of_modified_fasta("RKDES")
|
||||
assert parts is not None
|
||||
assert all(x == y for x, y in zip(parts, ["R", "K", "D", "E", "S"]))
|
||||
|
||||
|
||||
def test_modified_protein_fasta():
|
||||
parts = constituents_of_modified_fasta("(KCJ)(SEP)(PPN)(B3S)(BAL)(PPN)KX(NH2)")
|
||||
assert parts is not None
|
||||
expected = ["KCJ", "SEP", "PPN", "B3S", "BAL", "PPN", "K", "X", "NH2"]
|
||||
assert all(x == y for x, y in zip(parts, expected))
|
||||
|
||||
|
||||
def test_rna_fasta():
|
||||
seq = "ACUGACG"
|
||||
parts = constituents_of_modified_fasta(seq)
|
||||
assert parts is not None
|
||||
assert all(x == y for x, y in zip(parts, seq))
|
||||
|
||||
|
||||
def test_dna_fasta():
|
||||
seq = "ACGACTAGCAT"
|
||||
parts = constituents_of_modified_fasta(seq)
|
||||
assert parts is not None
|
||||
assert all(x == y for x, y in zip(parts, seq))
|
||||
|
||||
|
||||
def test_parsing():
|
||||
for ligand in example_ligands:
|
||||
assert EntityType.LIGAND in identify_potential_entity_types(ligand)
|
||||
|
||||
for protein in example_proteins:
|
||||
assert EntityType.PROTEIN in identify_potential_entity_types(protein)
|
||||
|
||||
for dna in example_dna:
|
||||
assert EntityType.DNA in identify_potential_entity_types(dna)
|
||||
|
||||
for rna in example_rna:
|
||||
assert EntityType.RNA in identify_potential_entity_types(rna)
|
||||
|
||||
|
||||
def test_fasta_parsing():
|
||||
test_string = """>foo\nRKDES\n>bar\nKEDESRRR"""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
fa_file = Path(tmpdir) / "test.fasta"
|
||||
fa_file.write_text(test_string)
|
||||
records = read_fasta(fa_file)
|
||||
|
||||
assert len(records) == 2
|
||||
assert records[0].header == "foo"
|
||||
assert records[0].sequence == "RKDES"
|
||||
assert records[1].header == "bar"
|
||||
assert records[1].sequence == "KEDESRRR"
|
||||
|
||||
|
||||
def test_smiles_parsing():
|
||||
smiles = ">smiles\nCc1cc2nc3c(=O)[nH]c(=O)nc-3n(C[C@H](O)[C@H](O)[C@H](O)CO)c2cc1C"
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
fa_file = Path(tmpdir) / "test.fasta"
|
||||
fa_file.write_text(smiles)
|
||||
records = read_fasta(fa_file)
|
||||
assert len(records) == 1
|
||||
24
tests/test_rdkit.py
Executable file
24
tests/test_rdkit.py
Executable file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) 2024 Chai Discovery, Inc.
|
||||
# Licensed under the Apache License, Version 2.0.
|
||||
# See the LICENSE file for details.
|
||||
|
||||
from chai_lab.data.sources.rdkit import RefConformerGenerator
|
||||
|
||||
|
||||
def test_ref_conformer_from_smiles():
|
||||
"""Test ref conformer generation from SMILES."""
|
||||
smiles = "Cc1cc2nc3c(=O)[nH]c(=O)nc-3n(C[C@H](O)[C@H](O)[C@H](O)CO)c2cc1C"
|
||||
rcg = RefConformerGenerator()
|
||||
|
||||
conformer = rcg.generate(smiles)
|
||||
|
||||
assert len(set(conformer.atom_names)) == conformer.num_atoms
|
||||
|
||||
|
||||
def test_ref_conformer_glycan_ccd():
|
||||
"""Ref conformer from CCD code for a sugar ring."""
|
||||
rcg = RefConformerGenerator()
|
||||
conformer = rcg.get("MAN")
|
||||
assert conformer is not None
|
||||
|
||||
assert len(set(conformer.atom_names)) == conformer.num_atoms
|
||||
Reference in New Issue
Block a user