Skip to content

KMeans Cluster¤

Since KMeans Implemented with stamox uses brute force method, it will be slow for large data

stamox.cluster.kmeans(x: ArrayLike, n_cluster: int, restarts: int = 10, max_iters: int = 100, dtype: dtype = <class 'jax.numpy.float32'>, *, key: Union[jax.Array, jax._src.prng.PRNGKeyArray] = None) ¤

Runs the K-means clustering algorithm on a given dataset.

Parameters:

Name Type Description Default
x ArrayLike

The dataset to be clustered.

required
n_cluster int

The number of clusters to generate.

required
restarts int

The number of restarts for the algorithm. Defaults to 10.

10
max_iters int

The maximum number of iterations for the algorithm. Defaults to 100.

100
dtype jnp.dtype

The data type of the output. Defaults to jnp.float32.

<class 'jax.numpy.float32'>
key KeyArray

A key array used for encryption. Defaults to None.

None

Returns:

Type Description
KMeansState

An object containing the results of the clustering algorithm.

Examples:

>>> from jax import random
>>> from stamox.functions import kmeans
>>> key = random.PRNGKey(0)
>>> x = random.normal(key, shape=(100, 2))
>>> state = kmeans(x, n_cluster=3, restarts=5, max_iters=50, key=key)
>>> state.centers
Array([[ 0.8450022 , -1.0791471 ],
             [-0.7179966 ,  0.6372063 ],
             [ 0.09818084, -0.25906876]], dtype=float32)

stamox.cluster.KMeansState ¤

KMeansState class for K-means clustering.

Attributes:

Name Type Description
n_clusters int

Number of clusters.

centers ArrayLike

Centers of the clusters.

cluster ArrayLike

Cluster labels for each point.

iters int

Number of iterations.

totss float

Total sum of squares.

betwss float

Between sum of squares.

withinss float

Within sum of squares.

tot_withinss float

Total within sum of squares.