Attention API
FlashAttention
Flash Attention 2 implementation with automatic fallback.
from legionheto.attention import FlashAttention
attn = FlashAttention(
embed_dim=4096,
num_heads=32,
num_key_value_heads=8,
dropout=0.0,
)
output, past_kv = attn(hidden_states, attention_mask)
DeepSeekMLA
DeepSeek Multi-head Latent Attention.
from legionheto.attention import DeepSeekMLA, MLAConfig
config = MLAConfig(
embed_dim=4096,
num_heads=32,
kv_lora_rank=512,
q_lora_rank=1536,
)
mla = DeepSeekMLA(config)
output, past_kv = mla(hidden_states)
MemoryEfficientAttention
Chunked attention for long sequences.
from legionheto.attention import MemoryEfficientAttention
attn = MemoryEfficientAttention(
embed_dim=4096,
num_heads=32,
max_seq_len=8192,
)
Utility Functions
has_flash_attn
Check if Flash Attention is available.
from legionheto.attention import has_flash_attn
if has_flash_attn():
print("Flash Attention available")
get_attention_backend
Get current attention backend.