KMeans Cluster¤
Since KMeans Implemented with stamox uses brute force method, it will be slow for large data
stamox.cluster.kmeans(x: ArrayLike, n_cluster: int, restarts: int = 10, max_iters: int = 100, dtype: dtype = <class 'jax.numpy.float32'>, *, key: Union[jax.Array, jax._src.prng.PRNGKeyArray] = None)
¤
Runs the K-means clustering algorithm on a given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ArrayLike |
The dataset to be clustered. |
required |
n_cluster |
int |
The number of clusters to generate. |
required |
restarts |
int |
The number of restarts for the algorithm. Defaults to 10. |
10 |
max_iters |
int |
The maximum number of iterations for the algorithm. Defaults to 100. |
100 |
dtype |
jnp.dtype |
The data type of the output. Defaults to jnp.float32. |
<class 'jax.numpy.float32'> |
key |
KeyArray |
A key array used for encryption. Defaults to None. |
None |
Returns:
Type | Description |
---|---|
KMeansState |
An object containing the results of the clustering algorithm. |
Examples:
>>> from jax import random
>>> from stamox.functions import kmeans
>>> key = random.PRNGKey(0)
>>> x = random.normal(key, shape=(100, 2))
>>> state = kmeans(x, n_cluster=3, restarts=5, max_iters=50, key=key)
>>> state.centers
Array([[ 0.8450022 , -1.0791471 ],
[-0.7179966 , 0.6372063 ],
[ 0.09818084, -0.25906876]], dtype=float32)
stamox.cluster.KMeansState
¤
KMeansState class for K-means clustering.
Attributes:
Name | Type | Description |
---|---|---|
n_clusters |
int |
Number of clusters. |
centers |
ArrayLike |
Centers of the clusters. |
cluster |
ArrayLike |
Cluster labels for each point. |
iters |
int |
Number of iterations. |
totss |
float |
Total sum of squares. |
betwss |
float |
Between sum of squares. |
withinss |
float |
Within sum of squares. |
tot_withinss |
float |
Total within sum of squares. |