Source code for diplomat.utils.clustering

"""
Optimized routines for performing agglomerative clustering on weighted graphs.
"""

from typing import Union, Tuple
import numba
import numpy as np
from enum import IntEnum


_numeric = Union[np.ndarray, int]


class ClusteringMethod(IntEnum):
    """
    The set of agglomerate clustering methods supported by this module.
    Determines how weights are adjusted after a merge.
    """

    SINGLE = 0
    """ Single linkage clustering, the distance between two clusters is the distance between the two closest nodes to each. """
    COMPLETE = 1
    """ Complete linkage clustering, the distance between two clusters is the distance between the two farthest points in each. """
    AVERAGE = 2
    """ Average clustering. The distance between two clusters is the distance between their points averaged. """
    WEIGHTED = 3
    """ Weighted average clustering. Like average, but ignores cluster size when determining the average point for a new cluster. """
    WARD = 4
    """ Ward clustering. Minimizes the total within-cluster variance. """


_CLUSTERING_METHOD_MAX = max(ClusteringMethod)
_CLUSTERING_METHOD_MIN = min(ClusteringMethod)


dist_update_func_sig = numba.float64(
    numba.int64,
    numba.int64,
    numba.int64,
    numba.int64,
    numba.float64,
    numba.float64,
    numba.float64,
)


@numba.njit(dist_update_func_sig)
def _distance_update(
    mode: int, sx: int, sy: int, si: int, dxy: float, dxi: float, dyi: float
) -> float:
    if mode == ClusteringMethod.SINGLE:
        return min(dxi, dyi)
    elif mode == ClusteringMethod.COMPLETE:
        return max(dxi, dyi)
    elif mode == ClusteringMethod.AVERAGE:
        return (sx * dxi + sy * dyi) / (sx + sy)
    elif mode == ClusteringMethod.WEIGHTED:
        return 0.5 * (dxi + dyi)
    elif mode == ClusteringMethod.WARD:
        denom = 1.0 / (sx + sy + si)
        return np.sqrt(
            ((sx + si) * denom) * np.square(dxi)
            + ((sy + si) * denom) * np.square(dyi)
            - (si * denom) * np.square(dxy)
        )
    else:
        return min(dxi, dyi)


@numba.njit("types.UniTuple(i8, 2)(i8, i8)")
def _dist_index(node1: _numeric, node2: _numeric):
    if node1 > node2:
        node1, node2 = node2, node1
    return node1, node2


[docs] @numba.njit("types.Tuple((i8[:, :], f8[:]))(f8[:, :], i8)") def nn_chain(dists: np.ndarray, linkage_mode: int) -> Tuple[np.ndarray, np.ndarray]: """ Use the nearest neighbor chain algorithm to perform hierarchical clustering of a graph. :param dists: A NxN 2d numpy array of floats, the pairwise distances between each node in the graph. :param linkage_mode: The distance merging rule to use, see :py:class:`~diplomat.utils.clustering.ClusteringMethod` for available options. Impacts the results hierarchical clustering returns. :returns: Two numpy arrays. The first is an array of shape Nx2 of ints, being which nodes to merge in order to produce the clustering. The second array is an length N array of floats, being the distances of the edges merged at each step. """ assert dists.ndim == 2 assert dists.shape[0] == dists.shape[1] assert _CLUSTERING_METHOD_MIN <= linkage_mode <= _CLUSTERING_METHOD_MAX dists = dists.copy() node_count = len(dists) sizes = np.ones(node_count) stack = np.zeros(node_count, dtype=np.int64) chain_length: int = 0 nodes_and_merge = np.zeros((node_count - 1, 3), dtype=np.int64) merge_distances = np.zeros(node_count - 1, dtype=np.float64) for i in range(node_count - 1): # Chain is empty, add the first next valid cluster... if chain_length == 0: for j in range(node_count): if sizes[j] != 0: stack[0] = j chain_length = 1 break nn_current: int = 0 nn_next: int = 0 # Grow the nearest neighbor chain... while True: nn_current = stack[chain_length - 1] # Set to the past link if chain is long enough, this prevents cycles... nn_next = stack[chain_length - 2] if (chain_length >= 2) else nn_current current_min = ( dists[_dist_index(nn_current, nn_next)] if (chain_length >= 2) else np.inf ) # Find the nearest neighbor for this cluster... for k in range(node_count): # Check if a valid cluster index... if sizes[k] == 0 or k == nn_current: continue # If distance is closer, update next nearest neighbor... dist = dists[_dist_index(nn_current, k)] if dist < current_min: current_min = dist nn_next = k # If the next nearest neighbor is second one back on the stack, we've found the next set of # clusters to merge... if chain_length >= 2 and nn_next == stack[chain_length - 2]: break stack[chain_length] = nn_next chain_length += 1 # Pop next 2 nodes off the stack... chain_length -= 2 # Node 1 should be the smaller of the two, this is to make merging encoding consistent... if nn_current > nn_next: nn_current, nn_next = nn_next, nn_current n1_size = sizes[nn_current] n2_size = sizes[nn_next] # Write next cluster merge (solution) for this step... nodes_and_merge[i, 0] = nn_current nodes_and_merge[i, 1] = nn_next merge_distances[i] = current_min nodes_and_merge[i, 2] = n1_size + n2_size # Merge the two clusters by updating sizes... sizes[nn_current] = n1_size + n2_size sizes[nn_next] = 0 # Update distances for the new merged cluster... for l in range(node_count): l_size = sizes[l] if l_size == 0 or l == nn_current: continue dists[_dist_index(nn_current, l)] = _distance_update( linkage_mode, n1_size, n2_size, l_size, current_min, dists[_dist_index(nn_current, l)], dists[_dist_index(nn_next, l)], ) # Reorder by distance, stably... idx_order = np.argsort(merge_distances, kind="mergesort") return nodes_and_merge[idx_order], merge_distances[idx_order]
UnionFindType = np.ndarray @numba.njit("i8[:, :](i8)") def _new_union_find(size: int) -> UnionFindType: res = np.ones((2, size), dtype=np.int64) res[0, :] = np.arange(size, dtype=np.int64) return res @numba.njit("i8(i8[:, :], i8)") def _uf_compress(uf: UnionFindType, n: int) -> int: parents = uf[0] if n == parents[n]: return n else: root = _uf_compress(uf, parents[n]) parents[n] = root return root @numba.njit("i8(i8[:, :], i8)") def _uf_find(uf: UnionFindType, n: int): return _uf_compress(uf, n) @numba.njit("i8(i8[:, :], i8, i8)") def _uf_union(uf: UnionFindType, n1: int, n2: int) -> int: parents, sizes = uf n1 = _uf_find(uf, n1) n2 = _uf_find(uf, n2) if n1 == n2: return sizes[n1] n1_size = sizes[n1] n2_size = sizes[n2] if n1_size < n2_size: n1, n2 = n2, n1 merged_size = n1_size + n2_size sizes[n1] = merged_size sizes[n2] = merged_size parents[n2] = n1 return merged_size
[docs] @numba.njit("types.Tuple((i8[:], i8))(i8[:, :], f8[:], i8)") def get_components(merge_list: np.ndarray, distances: np.ndarray, num_components: int) -> Tuple[np.ndarray, int]: """ Get the components or clusters of set of nodes after performing hierarchical clustering, given a specific number of desired components to be returned. Returns clustering solution at that level. :param merge_list: Nodes to merge at each step of the clustering, first output of :py:func:`~diplomat.utils.clustering.nn_chain`. :param distances: Distances of each edge that was merged at each step of clustering, second output of :py:func:`~diplomat.utils.clustering.nn_chain`. :param num_components: The number of components, or cluster, desired. Stops merging nodes once it reaches this many clusters. :returns: Two values, an array of integers giving the component each node belongs to, and an integer giving the total number of components in the final result. """ assert merge_list.ndim == 2 assert distances.ndim == 1 assert merge_list.shape[1] == 3 assert merge_list.shape[0] >= 1 and merge_list.shape[0] == distances.shape[0] size = merge_list.shape[0] + 1 assert 0 < num_components num_components = min(num_components, size) components = np.full(size, fill_value=-1, dtype=np.int64) iters = size - num_components uf = _new_union_find(size) # Merge nodes until number of components matches desired amount... for i in range(iters): _uf_union(uf, merge_list[i, 0], merge_list[i, 1]) # For every node, find it's root. If it's root # has not been assigned a component index, assign it one # and set the component to the same as the root. component_index = 0 for j in range(size): root = _uf_find(uf, j) if components[root] < 0: components[root] = component_index component_index += 1 components[j] = components[root] return components, num_components