+
Skip to content

flashinfer-ai/flashinfer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlashInfer

Kernel Library for LLM Serving

| Blog | Documentation | Slack | Discussion Forum |

Build Status Documentation

FlashInfer is a library and kernel generator for Large Language Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, SparseAttention, PageAttention, Sampling, and more. FlashInfer focuses on LLM serving and inference, and delivers state-of-the-art performance across diverse scenarios.

Check our v0.2 release blog for new features!

The core features of FlashInfer include:

  1. Efficient Sparse/Dense Attention Kernels: Efficient single/batch attention for sparse(paged)/dense KV-storage on CUDA Cores and Tensor Cores (both FA2 & FA3) templates. The vector-sparse attention can achieve 90% of the bandwidth of dense kernels with same problem size.
  2. Load-Balanced Scheduling: FlashInfer decouples plan/run stage of attention computation where we schedule the computation of variable-length inputs in plan stage to alleviate load-imbalance issue.
  3. Memory Efficiency: FlashInfer offers Cascade Attention for hierarchical KV-Cache, and implements Head-Query fusion for accelerating Grouped-Query Attention, and efficient kernels for low-precision attention and fused-RoPE attention for compressed KV-Cache.
  4. Customizable Attention: Bring your own attention variants through JIT-compilation.
  5. CUDAGraph and torch.compile Compatibility: FlashInfer kernels can be captured by CUDAGraphs and torch.compile for low-latency inference.
  6. Efficient LLM-specific Operators: High-Performance fused kernel for Top-P, Top-K/Min-P sampling without the need to sorting.

FlashInfer supports PyTorch, TVM and C++ (header-only) APIs, and can be easily integrated into existing projects.

News

  • [Mar 10, 2025] Blog Post Sorting-Free GPU Kernels for LLM Sampling, which explains the design of sampling kernels in FlashInfer.
  • [Mar 1, 2025] Checkout flashinfer's intra-kernel profiler for visualizing the timeline of each threadblock in GPU kernels.
  • [Dec 16, 2024] Blog Post FlashInfer 0.2 - Efficient and Customizable Kernels for LLM Inference Serving
  • [Sept 2024] We've launched a Slack workspace for Flashinfer users and developers. Join us for timely support, discussions, updates and knowledge sharing!
  • [Jan 31, 2024] Blog Post Cascade Inference: Memory-Efficient Shared Prefix Batch Decoding
  • [Jan 31, 2024] Blog Post Accelerating Self-Attentions for LLM Serving with FlashInfer

Getting Started

Using our PyTorch API is the easiest way to get started:

Install from PyPI

FlashInfer is available as a Python package for Linux. Install the core package with:

pip install flashinfer-python

Package Options:

  • flashinfer-python: Core package that compiles/downloads kernels on first use
  • flashinfer-cubin: Pre-compiled kernel binaries for all supported GPU architectures
  • flashinfer-jit-cache: Pre-built kernel cache for specific CUDA versions

For faster initialization and offline usage, install the optional packages to have most kernels pre-compiled:

pip install flashinfer-python flashinfer-cubin
# JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130)
pip install flashinfer-jit-cache --index-url https://flashinfer.ai/whl/cu129

This eliminates compilation and downloading overhead at runtime.

Install from Source

Build the core package from source:

git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
cd flashinfer
python -m pip install -v .

For development, install in editable mode:

python -m pip install --no-build-isolation -e . -v

Build optional packages:

flashinfer-cubin:

cd flashinfer-cubin
python -m build --no-isolation --wheel
python -m pip install dist/*.whl

flashinfer-jit-cache (customize FLASHINFER_CUDA_ARCH_LIST for your target GPUs):

export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a"
cd flashinfer-jit-cache
python -m build --no-isolation --wheel
python -m pip install dist/*.whl

For more details, see the Install from Source documentation.

Install Nightly Build

Nightly builds are available for testing the latest features:

# Core and cubin packages
pip install -U --pre flashinfer-python --index-url https://flashinfer.ai/whl/nightly/ --no-deps # Install the nightly package from custom index, without installing dependencies
pip install flashinfer-python  # Install flashinfer-python's dependencies from PyPI
pip install -U --pre flashinfer-cubin --index-url https://flashinfer.ai/whl/nightly/
# JIT cache package (replace cu129 with your CUDA version: cu128, cu129, or cu130)
pip install -U --pre flashinfer-jit-cache --index-url https://flashinfer.ai/whl/nightly/cu129

Verify Installation

After installation, verify that FlashInfer is correctly installed and configured:

flashinfer show-config

This command displays:

  • FlashInfer version and installed packages (flashinfer-python, flashinfer-cubin, flashinfer-jit-cache)
  • PyTorch and CUDA version information
  • Environment variables and artifact paths
  • Downloaded cubin status and module compilation status

Trying it out

Below is a minimal example of using FlashInfer's single-request decode/append/prefill attention kernels:

import torch
import flashinfer

kv_len = 2048
num_kv_heads = 32
head_dim = 128

k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)
v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)

# decode attention

num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="ROPE_LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
qo_len = 2048
q = torch.randn(qo_len, num_qo_heads, head_dim).half().to(0) # prefill attention
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill attention without RoPE on-the-fly, do not apply causal mask

Check out documentation for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels.

Custom Attention Variants

Starting from FlashInfer v0.2, users can customize their own attention variants with additional parameters. For more details, refer to our JIT examples.

GPU Support

FlashInfer currently provides support for NVIDIA SM architectures 75 and higher and beta support for 103, 110, 120, and 121.

Adoption

We are thrilled to share that FlashInfer is being adopted by many cutting-edge projects, including but not limited to:

Acknowledgement

FlashInfer is inspired by FlashAttention 1&2, vLLM, stream-K, cutlass and AITemplate projects.

Citation

If you find FlashInfer helpful in your project or research, please consider citing our paper:

@article{ye2025flashinfer,
    title = {FlashInfer: Efficient and Customizable Attention Engine for LLM Inference Serving},
    author = {
      Ye, Zihao and
      Chen, Lequn and
      Lai, Ruihang and
      Lin, Wuwei and
      Zhang, Yineng and
      Wang, Stephanie and
      Chen, Tianqi and
      Kasikci, Baris and
      Grover, Vinod and
      Krishnamurthy, Arvind and
      Ceze, Luis
    },
    journal = {arXiv preprint arXiv:2501.01005},
    year = {2025},
    url = {https://arxiv.org/abs/2501.01005}
}
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载