4比特量化三倍加速不掉點!清華即插即用的SageAttention迎來升級

論文共同第一作者張金濤、黃浩峰分別來自清華大學計算機系和交叉信息研究院,論文通訊作者陳鍵飛副教授及其他合作作者均來自清華大學計算機系。

大模型中,線性層的低比特量化已經逐步落地。然而,對於注意力模塊,目前幾乎各個模型都還在用高精度(例如 FP16 或 FP32)的注意力運算進行訓練和推理。並且,隨着大型模型需要處理的序列長度不斷增加,Attention(注意力運算)的時間開銷逐漸成爲主要開銷。

此前,清華大學陳鍵飛團隊提出的 8-Bit 的即插即用 Attention(SageAttention),將 Attention 中的 QK^T 量化至 INT8,將 PV 保持爲 FP16 精度並使用 FP16 精度的矩陣乘法累加器,同時提出 Smooth K 技術保持了量化 Attention 的精度,實現了 2 倍加速於 FlashAttention2,且在各類大模型上均保持了端到端的精度表現。

目前,SageAttention 已經被業界及社區廣泛地使用於各種開源及商業大模型中,比如 CogvideoX、Mochi、Flux、Llama3、Qwen 等。

近日,陳鍵飛團隊進一步提出了 4-Bit 的即插即用 Attention(SageAttention2),相較於 FlashAttention2 和 xformers 分別實現了3倍以及4.5倍的即插即用的推理加速,且在視頻、圖像、文本生成等大模型上均保持了端到端的精度表現。

即插即用舉例

SageAttention2 實現了高效的 Attention 算子,可以實現即插即用的推理加速。輸入任意 Q, K, V 矩陣,SageAttention2 可以快速返回 Attention Output (O)。

具體來說,SageAttention2 使用起來很方便,克隆倉庫(git clone https://github.com/thu-ml/SageAttention)並執行 python setup.py install 後,只需一行代碼便可以得到 Attention 的輸出,可以使用該接口方便地替換任意模型中的 Attention 函數:

效果上,以開源視頻生成模型 CogvideoX-1.5-5B 爲例,使用 SageAttention2 可以端到端加速1.8倍,且生成的視頻無損。

更重要的是,SageAttention2 提供了比 SageAttention 更廣泛的硬件支持。除了在 RTX 4090 上可以 3 倍加速於 FlashAttention 外,在 L20、L40、L40S 可以實現 2 倍的加速,在 A100、A800、A6000 上可以實現 1.45-1.6 倍的加速(基於 SageAttention)。

接下來,研究團隊將從前言、挑戰、方法以及實驗效果四個方面介紹 SageAttention2(總體流程圖如下圖)。

前言

隨着大模型需要處理的序列長度越來越長,Attention 的速度優化變得越來越重要。下圖展示了一個標準的 Transformer 模型中各運算的時間佔比隨序列長度的變化:

爲了方便指代注意力運算中的矩陣,我們先回顧一下注意力的計算公式:

儘管 SageAttention 提出將 Q,K 量化至 INT8,將 P,V 保持 FP16 精度且採用 FP16 的矩陣乘法累加器來加快 Attention 的速度。然而,這樣做的缺點是:1)INT8 的矩陣乘法只達到了一半的 INT4 矩陣乘法的速度,2)使用 FP16 的乘法累加器的 FP16 的矩陣乘法的加速只在 RTX4090 和 RTX3090 顯卡上有效。

爲了克服上述缺點,SageAttention2 提出將 Q, K 量化至 INT4,並將 P, V 量化至 FP8 來加速 Attention。然而,這樣做的挑戰是很大的。

4-Bit 注意力量化有什麼問題?

研究團隊發現直接將注意力運算中的 Q, K 量化爲 INT4 後將會導致在幾乎所有模型和任務上都會得到極差的結果,例如,在 CogVideoX 文生視頻模型中,會得到完全模糊的視頻;Llama2-7B 進行四選一選擇題任務上得到 25% 的準確率。

經過仔細分析後,研究團隊發現主要是兩個原因導致了量化注意力的不準確:

(1)INT4 的數值範圍相比 INT8 非常小,導致其量化誤差在 Q,K 矩陣中出現一些異常值時會變得十分明顯,恰好大多模型都在 Q, K 中表現出來了較大的通道維度的異常值。這極大削減了 QK^⊤矩陣乘法的精度。

(2)研究團隊發現 Nvidia 的顯卡上,FP8 的矩陣乘法指令 (mma.f32.f8.f8.f32) 的乘法累加器並不是官方宣稱的 FP32 精度,而是隻有 FP22 精度,這導致了 PV 矩陣乘法出現較大的累加誤差。

技術方案

爲了解決上述的兩個挑戰,研究團隊提出了對應的解決辦法。

(1)保留 SageAttention 中對 K 進行平滑處理的同時,提出對 Q 進行平滑處理:Q – mean (Q)。其中 mean (Q) 是沿着通道維度的平均值向量。完成該平滑操作後需要在 Attention 計算過程中將 mean (Q) 和 K^T 的向量與矩陣乘法的結果補償到 S 中。

這使得相比直接量化 Q, K 至 INT4 的準確度有質的改變,如下表展示了對比了該方法和直接量化 Q, K 至 INT4 在 Cogvideo 和 Llama3.1 上的端到端表現。

矩陣 Q 平滑前後的數據分佈可視化的結果如下,可以發現平滑後的 Q 對 INT4 數據範圍的利用度更高:

(2)對 Q, K 進行 Per-thread 量化。對於矩陣 Q, K,SageAttention2 採用了根據 mma 指令對矩陣內存排布的要求,對 Q,K 中的 Token 按照 GPU 線程進行分組,使量化粒度比 SageAttention 中的 per-block 細化 16 倍,極大提高了 4Bit 的 QK^⊤乘法準確度的同時不引入任何額外開銷。

具體來說,在 SageAttention 中,每個 Q 的塊將被劃分爲 c_w 個段,由 GPU 流處理器(SM)中的 c_w 個 GPU warp 處理。然後,每個包含 32 個線程的 warp 會使用 NVIDIA 的 mma.m16n8k64 PTX 指令來執行 QK^⊤運算。根據這一指令的佈局要求,研究團隊發現一個 warp 內的 Q [8×(n%8)] 可以共用一個量化縮放參數,而一個 warp 內的 K [8×(n%8)] 和 K [8×(n%8+1)] 也可以共用一個量化縮放參數,其中 n 是 token 索引。

這種量化方法更爲細緻且不增加額外開銷。這是因爲它根據 MMA 指令的佈局將不同的 GPU 線程分配到不同的量化 Token 組,每個線程只對應一個量化縮放參數進行反量化。而非 Per-token 量化那樣,每個線程對應多個量化縮放參數。

如下表所示,可以發現 per-thread 量化的準確度比 SageAttention 中採用的 per-block 量化高得多,準確度和 per-token 量化幾乎沒有差別。

(3)對 FP8 的 PV 矩陣乘法採用 FP32 的寄存器將每次 FlashAttention 分塊粒度的 PV 的 FP22 的乘法結果累加起來。這種做法可以有效地避免 FP22 的乘法累加器沿着序列長度累積過多的誤差,將 FP22 累加器帶來的誤差控制在 FlashAttention 分塊的粒度中,提高了 FP8 的 PV 乘法的準確度。

(4)針對 P 和 V,研究團隊對比了多種量化的數據類型,對比發現使用 E4M3 數據格式的 FP8 精度最準確,基本接近了 FP16 的準確度。因此採用將 P 和 V 量化至 E4M3。

下圖展示了 SageAttention2 的算法流程:

SageAttention2 共實現了兩種 Kernel,區別在於對 Q, K 進行 INT4 量化還是 INT8 量化:

此外,SageAttention2 還提出一種可選的對矩陣 V 進行平滑處理的技術,可以進一步提高 PV 矩陣乘法的準確度。具體來說,當某些模型中 V 矩陣具有通道維度的偏移時,可以將 V 減去其通道維度的平均值 mean (V) 來去除偏移,之後進行正常的量化 Attention 運算。只需要對最終 Attention 的 Output 加上 mean (V) 即可保持計算的正確性。

這種做法可以提升準確度的原因如下圖所示。在 FP22 的表示範圍內,數值越大,相比 FP32 的誤差越大。而 P 的範圍是 0~1 之間,那麼當 V 矩陣的列有較大的數值偏移時,PV 的 FP22 累加器的精度就越差,通過平滑 V 去除偏移後,就可以加強 PV 矩陣乘法的準確度。

實驗效果

SageAttention 實現了底層的 GPU CUDA Kernel,在算子速度以及各個模型端到端準確度上都有十分不錯的表現。

具體來說,算子速度相比於 FlashAttention2 和 xformers 有大約 3 倍以及 4.5 倍的加速:

算子的準確度方面也是比對 Q, K 進行 SmoothQuant 和 Hadamard 變換要更加準確:

各模型在真實場景的端到端精度表現中,在視頻、圖像、文本生成等大模型上均保持了端到端的精度表現:

下圖是在 HunyuanVideo 中的可視化實例:

下圖是在 Cogvideo 中的可視化實例:

下表展示了各個語言、視頻、圖像生成模型中 SageAttention2 的端到端精度表現:

端到端的速度表現上,SageAttention2 兩個 Kernel 的實現均可以有效地對長序列模型進行加速,比如可以端到端1.8倍加速 CogVideoX1.5-5B,其他模型上也均有1.6到1.8倍的提速。