KMeans Cluster¤
Since KMeans Implemented with stamox uses brute force method, it will be slow for large data
stamox.cluster.kmeans(x: ArrayLike, n_cluster: int, restarts: int = 10, max_iters: int = 100, dtype: dtype = <class 'jax.numpy.float32'>, *, key: Union[jax.Array, jax._src.prng.PRNGKeyArray] = None)
¤
    Runs the K-means clustering algorithm on a given dataset.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
x | 
        ArrayLike | 
        The dataset to be clustered.  | 
        required | 
n_cluster | 
        int | 
        The number of clusters to generate.  | 
        required | 
restarts | 
        int | 
        The number of restarts for the algorithm. Defaults to 10.  | 
        10 | 
      
max_iters | 
        int | 
        The maximum number of iterations for the algorithm. Defaults to 100.  | 
        100 | 
      
dtype | 
        jnp.dtype | 
        The data type of the output. Defaults to jnp.float32.  | 
        <class 'jax.numpy.float32'> | 
      
key | 
        KeyArray | 
        A key array used for encryption. Defaults to None.  | 
        None | 
      
Returns:
| Type | Description | 
|---|---|
KMeansState | 
      An object containing the results of the clustering algorithm.  | 
    
Examples:
>>> from jax import random
>>> from stamox.functions import kmeans
>>> key = random.PRNGKey(0)
>>> x = random.normal(key, shape=(100, 2))
>>> state = kmeans(x, n_cluster=3, restarts=5, max_iters=50, key=key)
>>> state.centers
Array([[ 0.8450022 , -1.0791471 ],
             [-0.7179966 ,  0.6372063 ],
             [ 0.09818084, -0.25906876]], dtype=float32)
        
stamox.cluster.KMeansState        
¤
    KMeansState class for K-means clustering.
Attributes:
| Name | Type | Description | 
|---|---|---|
n_clusters | 
        int | 
        Number of clusters.  | 
      
centers | 
        ArrayLike | 
        Centers of the clusters.  | 
      
cluster | 
        ArrayLike | 
        Cluster labels for each point.  | 
      
iters | 
        int | 
        Number of iterations.  | 
      
totss | 
        float | 
        Total sum of squares.  | 
      
betwss | 
        float | 
        Between sum of squares.  | 
      
withinss | 
        float | 
        Within sum of squares.  | 
      
tot_withinss | 
        float | 
        Total within sum of squares.  |