Deprecate Vector operator overloads

This commit is contained in:
Thomas Holder
2023-12-13 08:12:42 +01:00
parent 57ad5d8384
commit 723609cc33
4 changed files with 30 additions and 24 deletions

View File

@@ -319,11 +319,11 @@ def are_atoms_planar(atoms):
return False return False
vec1 = Vector(atom1=atoms[0], atom2=atoms[1]) vec1 = Vector(atom1=atoms[0], atom2=atoms[1])
vec2 = Vector(atom1=atoms[0], atom2=atoms[2]) vec2 = Vector(atom1=atoms[0], atom2=atoms[2])
norm = (vec1**vec2).rescale(1.0) norm = vec1.cross(vec2).rescale(1.0)
margin = PLANARITY_MARGIN margin = PLANARITY_MARGIN
for atom in atoms[3:]: for atom in atoms[3:]:
vec = Vector(atom1=atoms[0], atom2=atom).rescale(1.0) vec = Vector(atom1=atoms[0], atom2=atom).rescale(1.0)
if abs(vec*norm) > margin: if abs(vec.dot(norm)) > margin:
return False return False
return True return True

View File

@@ -274,15 +274,15 @@ class Protonate:
vec2 = Vector(atom1=atom.bonded_atoms[0], vec2 = Vector(atom1=atom.bonded_atoms[0],
atom2=atom.bonded_atoms[0] atom2=atom.bonded_atoms[0]
.bonded_atoms[other_atom_indices[0]]) .bonded_atoms[other_atom_indices[0]])
axis = vec1**vec2 axis = vec1.cross(vec2)
# this is a trick to make sure that the order of atoms doesn't # this is a trick to make sure that the order of atoms doesn't
# influence the final postions of added protons # influence the final postions of added protons
if len(other_atom_indices) > 1: if len(other_atom_indices) > 1:
vec3 = Vector(atom1=atom.bonded_atoms[0], vec3 = Vector(atom1=atom.bonded_atoms[0],
atom2=atom.bonded_atoms[0] atom2=atom.bonded_atoms[0]
.bonded_atoms[other_atom_indices[1]]) .bonded_atoms[other_atom_indices[1]])
axis2 = vec1**vec3 axis2 = vec1.cross(vec3)
if axis*axis2 > 0: if axis.dot(axis2) > 0:
axis = axis+axis2 axis = axis+axis2
else: else:
axis = axis-axis2 axis = axis-axis2

View File

@@ -7,6 +7,7 @@ Vector algebra for PROPKA.
import logging import logging
import math import math
from typing import Optional, Protocol, overload from typing import Optional, Protocol, overload
import warnings
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -87,10 +88,11 @@ class Vector:
def __mul__(self, other): def __mul__(self, other):
"""Dot product, scalar and matrix multiplication.""" """Dot product, scalar and matrix multiplication."""
if isinstance(other, Vector): if isinstance(other, Vector):
# TODO deprecate in favor of self.dot() warnings.warn("Use Vector.dot() instead of operator.mul()", DeprecationWarning, stacklevel=2)
return self.dot(other) return self.dot(other)
if isinstance(other, Matrix4x4): if isinstance(other, Matrix4x4):
# TODO deprecate in favor of matmul operator warnings.warn("Use M @ v (operator.matmul()) instead of M * v (operator.mul())",
DeprecationWarning, stacklevel=2)
return other @ self return other @ self
if isinstance(other, (int, float)): if isinstance(other, (int, float)):
return Vector(self.x * other, self.y * other, self.z * other) return Vector(self.x * other, self.y * other, self.z * other)
@@ -100,7 +102,7 @@ class Vector:
return self.__mul__(other) return self.__mul__(other)
def __pow__(self, other: _XYZ): def __pow__(self, other: _XYZ):
# TODO deprecate in favor of self.cross() warnings.warn("Use Vector.cross() instead of operator.pow()", DeprecationWarning, stacklevel=2)
return self.cross(other) return self.cross(other)
def cross(self, other: _XYZ): def cross(self, other: _XYZ):
@@ -195,7 +197,7 @@ def angle(avec: Vector, bvec: Vector) -> float:
Returns: Returns:
angle in radians angle in radians
""" """
dot = avec * bvec dot = avec.dot(bvec)
return math.acos(dot / (avec.length() * bvec.length())) return math.acos(dot / (avec.length() * bvec.length()))
@@ -221,10 +223,10 @@ def signed_angle_around_axis(avec: Vector, bvec: Vector, axis: Vector) -> float:
Returns: Returns:
angle in radians angle in radians
""" """
norma = avec**axis norma = avec.cross(axis)
normb = bvec**axis normb = bvec.cross(axis)
ang = angle(norma, normb) ang = angle(norma, normb)
dot_ = bvec*(avec**axis) dot_ = bvec.dot(avec.cross(axis))
if dot_ < 0: if dot_ < 0:
ang = -ang ang = -ang
return ang return ang
@@ -248,21 +250,21 @@ def rotate_vector_around_an_axis(theta: float, axis: Vector, vec: Vector) -> Vec
else: else:
gamma = math.pi/2.0 gamma = math.pi/2.0
rot_z = rotate_atoms_around_z_axis(gamma) rot_z = rotate_atoms_around_z_axis(gamma)
vec = rot_z * vec vec = rot_z @ vec
axis = rot_z * axis axis = rot_z @ axis
beta = 0.0 beta = 0.0
if axis.x != 0: if axis.x != 0:
beta = -axis.x/abs(axis.x)*math.acos( beta = -axis.x/abs(axis.x)*math.acos(
axis.z/math.sqrt(axis.x*axis.x + axis.z*axis.z)) axis.z/math.sqrt(axis.x*axis.x + axis.z*axis.z))
rot_y = rotate_atoms_around_y_axis(beta) rot_y = rotate_atoms_around_y_axis(beta)
vec = rot_y * vec vec = rot_y @ vec
axis = rot_y * axis axis = rot_y @ axis
rot_z = rotate_atoms_around_z_axis(theta) rot_z = rotate_atoms_around_z_axis(theta)
vec = rot_z * vec vec = rot_z @ vec
rot_y = rotate_atoms_around_y_axis(-beta) rot_y = rotate_atoms_around_y_axis(-beta)
vec = rot_y * vec vec = rot_y @ vec
rot_z = rotate_atoms_around_z_axis(-gamma) rot_z = rotate_atoms_around_z_axis(-gamma)
vec = rot_z * vec vec = rot_z @ vec
return vec return vec

View File

@@ -72,7 +72,8 @@ def test_Vector__mul__number():
def test_Vector__mul__Vector(): def test_Vector__mul__Vector():
v1 = m.Vector(1, 2, 3) v1 = m.Vector(1, 2, 3)
v2 = m.Vector(4, 5, 6) v2 = m.Vector(4, 5, 6)
assert v1 * v2 == 32 with pytest.deprecated_call():
assert v1 * v2 == 32
assert v1.dot(v2) == 32 assert v1.dot(v2) == 32
with pytest.raises(TypeError): with pytest.raises(TypeError):
v1 @ v2 # type: ignore v1 @ v2 # type: ignore
@@ -80,10 +81,12 @@ def test_Vector__mul__Vector():
def test_Vector__mul__Matrix4x4(): def test_Vector__mul__Matrix4x4():
v1 = m.Vector(1, 2, 3) v1 = m.Vector(1, 2, 3)
assert_vector_equal(v1 * m.Matrix4x4(), m.Vector()) assert_vector_equal(m.Matrix4x4() @ v1, m.Vector())
m2 = m.Matrix4x4(0, 1, 0, 0, 20, 0, 0, 0, 0, 0, 300, 0, 0, 0, 0, 1) m2 = m.Matrix4x4(0, 1, 0, 0, 20, 0, 0, 0, 0, 0, 300, 0, 0, 0, 0, 1)
assert_vector_equal(v1 * m2, m.Vector(2, 20, 900)) with pytest.deprecated_call():
assert_vector_equal(m2 * v1, m.Vector(2, 20, 900)) assert_vector_equal(v1 * m2, m.Vector(2, 20, 900))
with pytest.deprecated_call():
assert_vector_equal(m2 * v1, m.Vector(2, 20, 900))
assert_vector_equal(m2 @ v1, m.Vector(2, 20, 900)) assert_vector_equal(m2 @ v1, m.Vector(2, 20, 900))
with pytest.raises(TypeError): with pytest.raises(TypeError):
v1 @ m2 # type: ignore v1 @ m2 # type: ignore
@@ -92,7 +95,8 @@ def test_Vector__mul__Matrix4x4():
def test_Vector__cross(): def test_Vector__cross():
v1 = m.Vector(1, 2, 3) v1 = m.Vector(1, 2, 3)
v2 = m.Vector(4, 5, 6) v2 = m.Vector(4, 5, 6)
assert_vector_equal(v1**v2, m.Vector(-3, 6, -3)) # TODO deprecate with pytest.deprecated_call():
assert_vector_equal(v1**v2, m.Vector(-3, 6, -3))
assert_vector_equal(v1.cross(v2), m.Vector(-3, 6, -3)) assert_vector_equal(v1.cross(v2), m.Vector(-3, 6, -3))
assert_vector_equal(v2.cross(v1), m.Vector(3, -6, 3)) assert_vector_equal(v2.cross(v1), m.Vector(3, -6, 3))