flashkda-delta-attention
High-performance Kimi Delta Attention CUDA kernels built on CUTLASS for efficient recurrent state inference
下記のコマンドをコピーしてターミナル(Mac/Linux)または PowerShell(Windows)に貼り付けてください。 ダウンロード → 解凍 → 配置まで全自動。
mkdir -p ~/.claude/skills && cd ~/.claude/skills && curl -L -o flashkda-delta-attention.zip https://jpskill.com/download/22945.zip && unzip -o flashkda-delta-attention.zip && rm flashkda-delta-attention.zip
$d = "$env:USERPROFILE\.claude\skills"; ni -Force -ItemType Directory $d | Out-Null; iwr https://jpskill.com/download/22945.zip -OutFile "$d\flashkda-delta-attention.zip"; Expand-Archive "$d\flashkda-delta-attention.zip" -DestinationPath $d -Force; ri "$d\flashkda-delta-attention.zip"
完了後、Claude Code を再起動 → 普通に「動画プロンプト作って」のように話しかけるだけで自動発動します。
💾 手動でダウンロードしたい(コマンドが難しい人向け)
- 1. 下の青いボタンを押して
flashkda-delta-attention.zipをダウンロード - 2. ZIPファイルをダブルクリックで解凍 →
flashkda-delta-attentionフォルダができる - 3. そのフォルダを
C:\Users\あなたの名前\.claude\skills\(Win)または~/.claude/skills/(Mac)へ移動 - 4. Claude Code を再起動
⚠️ ダウンロード・利用は自己責任でお願いします。当サイトは内容・動作・安全性について責任を負いません。
🎯 このSkillでできること
下記の説明文を読むと、このSkillがあなたに何をしてくれるかが分かります。Claudeにこの分野の依頼をすると、自動で発動します。
📦 インストール方法 (3ステップ)
- 1. 上の「ダウンロード」ボタンを押して .skill ファイルを取得
- 2. ファイル名の拡張子を .skill から .zip に変えて展開(macは自動展開可)
- 3. 展開してできたフォルダを、ホームフォルダの
.claude/skills/に置く- · macOS / Linux:
~/.claude/skills/ - · Windows:
%USERPROFILE%\.claude\skills\
- · macOS / Linux:
Claude Code を再起動すれば完了。「このSkillを使って…」と話しかけなくても、関連する依頼で自動的に呼び出されます。
詳しい使い方ガイドを見る →- 最終更新
- 2026-05-18
- 取得日時
- 2026-05-18
- 同梱ファイル
- 1
📖 Claude が読む原文 SKILL.md(中身を展開)
この本文は AI(Claude)が読むための原文(英語または中国語)です。日本語訳は順次追加中。
FlashKDA Delta Attention Skill
Skill by ara.so — Daily 2026 Skills collection.
FlashKDA provides high-performance CUDA kernels for Kimi Delta Attention (KDA) built on CUTLASS. It targets SM90+ GPUs (H100/H20 class) and integrates as a drop-in backend for flash-linear-attention's chunk_kda operation.
Requirements
- GPU: SM90+ (H100, H20, or newer)
- CUDA 12.9+
- PyTorch 2.4+
- Python 3.8+
Installation
git clone https://github.com/MoonshotAI/FlashKDA.git flash-kda
cd flash-kda
git submodule update --init --recursive
pip install -v .
Install the FLA integration (optional but recommended):
pip install -U flash-linear-attention # >= 0.5.0
Core Kernel API
flash_kda.fwd
The primary low-level kernel call:
import torch
import flash_kda
flash_kda.fwd(
q, k, v, g, beta, scale, out,
A_log, dt_bias, lower_bound,
initial_state=None,
final_state=None,
cu_seqlens=None
)
Tensor shapes and dtypes:
| Parameter | Dtype | Shape | Notes |
|---|---|---|---|
q |
bf16 | [B, T, H, K] |
Query; K must be 128 |
k |
bf16 | [B, T, H, K] |
Key; K must be 128 |
v |
bf16 | [B, T, H, V] |
Value; V must be 128 |
g |
bf16 | [B, T, H, K] |
Gate logits (sigmoid/activation applied internally) |
beta |
bf16 | [B, T, H] |
Beta logits (sigmoid applied internally) |
scale |
float | scalar | Attention scale factor |
out |
bf16 | [B, T, H, V] |
Pre-allocated output tensor |
A_log |
fp32 | [H] |
Per-head log-gate parameter |
dt_bias |
fp32 | [H, K] |
Per-head gate bias |
lower_bound |
float | scalar | Gate lower bound, range [-5.0, 0] |
initial_state |
bf16/fp32/None | [B, H, V, K] or [N, H, V, K] |
Optional initial recurrent state |
final_state |
bf16/fp32/None | [B, H, V, K] or [N, H, V, K] |
Optional output final state |
cu_seqlens |
int64 | [N+1] |
Optional cumulative seq lengths for varlen |
Constraints:
K == V == 128required- When
cu_seqlensis provided,Bmust be 1 andTis total tokens across all sequences initial_stateandfinal_statedtypes must match when both provided
Usage via flash-linear-attention Backend (Recommended)
FlashKDA auto-dispatches from FLA's chunk_kda when installed:
import torch
import logging
from fla.ops.kda import chunk_kda
# Optional: see dispatch decisions
logging.basicConfig(level=logging.INFO)
B, T, H, K, V = 2, 2048, 16, 128, 128
q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(B, T, H, dtype=torch.bfloat16, device='cuda')
A_log = torch.randn(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
h0 = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5
lower_bound = -5.0
with torch.inference_mode():
out, final_state = chunk_kda(
q=q, k=k, v=v, g=g, beta=beta,
scale=scale,
initial_state=h0,
output_final_state=True,
use_gate_in_kernel=True,
use_qk_l2norm_in_kernel=True,
use_beta_sigmoid_in_kernel=True,
safe_gate=True,
A_log=A_log,
dt_bias=dt_bias,
lower_bound=lower_bound,
transpose_state_layout=True,
)
# out: [B, T, H, V], final_state: [B, H, V, K]
Direct Low-Level Kernel Usage
import torch
import flash_kda
def run_flash_kda(
q, k, v, g, beta,
A_log, dt_bias,
lower_bound=-5.0,
initial_state=None,
):
B, T, H, K = q.shape
V = v.shape[-1]
scale = K ** -0.5
out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device=q.device)
final_state = torch.zeros(B, H, V, K, dtype=torch.float32, device=q.device)
flash_kda.fwd(
q, k, v, g, beta,
scale, out,
A_log, dt_bias, lower_bound,
initial_state=initial_state,
final_state=final_state,
cu_seqlens=None,
)
return out, final_state
B, T, H, K = 1, 4096, 8, 128
device = 'cuda'
dtype = torch.bfloat16
q = torch.randn(B, T, H, K, device=device, dtype=dtype)
k = torch.randn(B, T, H, K, device=device, dtype=dtype)
v = torch.randn(B, T, H, K, device=device, dtype=dtype) # V==K==128
g = torch.randn(B, T, H, K, device=device, dtype=dtype)
beta = torch.randn(B, T, H, device=device, dtype=dtype)
A_log = torch.full((H,), -0.1, device=device, dtype=torch.float32)
dt_bias = torch.zeros(H, K, device=device, dtype=torch.float32)
with torch.inference_mode():
out, state = run_flash_kda(q, k, v, g, beta, A_log, dt_bias)
print(out.shape) # [1, 4096, 8, 128]
print(state.shape) # [1, 8, 128, 128]
Variable-Length (Packed) Batching
Use cu_seqlens for variable-length sequences packed into a single batch dimension:
import torch
import flash_kda
# Two sequences of lengths 512 and 768, packed together
seq_lens = [512, 768]
T_total = sum(seq_lens)
N = len(seq_lens)
H, K, V = 16, 128, 128
cu_seqlens = torch.tensor([0, 512, 1280], dtype=torch.int64, device='cuda')
# B must be 1 for varlen mode
q = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(1, T_total, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(1, T_total, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(1, T_total, H, dtype=torch.bfloat16, device='cuda')
A_log = torch.zeros(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
out = torch.empty(1, T_total, H, V, dtype=torch.bfloat16, device='cuda')
# State shape is [N, H, V, K] in varlen mode
final_state = torch.zeros(N, H, V, K, dtype=torch.float32, device='cuda')
scale = K ** -0.5
with torch.inference_mode():
flash_kda.fwd(
q, k, v, g, beta,
scale, out,
A_log, dt_bias, lower_bound=-5.0,
initial_state=None,
final_state=final_state,
cu_seqlens=cu_seqlens,
)
print(out.shape) # [1, 1280, 16, 128]
print(final_state.shape) # [2, 16, 128, 128]
Stateful Inference (Multi-turn / Streaming)
Pass initial_state from a previous call to maintain recurrent state across chunks:
import torch
import flash_kda
H, K, V = 16, 128, 128
B = 2
scale = K ** -0.5
def inference_step(q, k, v, g, beta, A_log, dt_bias, state=None):
T = q.shape[1]
out = torch.empty(B, T, H, V, dtype=torch.bfloat16, device='cuda')
new_state = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
flash_kda.fwd(
q, k, v, g, beta, scale, out,
A_log, dt_bias, lower_bound=-5.0,
initial_state=state,
final_state=new_state,
cu_seqlens=None,
)
return out, new_state
A_log = torch.zeros(H, dtype=torch.float32, device='cuda')
dt_bias = torch.zeros(H, K, dtype=torch.float32, device='cuda')
state = None
for chunk_idx in range(4):
q = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
k = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
v = torch.randn(B, 256, H, V, dtype=torch.bfloat16, device='cuda')
g = torch.randn(B, 256, H, K, dtype=torch.bfloat16, device='cuda')
beta = torch.randn(B, 256, H, dtype=torch.bfloat16, device='cuda')
with torch.inference_mode():
out, state = inference_step(q, k, v, g, beta, A_log, dt_bias, state)
print(f"Chunk {chunk_idx}: out={out.shape}, state={state.shape}")
Configuration & Environment Variables
| Variable | Values | Effect |
|---|---|---|
FLA_FLASH_KDA |
0 / 1 |
Set to 0 to force Triton fallback in FLA |
# Disable FlashKDA, use Triton path
FLA_FLASH_KDA=0 python your_script.py
Running Tests
bash tests/test.sh
tests/test_fwd.py— correctness tests against PyTorch reference and flash-linear-attention
Common Patterns & Troubleshooting
Check dispatch logging
import logging
logging.basicConfig(level=logging.INFO)
# Successful: [FLA Backend] kda.chunk_kda -> flashkda
# Rejected: [FLA Backend] kda.chunk_kda rejected: <reason>
Verify GPU compatibility
import torch
cap = torch.cuda.get_device_capability()
assert cap >= (9, 0), f"FlashKDA requires SM90+, got SM{cap[0]}{cap[1]}"
K and V must be 128
# WRONG — will error
q = torch.randn(1, 512, 8, 64, ...) # K=64 not supported
# CORRECT
q = torch.randn(1, 512, 8, 128, ...) # K=128 required
Use torch.inference_mode() not torch.no_grad()
# FlashKDA requires inference_mode for FLA dispatch
with torch.inference_mode():
out, state = chunk_kda(...)
State dtype consistency
# initial_state and final_state must have matching dtypes
initial = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda')
final = torch.zeros(B, H, V, K, dtype=torch.float32, device='cuda') # must match
# bf16 initial + fp32 final → error
lower_bound valid range
lower_bound = -5.0 # valid: range is [-5.0, 0]
lower_bound = -2.5 # valid
lower_bound = 0.0 # valid boundary
lower_bound = -10.0 # out of spec — use -5.0 as safe minimum
IntelliSense / clangd setup for development
bash setup_clangd.sh
# Generates .clangd with correct include paths for CUDA/CUTLASS sources