官方資訊
Github: https://github.com/Dao-AILab/flash-attention
論文文檔: https://arxiv.org/abs/2205.14135
此存儲庫提供了以下論文中 FlashAttention 和 FlashAttention-2 的官方實現。可讓我們在建模時有更快的注意力,更好的並行度和工作分區。下面為一張概念示意圖:
官方所做的效能提升試驗結果如下:
甚麼是Flash Attention
Flash Attention 是一種注意力算法,旨在提高基於 Transformer 的模型的效率,使其能夠處理更長的序列長度並更快地進行訓練和推理。它通過減少計算量和內存使用來實現這一點。Flash Attention 是一種快速且內存高效的精確注意力機制,其設計考慮了IO(輸入輸出)的特性。
這項技術的關鍵點
快速 (Fast)
- 訓練BERT-large(序列長度512)比MLPerf 1.1中的訓練速度記錄快15%。
- 訓練GPT-2(序列長度1K)比HuggingFace和Megatron-LM的基準實現快3倍。
- 在long-range arena(序列長度1K-4K)中,比基準速度快2.4倍。
高效內存使用 (Memory-efficient)
- 傳統的注意力機制內存訪問量是O(N²),而Flash Attention的內存訪問量是亞二次方/線性的。
精確 (Exact)
- 這不是近似算法(例如稀疏或低秩矩陣方法),其結果與原始方法完全相同。
IO感知 (IO-aware)
- 與原始的注意力計算方法相比,Flash Attention考慮了硬件(特別是GPU)的特性,而不是將其當作黑盒來處理。
使用Flash Attention
可以通過以下兩種方式來實現:
- 切片和重新計算:Flash Attention 將序列分成較小的塊,並在每個塊上計算注意力。這可以減少計算量,因為每個塊的注意力矩陣都小得多。此外,Flash Attention 還會重新利用中間計算結果,以進一步減少計算量。
- 稀疏表示:Flash Attention 使用稀疏表示來表示注意力矩陣。這意味著只存儲非零元素,從而減少內存使用量。
安裝/使用Flash Attention
系統要求:
- CUDA 11.6 及更高版本。
- PyTorch 1.12 及更高版本。
- Linux系統。此功能有可能於 v2.3.2版本之後開始支持 Windows,但 Windows 編譯仍然需要更多的測試。
我們推薦 Nvidia 的 Pytorch 容器,它具有安裝 FlashAttention 所需的所有工具。
在使用Flash Attention要先安裝:
- PyTorch。
- 安裝
pip install packaging
- 確保已安裝並且
ninja
工作正常(例如,ninja --version
然後echo $?
應返回退出代碼 0)。如果不是(有時ninja --version
然後echo $?
返回非零退出代碼),請卸載然後重新安裝ninja
(pip uninstall -y ninja && pip install ninja
)。如果沒有ninja
,編譯可能需要很長時間(2 小時),因為它不使用多個 CPU 內核。ninja
在 3 核機器上編譯需要 5-64 分鐘。 - 然後:
pip install flash-attn --no-build-isolation
如果您的電腦的 RAM 小於 96GB 且 CPU 內核眾多, ninja
則可能會運行過多的並行編譯作業,從而耗盡 RAM 量。要限制並行編譯作業的數量,可以設置環境變數 MAX_JOBS
:
MAX_JOBS=4 pip install flash-attn --no-build-isolation
使用範例
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""