Source code for dcor._utils
"""Utility functions."""
from __future__ import annotations
import enum
import warnings
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, Union
import numpy as np
from array_api_compat import (
array_namespace as array_namespace_compat,
numpy as numpy_namespace,
)
# TODO: Change in the future
if TYPE_CHECKING:
ArrayType = np.typing.NDArray[np.number[Any]]
else:
ArrayType = np.ndarray
T = TypeVar("T", bound=ArrayType)
RandomLike = Union[
np.random.RandomState,
np.random.Generator,
int,
None,
]
[docs]class CompileMode(enum.Enum):
"""Compilation mode of the algorithm."""
AUTO = enum.auto()
"""
Try to use the fastest available method.
"""
NO_COMPILE = enum.auto()
"""
Use a pure Python implementation.
"""
COMPILE_CPU = enum.auto()
"""
Compile for execution in one CPU.
"""
COMPILE_PARALLEL = enum.auto()
"""
Compile for execution in multicore CPUs.
"""
[docs]class RowwiseMode(enum.Enum):
"""Rowwise mode of the algorithm."""
AUTO = enum.auto()
"""
Try to use the fastest available method.
"""
NAIVE = enum.auto()
"""
Use naive (list comprehension/map) computation.
"""
OPTIMIZED = enum.auto()
"""
Use optimized version, or fail if there is none.
"""
# TODO: Change the return type in the future
def array_namespace(*xs: Any) -> Any:
# `xs` contains one or more arrays, or possibly Python scalars (accepting
# those is a matter of taste, but doesn't seem unreasonable).
try:
return array_namespace_compat(*xs)
except TypeError:
warnings.warn(
"Passing non-array objects to functions in the 'dcor' "
"package is deprecated and will be removed in a future version",
DeprecationWarning,
stacklevel=3, # TODO: Use skip_file_prefixes in Python 3.12?
)
return numpy_namespace
def _sqrt(x: T) -> T:
"""
Return square root of an array.
This sqrt function for ndarrays tries to use the exponentiation operator
if the objects stored do not supply a sqrt method.
Args:
x: Input array.
Returns:
Square root of the input array.
"""
# Replace negative values with 0
xp = array_namespace(x)
x_copy = xp.asarray(x + 0)
x_copy[x_copy < 0] = 0
try:
return xp.sqrt(x_copy)
except (AttributeError, TypeError):
return x_copy**0.5
def _transform_to_1d(*args: T) -> Iterable[T]:
"""Convert column matrices to vectors, to always have a 1d shape."""
xp = array_namespace(*args)
for array in args:
array = xp.asarray(array)
dim = len(array.shape)
assert dim <= 2
if dim == 2:
assert array.shape[1] == 1
array = xp.reshape(array, -1)
yield array
def _transform_to_2d(*args: T) -> Iterable[T]:
"""Convert vectors to column matrices, to always have a 2d shape."""
xp = array_namespace(*args)
for array in args:
array = xp.asarray(array)
dim = len(array.shape)
assert dim <= 2
if dim < 2:
array = xp.expand_dims(array, axis=1)
yield array
def _can_be_numpy_double(x: ArrayType) -> bool:
"""
Return if the array can be safely converted to double.
That happens when the dtype is a float with the same size of
a double or narrower, or when is an integer that can be safely
converted to double (if the roundtrip conversion works).
"""
if array_namespace(x) != np:
return False
return (
(
np.issubdtype(x.dtype, np.floating)
and x.dtype.itemsize <= np.dtype(float).itemsize
) or (
np.issubdtype(x.dtype, np.signedinteger)
and np.can_cast(x, float)
)
)
def _random_state_init(
random_state: RandomLike,
) -> np.random.RandomState | np.random.Generator:
"""
Initialize a RandomState object.
If the object is a RandomState, or cannot be used to
initialize one, it will be assumed that is a similar object
and returned.
"""
if isinstance(random_state, (np.random.RandomState, np.random.Generator)):
return random_state
return np.random.RandomState(random_state)