Posted on

FlashAttention介紹

官方資訊

Github: https://github.com/Dao-AILab/flash-attention

論文文檔: https://arxiv.org/abs/2205.14135

此存儲庫提供了以下論文中 FlashAttention 和 FlashAttention-2 的官方實現。可讓我們在建模時有更快的注意力,更好的並行度和工作分區。下面為一張概念示意圖:

官方所做的效能提升試驗結果如下:

甚麼是Flash Attention

Flash Attention 是一種注意力算法,旨在提高基於 Transformer 的模型的效率,使其能夠處理更長的序列長度並更快地進行訓練和推理。它通過減少計算量和內存使用來實現這一點。Flash Attention 是一種快速且內存高效的精確注意力機制,其設計考慮了IO(輸入輸出)的特性。

這項技術的關鍵點

快速 (Fast)

  • 訓練BERT-large(序列長度512)比MLPerf 1.1中的訓練速度記錄快15%。
  • 訓練GPT-2(序列長度1K)比HuggingFace和Megatron-LM的基準實現快3倍。
  • 在long-range arena(序列長度1K-4K)中,比基準速度快2.4倍。

高效內存使用 (Memory-efficient)

  • 傳統的注意力機制內存訪問量是O(N²),而Flash Attention的內存訪問量是亞二次方/線性的。

精確 (Exact)

  • 這不是近似算法(例如稀疏或低秩矩陣方法),其結果與原始方法完全相同。

IO感知 (IO-aware)

  • 與原始的注意力計算方法相比,Flash Attention考慮了硬件(特別是GPU)的特性,而不是將其當作黑盒來處理。

使用Flash Attention

可以通過以下兩種方式來實現:

  • 切片和重新計算:Flash Attention 將序列分成較小的塊,並在每個塊上計算注意力。這可以減少計算量,因為每個塊的注意力矩陣都小得多。此外,Flash Attention 還會重新利用中間計算結果,以進一步減少計算量。
  • 稀疏表示:Flash Attention 使用稀疏表示來表示注意力矩陣。這意味著只存儲非零元素,從而減少內存使用量。

安裝/使用Flash Attention

系統要求:

  • CUDA 11.6 及更高版本。
  • PyTorch 1.12 及更高版本。
  • Linux系統。此功能有可能於 v2.3.2版本之後開始支持 Windows,但 Windows 編譯仍然需要更多的測試。

我們推薦 Nvidia 的 Pytorch 容器,它具有安裝 FlashAttention 所需的所有工具。

在使用Flash Attention要先安裝:

  1. PyTorch。
  2. 安裝  pip install packaging 
  3. 確保已安裝並且 ninja 工作正常(例如, ninja --version 然後 echo $? 應返回退出代碼 0)。如果不是(有時 ninja --version 然後 echo $? 返回非零退出代碼),請卸載然後重新安裝 ninja ( pip uninstall -y ninja && pip install ninja )。如果沒有 ninja ,編譯可能需要很長時間(2 小時),因為它不使用多個 CPU 內核。 ninja 在 3 核機器上編譯需要 5-64 分鐘。
  4. 然後:
pip install flash-attn --no-build-isolation

如果您的電腦的 RAM 小於 96GB 且 CPU 內核眾多, ninja 則可能會運行過多的並行編譯作業,從而耗盡 RAM 量。要限制並行編譯作業的數量,可以設置環境變數 MAX_JOBS :

MAX_JOBS=4 pip install flash-attn --no-build-isolation

使用範例

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
                          window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
    qkv: (batch_size, seqlen, 3, nheads, headdim)
    dropout_p: float. Dropout probability.
    softmax_scale: float. The scaling of QK^T before applying softmax.
        Default to 1 / sqrt(headdim).
    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
    window_size: (left, right). If not (-1, -1), implements sliding window local attention.
    alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
        the attention score of query i and key j.
    deterministic: bool. Whether to use the deterministic implementation of the backward pass,
        which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
    out: (batch_size, seqlen, nheads, headdim).
"""