Source code for dcor.distances

"""
Distance functions between sets of points.

This module provide functions that compute the distance between one or two
sets of points. The Scipy implementation is used when the conversion to
a double precision floating point number will not cause loss of precision.
"""

from __future__ import annotations

from typing import TypeVar

import numpy as np
import scipy.spatial as spatial

from dcor._utils import ArrayType, _sqrt, _transform_to_2d, array_namespace

from ._utils import _can_be_numpy_double

Array = TypeVar("Array", bound=ArrayType)


def _cdist_naive(x: Array, y: Array, exponent: float = 1) -> Array:
    """Pairwise distance, custom implementation."""
    xp = array_namespace(x, y)

    x = xp.asarray(x)
    y = xp.asarray(y)

    x = xp.expand_dims(x, axis=0)
    y = xp.expand_dims(y, axis=1)

    squared_norms = xp.sum(((x - y) ** 2), axis=-1)

    try:
        return squared_norms ** (exponent / 2)
    except TypeError:
        return _sqrt(squared_norms ** exponent)


def _pdist_scipy(
    x: np.typing.NDArray[float],
    exponent: float = 1,
) -> np.typing.NDArray[float]:
    """Pairwise distance between points in a set."""
    metric = 'euclidean'

    if exponent != 1:
        metric = 'sqeuclidean'

    # cdist is actually FASTER than pdist + squareform
    distances = spatial.distance.cdist(x, x, metric=metric)

    if exponent != 1:
        distances **= exponent / 2

    return distances


def _cdist_scipy(
    x: np.typing.NDArray[float],
    y: np.typing.NDArray[float],
    exponent: float = 1,
) -> np.typing.NDArray[float]:
    """Pairwise distance between the points in two sets."""
    metric = 'euclidean'

    if exponent != 1:
        metric = 'sqeuclidean'

    distances = spatial.distance.cdist(x, y, metric=metric)

    if exponent != 1:
        distances **= exponent / 2

    return distances


def _pdist(x: Array, exponent: float = 1) -> Array:
    """
    Pairwise distance between points in a set.

    As Scipy converts every value to double, this wrapper uses
    a less efficient implementation if the original dtype
    can not be converted to double.

    """
    if _can_be_numpy_double(x):
        return _pdist_scipy(x, exponent)

    return _cdist_naive(x, x, exponent)


def _cdist(x: Array, y: Array, exponent: float = 1) -> Array:
    """
    Pairwise distance between points in two sets.

    As Scipy converts every value to double, this wrapper uses
    a less efficient implementation if the original dtype
    can not be converted to double.

    """
    if _can_be_numpy_double(x) and _can_be_numpy_double(y):
        return _cdist_scipy(x, y, exponent)

    return _cdist_naive(x, y, exponent)


[docs]def pairwise_distances( x: Array, y: Array | None = None, *, exponent: float = 1, ) -> Array: r""" Pairwise distance between points. Return the pairwise distance between points in two sets, or in the same set if only one set is passed. Args: x: An :math:`n \times m` array of :math:`n` observations in a :math:`m`-dimensional space. y: An :math:`l \times m` array of :math:`l` observations in a :math:`m`-dimensional space. If None, the distances will be computed between the points in :math:`x`. exponent: Exponent of the Euclidean distance. Returns: A :math:`n \times l` matrix where the :math:`(i, j)`-th entry is the distance between :math:`x[i]` and :math:`y[j]`. Examples: >>> import numpy as np >>> import dcor >>> a = np.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12], ... [13, 14, 15, 16]]) >>> b = np.array([[16, 15, 14, 13], ... [12, 11, 10, 9], ... [8, 7, 6, 5], ... [4, 3, 2, 1]]) >>> dcor.distances.pairwise_distances(a) array([[ 0., 8., 16., 24.], [ 8., 0., 8., 16.], [16., 8., 0., 8.], [24., 16., 8., 0.]]) >>> dcor.distances.pairwise_distances(a, b) array([[24.41311123, 16.61324773, 9.16515139, 4.47213595], [16.61324773, 9.16515139, 4.47213595, 9.16515139], [ 9.16515139, 4.47213595, 9.16515139, 16.61324773], [ 4.47213595, 9.16515139, 16.61324773, 24.41311123]]) """ if y is None or y is x: x, = _transform_to_2d(x) return _pdist(x, exponent=exponent) x, y = _transform_to_2d(x, y) return _cdist(x, y, exponent=exponent)