from typing import Union, Tuple
import numba
import numpy as np
from enum import IntEnum
_numeric = Union[np.ndarray, int]
[docs]
class ClusteringMethod(IntEnum):
SINGLE = 0
COMPLETE = 1
AVERAGE = 2
WEIGHTED = 3
WARD = 4
_CLUSTERING_METHOD_MAX = max(ClusteringMethod)
_CLUSTERING_METHOD_MIN = min(ClusteringMethod)
dist_update_func_sig = numba.float64(
numba.int64,
numba.int64,
numba.int64,
numba.int64,
numba.float64,
numba.float64,
numba.float64
)
@numba.njit(dist_update_func_sig)
def distance_update(mode: int, sx: int, sy: int, si: int, dxy: float, dxi: float, dyi: float) -> float:
if(mode == ClusteringMethod.SINGLE):
return min(dxi, dyi)
elif(mode == ClusteringMethod.COMPLETE):
return max(dxi, dyi)
elif(mode == ClusteringMethod.AVERAGE):
return (sx * dxi + sy * dyi) / (sx + sy)
elif(mode == ClusteringMethod.WEIGHTED):
return 0.5 * (dxi + dyi)
elif(mode == ClusteringMethod.WARD):
denom = 1.0 / (sx + sy + si)
return np.sqrt(
((sx + si) * denom) * np.square(dxi)
+ ((sy + si) * denom) * np.square(dyi)
- (si * denom) * np.square(dxy)
)
else:
return min(dxi, dyi)
@numba.njit("types.UniTuple(i8, 2)(i8, i8)")
def dist_index(node1: _numeric, node2: _numeric):
if(node1 > node2):
node1, node2 = node2, node1
return node1, node2
@numba.njit("types.Tuple((i8[:, :], f8[:]))(f8[:, :], i8)")
def nn_chain(dists: np.ndarray, linkage_mode: int):
assert dists.ndim == 2
assert dists.shape[0] == dists.shape[1]
assert _CLUSTERING_METHOD_MIN <= linkage_mode <= _CLUSTERING_METHOD_MAX
dists = dists.copy()
node_count = len(dists)
sizes = np.ones(node_count)
stack = np.zeros(node_count, dtype=np.int64)
chain_length: int = 0
nodes_and_merge = np.zeros((node_count - 1, 3), dtype=np.int64)
merge_distances = np.zeros(node_count - 1, dtype=np.float64)
for i in range(node_count - 1):
# Chain is empty, add the first next valid cluster...
if(chain_length == 0):
for j in range(node_count):
if sizes[j] != 0:
stack[0] = j
chain_length = 1
break
nn_current: int = 0
nn_next: int = 0
# Grow the nearest neighbor chain...
while(True):
nn_current = stack[chain_length - 1]
# Set to the past link if chain is long enough, this prevents cycles...
nn_next = stack[chain_length - 2] if(chain_length >= 2) else nn_current
current_min = dists[dist_index(nn_current, nn_next)] if(chain_length >= 2) else np.inf
# Find the nearest neighbor for this cluster...
for k in range(node_count):
# Check if a valid cluster index...
if(sizes[k] == 0 or k == nn_current):
continue
# If distance is closer, update next nearest neighbor...
dist = dists[dist_index(nn_current, k)]
if(dist < current_min):
current_min = dist
nn_next = k
# If the next nearest neighbor is second one back on the stack, we've found the next set of
# clusters to merge...
if(chain_length >= 2 and nn_next == stack[chain_length - 2]):
break
stack[chain_length] = nn_next
chain_length += 1
# Pop next 2 nodes off the stack...
chain_length -= 2
# Node 1 should be the smaller of the two, this is to make merging encoding consistent...
if(nn_current > nn_next):
nn_current, nn_next = nn_next, nn_current
n1_size = sizes[nn_current]
n2_size = sizes[nn_next]
# Write next cluster merge (solution) for this step...
nodes_and_merge[i, 0] = nn_current
nodes_and_merge[i, 1] = nn_next
merge_distances[i] = current_min
nodes_and_merge[i, 2] = n1_size + n2_size
# Merge the two clusters by updating sizes...
sizes[nn_current] = n1_size + n2_size
sizes[nn_next] = 0
# Update distances for the new merged cluster...
for l in range(node_count):
l_size = sizes[l]
if(l_size == 0 or l == nn_current):
continue
dists[dist_index(nn_current, l)] = distance_update(
linkage_mode,
n1_size,
n2_size,
l_size,
current_min,
dists[dist_index(nn_current, l)],
dists[dist_index(nn_next, l)]
)
# Reorder by distance, stably...
idx_order = np.argsort(merge_distances, kind="mergesort")
return nodes_and_merge[idx_order], merge_distances[idx_order]
UnionFindType = np.ndarray
@numba.njit("i8[:, :](i8)")
def _new_union_find(size: int) -> UnionFindType:
res = np.ones((2, size), dtype=np.int64)
res[0, :] = np.arange(size, dtype=np.int64)
return res
@numba.njit("i8(i8[:, :], i8)")
def _uf_compress(uf: UnionFindType, n: int) -> int:
parents = uf[0]
if(n == parents[n]):
return n
else:
root = _uf_compress(uf, parents[n])
parents[n] = root
return root
@numba.njit("i8(i8[:, :], i8)")
def _uf_find(uf: UnionFindType, n: int):
return _uf_compress(uf, n)
@numba.njit("i8(i8[:, :], i8, i8)")
def _uf_union(uf: UnionFindType, n1: int, n2: int) -> int:
parents, sizes = uf
n1 = _uf_find(uf, n1)
n2 = _uf_find(uf, n2)
if(n1 == n2):
return sizes[n1]
n1_size = sizes[n1]
n2_size = sizes[n2]
if(n1_size < n2_size):
n1, n2 = n2, n1
merged_size = n1_size + n2_size
sizes[n1] = merged_size
sizes[n2] = merged_size
parents[n2] = n1
return merged_size
@numba.njit("types.Tuple((i8[:], i8))(i8[:, :], f8[:], i8)")
def get_components(merge_list: np.ndarray, distances: np.ndarray, num_components: int):
assert merge_list.ndim == 2
assert distances.ndim == 1
assert merge_list.shape[1] == 3
assert merge_list.shape[0] >= 1 and merge_list.shape[0] == distances.shape[0]
size = merge_list.shape[0] + 1
assert 0 < num_components
num_components = min(num_components, size)
components = np.full(size, fill_value=-1, dtype=np.int64)
iters = size - num_components
uf = _new_union_find(size)
# Merge nodes until number of components matches desired amount...
for i in range(iters):
_uf_union(uf, merge_list[i, 0], merge_list[i, 1])
# For every node, find it's root. If it's root
# has not been assigned a component index, assign it one
# and set the component to the same as the root.
component_index = 0
for j in range(size):
root = _uf_find(uf, j)
if(components[root] < 0):
components[root] = component_index
component_index += 1
components[j] = components[root]
return components, num_components