More static typing

This commit is contained in:
Thomas Holder
2024-01-21 09:23:10 +01:00
parent c8d93f6f01
commit 25ed65eed9
2 changed files with 38 additions and 28 deletions

View File

@@ -12,7 +12,7 @@ in configuration file.
""" """
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List from typing import Callable, Dict, List, Sequence, Tuple, TypeVar, Union
try: try:
# New in version 3.10, deprecated since version 3.12 # New in version 3.10, deprecated since version 3.12
@@ -39,6 +39,8 @@ class squared_property:
setattr(instance, self._name_not_squared, value**0.5) setattr(instance, self._name_not_squared, value**0.5)
T = TypeVar("T")
_T_MATRIX: TypeAlias = "InteractionMatrix" _T_MATRIX: TypeAlias = "InteractionMatrix"
_T_PAIR_WISE_MATRIX: TypeAlias = "PairwiseMatrix" _T_PAIR_WISE_MATRIX: TypeAlias = "PairwiseMatrix"
_T_NUMBER_DICTIONARY = Dict[str, float] _T_NUMBER_DICTIONARY = Dict[str, float]
@@ -155,8 +157,10 @@ class Parameters:
self.parse_to_matrix(words) self.parse_to_matrix(words)
elif typeannotation is _T_STRING_DICTIONARY: elif typeannotation is _T_STRING_DICTIONARY:
self.parse_to_string_dictionary(words) self.parse_to_string_dictionary(words)
elif typeannotation is int or typeannotation is _T_BOOL:
self.parse_parameter(words, int)
else: else:
self.parse_parameter(words) self.parse_parameter(words, float)
def parse_to_number_dictionary(self, words): def parse_to_number_dictionary(self, words):
"""Parse field to number dictionary. """Parse field to number dictionary.
@@ -219,14 +223,14 @@ class Parameters:
value = tuple(words[1:]) value = tuple(words[1:])
matrix.add(value) matrix.add(value)
def parse_parameter(self, words): def parse_parameter(self, words, typefunc: Callable[[str], T]):
"""Parse field to parameters. """Parse field to parameters.
Args: Args:
words: strings to parse words: strings to parse
""" """
assert len(words) == 2, words assert len(words) == 2, words
value = float(words[1]) value = typefunc(words[1])
setattr(self, words[0], value) setattr(self, words[0], value)
def parse_string(self, words): def parse_string(self, words):
@@ -448,37 +452,40 @@ O2
class InteractionMatrix: class InteractionMatrix:
"""Interaction matrix class.""" """Interaction matrix class."""
def __init__(self, name): def __init__(self, name: str):
"""Initialize with name of matrix. """Initialize with name of matrix.
Args: Args:
name: name of interaction matrix name: name of interaction matrix
""" """
self.name = name self.name = name
self.value = None self.ordered_keys: List[str] = []
self.ordered_keys = [] self.dictionary: Dict[str, Dict[str, Union[str, float]]] = {}
self.dictionary = {}
def add(self, words): def add(self, words: Sequence[str]):
"""Add values to matrix. """Add values to matrix.
Args: Args:
words: values to add words: values to add
""" """
len_expected = len(self.ordered_keys) + 2
if len(words) != len_expected:
raise ValueError(f"Expected {len_expected} arguments, got {words!r}")
new_group = words[0] new_group = words[0]
self.ordered_keys.append(new_group) self.ordered_keys.append(new_group)
if new_group not in self.dictionary.keys(): if new_group not in self.dictionary.keys():
self.dictionary[new_group] = {} self.dictionary[new_group] = {}
for i, group in enumerate(self.ordered_keys): for i, group in enumerate(self.ordered_keys):
if len(words) > i+1: if len(words) > i+1:
value: Union[str, float]
try: try:
self.value = float(words[i+1]) value = float(words[i+1])
except ValueError: except ValueError:
self.value = words[i+1] value = words[i+1]
self.dictionary[group][new_group] = self.value self.dictionary[group][new_group] = value
self.dictionary[new_group][group] = self.value self.dictionary[new_group][group] = value
def get_value(self, item1, item2): def get_value(self, item1: str, item2: str) -> Union[str, float, None]:
"""Get specific matrix value. """Get specific matrix value.
Args: Args:
@@ -492,7 +499,7 @@ class InteractionMatrix:
except KeyError: except KeyError:
return None return None
def __getitem__(self, group): def __getitem__(self, group: str):
"""Get specific group from matrix. """Get specific group from matrix.
Args: Args:
@@ -528,17 +535,17 @@ class InteractionMatrix:
class PairwiseMatrix: class PairwiseMatrix:
"""Pairwise interaction matrix class.""" """Pairwise interaction matrix class."""
def __init__(self, name): def __init__(self, name: str):
"""Initialize pairwise matrix. """Initialize pairwise matrix.
Args: Args:
name: name of pairwise interaction name: name of pairwise interaction
""" """
self.name = name self.name = name
self.dictionary = {} self.dictionary: Dict[str, Dict[str, Tuple[float, float]]] = {}
self.default = [0.0, 0.0] self.default = (0.0, 0.0)
def add(self, words): def add(self, words: Sequence[str]):
"""Add information to the matrix. """Add information to the matrix.
TODO - this function unnecessarily bundles arguments into a tuple TODO - this function unnecessarily bundles arguments into a tuple
@@ -548,16 +555,17 @@ class PairwiseMatrix:
""" """
# assign the default value # assign the default value
if len(words) == 3 and words[0] == 'default': if len(words) == 3 and words[0] == 'default':
self.default = [float(words[1]), float(words[2])] self.default = (float(words[1]), float(words[2]))
return return
# assign non-default values # assign non-default values
assert len(words) == 4
group1 = words[0] group1 = words[0]
group2 = words[1] group2 = words[1]
value = [float(words[2]), float(words[3])] value = (float(words[2]), float(words[3]))
self.insert(group1, group2, value) self.insert(group1, group2, value)
self.insert(group2, group1, value) self.insert(group2, group1, value)
def insert(self, key1, key2, value): def insert(self, key1: str, key2: str, value: Tuple[float, float]):
"""Insert value into matrix. """Insert value into matrix.
Args: Args:
@@ -575,7 +583,7 @@ class PairwiseMatrix:
self.dictionary[key1] = {} self.dictionary[key1] = {}
self.dictionary[key1][key2] = value self.dictionary[key1][key2] = value
def get_value(self, item1, item2): def get_value(self, item1: str, item2: str) -> Tuple[float, float]:
"""Get specified value from matrix. """Get specified value from matrix.
Args: Args:
@@ -589,7 +597,7 @@ class PairwiseMatrix:
except KeyError: except KeyError:
return self.default return self.default
def __getitem__(self, group): def __getitem__(self, group: str):
"""Get item from matrix corresponding to specific group. """Get item from matrix corresponding to specific group.
Args: Args:

View File

@@ -7,6 +7,7 @@ Contains version-specific methods and parameters.
TODO - this module unnecessarily confuses the code. Can we eliminate it? TODO - this module unnecessarily confuses the code. Can we eliminate it?
""" """
import logging import logging
from typing import Sequence, Tuple
from propka.atom import Atom from propka.atom import Atom
from propka.hydrogens import setup_bonding_and_protonation, setup_bonding from propka.hydrogens import setup_bonding_and_protonation, setup_bonding
from propka.hydrogens import setup_bonding_and_protonation_30_style from propka.hydrogens import setup_bonding_and_protonation_30_style
@@ -15,6 +16,7 @@ from propka.energy import hydrogen_bond_energy, hydrogen_bond_interaction
from propka.energy import electrostatic_interaction, check_coulomb_pair from propka.energy import electrostatic_interaction, check_coulomb_pair
from propka.energy import coulomb_energy, check_exceptions from propka.energy import coulomb_energy, check_exceptions
from propka.energy import backbone_reorganization from propka.energy import backbone_reorganization
from propka.parameters import Parameters
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -22,7 +24,7 @@ _LOGGER = logging.getLogger(__name__)
class Version: class Version:
"""Store version-specific methods and parameters.""" """Store version-specific methods and parameters."""
def __init__(self, parameters): def __init__(self, parameters: Parameters):
self.parameters = parameters self.parameters = parameters
self.desolvation_model = self.empty_function self.desolvation_model = self.empty_function
self.weight_pair_method = self.empty_function self.weight_pair_method = self.empty_function
@@ -99,7 +101,7 @@ class Version:
"""Setup bonding using assigned model.""" """Setup bonding using assigned model."""
return self.prepare_bonds(self.parameters, molecular_container) return self.prepare_bonds(self.parameters, molecular_container)
def get_hydrogen_bond_parameters(self, atom1: Atom, atom2: Atom) -> tuple: def get_hydrogen_bond_parameters(self, atom1: Atom, atom2: Atom) -> Tuple[float, Sequence[float]]:
"""Get hydrogen bond parameters for two atoms.""" """Get hydrogen bond parameters for two atoms."""
raise NotImplementedError("abstract method") raise NotImplementedError("abstract method")
@@ -136,7 +138,7 @@ class VersionA(Version):
dpka_max = self.parameters.sidechain_interaction dpka_max = self.parameters.sidechain_interaction
cutoff = self.parameters.sidechain_cutoffs.get_value( cutoff = self.parameters.sidechain_cutoffs.get_value(
atom1.group_type, atom2.group_type) atom1.group_type, atom2.group_type)
return [dpka_max, cutoff] return dpka_max, cutoff
def get_backbone_hydrogen_bond_parameters(self, backbone_atom, atom): def get_backbone_hydrogen_bond_parameters(self, backbone_atom, atom):
"""Get hydrogen bond parameters between backbone atom and other atom. """Get hydrogen bond parameters between backbone atom and other atom.
@@ -311,4 +313,4 @@ class Propka30(Version):
atom1.group_type, atom2.group_type) atom1.group_type, atom2.group_type)
cutoff = self.parameters.sidechain_cutoffs.get_value( cutoff = self.parameters.sidechain_cutoffs.get_value(
atom1.group_type, atom2.group_type) atom1.group_type, atom2.group_type)
return [dpka_max, cutoff] return dpka_max, cutoff