For decades, k-means has been treated as a one-shot, offline preprocessing step — run it once, store the results, and move forward. A collaborative team from UC Berkeley and UT Austin has introduced Flash-KMeans, an open-source library designed for a fundamentally different use case. In today’s AI workflows, k-means gets invoked repeatedly inside both training and inference loops. When that happens, the latency of each individual call becomes far more critical than raw floating-point throughput.
Flash-KMeans is an IO-aware re-implementation of classic Lloyd’s k-means. The underlying mathematics remain untouched, and no approximations are introduced. The gains come entirely from restructuring how data flows through GPU memory. On an NVIDIA H200, the team observed end-to-end speedups reaching 17.9× compared to the strongest existing baselines. Against NVIDIA cuML, the figure climbs to 33×. Against FAISS, it exceeds 200×.
What is Flash-KMeans
Flash-KMeans is a batched k-means library implemented as Triton GPU kernels. It is released under the Apache 2.0 license and can be installed with pip install flash-kmeans.
The results are mathematically indistinguishable from standard Lloyd’s k-means. The performance improvement stems from how data moves at the kernel level, not from skipping any computation. This sets it apart from algorithmic shortcuts such as triangle-inequality pruning or coreset-based sampling.
A typical Lloyd iteration consists of two phases. During the assignment phase, the distance from every data point to each centroid is computed, and the closest centroid is selected. During the update phase, the points belonging to each cluster are averaged to produce new centroids. Both phases involve straightforward arithmetic. On GPUs, however, the bottleneck in both cases is memory bandwidth rather than computation.
The Two Bottlenecks It Attacks
The first bottleneck lies in the assignment phase. Conventional implementations construct a full distance matrix D of dimensions N×K in High Bandwidth Memory (HBM). The matrix is written out and then read back to perform the argmin operation. For N=65536, K=1024, d=128, B=32, the distance calculations themselves take about 2.6ms. Writing and subsequently reading D consumes roughly 23ms. The matrix traffic is the real cost, not the arithmetic.
Flash-KMeans addresses this with FlashAssign, a design inspired by FlashAttention. FlashAssign streams tiles of points and centroids from HBM into on-chip SRAM, fusing distance computation with an online argmin. The full N×K matrix is never materialized in memory. This reduces the dominant IO complexity from O(NK) to O(Nd + Kd). At the kernel level, FlashAssign achieves up to 21.2× speedup. In one test, it slashed assignment time from 122.5ms to 5.8ms.
The second bottleneck is the centroid update phase. Conventional approaches rely on scatter-style atomic additions. Each thread contributes its data point into a shared sum buffer indexed by cluster ID. When many threads simultaneously target the same “hot” cluster, atomic contention arises, causing hardware serialization. The team measured only 50 GB/s of effective bandwidth on an H200 during this phase.
Flash-KMeans replaces this with a technique called Sort-Inverse Update. It begins by sorting the 1D assignment vector by cluster ID using argsort. Identical cluster IDs then appear as contiguous segments. Each thread block performs an on-chip reduction over one segment and issues a single atomic add per segment. The large point matrix is never physically rearranged. Atomic operations drop from . The kernel reaches up to 6.3×.
Benchmark
The team evaluated performance on an H200 with CUDA 12.8, using FP16 data and d=128. They varied N, K, and batch size B across experiments. Four optimized baselines were used for comparison: fast_pytorch_kmeans, fastkmeans, cuML, and FAISS.
| Comparison | Reported speedup | Workload context |
|---|---|---|
| End-to-end vs best baseline | up to 17.9× | N=8M, K=1024 (large N, small K) |
| vs NVIDIA cuML | 33× | industry library |
| vs FAISS | over 200× | industry library |
| FlashAssign kernel | up to 21.2× | N=1M, K=8192 (assignment) |
| Sort-Inverse Update kernel | up to 6.3× | N=33M, K=4096 (update) |
| Out-of-core, large scale | up to 10.5× | N=400M, K=16384 vs fastkmeans |
One notable limitation of existing approaches is worth highlighting. Standard PyTorch implementations run out of memory when K is large, because they cannot allocate the N×K matrix. FAISS serves as the industry-standard library underlying many production vector-search systems.
The library also supports out-of-core execution. On a dataset of one billion points (K=32768, d=128), it completes a single iteration in 41.4 seconds, compared to 261.8 seconds for the baseline. It leverages chunked stream overlap to hide PCIe transfer latency behind active computation. A cache-aware compile heuristic further reduces tuning overhead by up to 175×, achieving within 0.3% of fully tuned performance.
MTP Interactive Explainer
Marktechpost · Interactive Explainer
Flash-KMeans: exact k-means, rebuilt around GPU memory
Same Lloyd’s math as standard k-means — faster only because of dataflow. Run clustering live, watch the update bottleneck, and size the IO it removes.
17.9×end-to-end vs best baseline
33×vs NVIDIA cuML
200×+vs FAISS
1Bpoints, out-of-core
1 · Live clustering
2 · Update contention
3 · IO calculator
Iteration0
Centroid shift—
Statusidle
This runs real Lloyd’s k-means in your browser on 2-D points. The algorithm is identical to what Flash-KMeans accelerates — only the GPU dataflow differs. Each step = one assignment + one centroid update.
Hit play to see the timeline. With standard scatter updates, multiple blocks writing to the same “hot” centroid cause stalls (shown in red). Sort-Inverse Update avoids this by sorting cluster IDs first, allowing each block to merge continuous segments using a single atomic add — eliminating conflicts entirely.
Standard atomicsO(N·d)
Sort-Inverse atomicsO((K+N/B)·d)
Measured std bandwidth50 GB/s
Kernel speedup6.3×
Standard methods perform one atomic add for every token. When many threads target the same centroid simultaneously, it creates bottlenecks. By sorting according to cluster ID, scattered writes become efficient segment-level reductions handled directly in on-chip memory.
—less HBM traffic for the assignment step (theoretical)
Use Cases
Accelerating exact k-means opens the door to real-time applications, not just batch processing.
- Vector search indexing: FAISS relies on k-means to build its search indices. With faster k-means, you can refresh indexes dynamically as your data evolves, rather than waiting for overnight rebuilds.
- Sparse attention routing: Models like Routing Transformers and Tactic use token clustering to direct attention. Millisecond-level k-means makes this feasible within the inference pipeline itself.
- KV-cache compression: ClusterKV groups tokens semantically to shrink the KV cache. More efficient clustering enables per-layer, per-step compression without significant overhead.
- Low-bit KV quantization: Modern techniques repeatedly cluster KV entries into codebooks. Speeding up clustering dramatically reduces preprocessing time.
- Diffusion Transformers: Sparse VideoGen2 applies batched k-means during forward passes. It reorders tokens based on semantic similarity to maximize sparsity benefits.
Using It
The API is designed to be familiar, matching conventions from faiss and sklearn. The example below clusters a batched tensor of shape (B, N, d).
import torch
from flash_kmeans import batch_kmeans_Euclid
x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(
x, n_clusters=1000, tol=1e-4, verbose=True
)There’s also a scikit-learn-compatible interface.
from flash_kmeans import FlashKMeans
km = FlashKMeans(d=128, k=8192, niter=100)
labels = km.fit_predict(large_cpu_tensor) # device=None uses all visible GPUsThe kernel automatically selects the optimal execution path based on tensor shape and data type. A specialized small-D path handles dimensions up to 512. For higher dimensions, a split-D approach avoids creating the full distance matrix. When working with large datasets in CPU memory, multi-GPU execution kicks in automatically.
Key Takeaways
- Flash-KMeans delivers exact results, not approximations — it uses the same Lloyd’s algorithm, accelerated entirely through optimized GPU dataflow.
- FlashAssign combines distance computation with online argmin, reducing assignment I/O from O(NK) to O(Nd+Kd) — achieving up to 21.2× improvement.
- Sort-Inverse Update organizes cluster IDs into contiguous segments, replacing scattered atomic operations — delivering up to 6.3× speedup.
- Achieves up to 17.9× end-to-end acceleration, 33× faster than cuML, and over 200× faster than FAISS on an H200 GPU.
- Supports out-of-core processing for datasets up to one billion points and reduces tuning time by as much as 175×.
Explore the Paper and Repo. Also, follow us on Twitter, join our 150k+ ML SubReddit, and subscribe to our Newsletter. Are you on Telegram? You can now join us there too!
Interested in partnering with us to promote your GitHub repo, Hugging Face page, product launch, or webinar? Get in touch



