Skip to content

Attention API

FlashAttention

Flash Attention 2 implementation with automatic fallback.

from legionheto.attention import FlashAttention

attn = FlashAttention(
    embed_dim=4096,
    num_heads=32,
    num_key_value_heads=8,
    dropout=0.0,
)

output, past_kv = attn(hidden_states, attention_mask)

DeepSeekMLA

DeepSeek Multi-head Latent Attention.

from legionheto.attention import DeepSeekMLA, MLAConfig

config = MLAConfig(
    embed_dim=4096,
    num_heads=32,
    kv_lora_rank=512,
    q_lora_rank=1536,
)

mla = DeepSeekMLA(config)
output, past_kv = mla(hidden_states)

MemoryEfficientAttention

Chunked attention for long sequences.

from legionheto.attention import MemoryEfficientAttention

attn = MemoryEfficientAttention(
    embed_dim=4096,
    num_heads=32,
    max_seq_len=8192,
)

Utility Functions

has_flash_attn

Check if Flash Attention is available.

from legionheto.attention import has_flash_attn

if has_flash_attn():
    print("Flash Attention available")

get_attention_backend

Get current attention backend.

from legionheto.attention import get_attention_backend

backend = get_attention_backend()  # "flash_attn", "sdpa", or "eager"