僅激活3.8B參數,表現媲美同類7B模型,適用於訓練和微調,微軟最新研究成果

研究人員揭示了大型語言模型推理過程中的優化規律。通過對不同規模模型的分析,他們發現了一種可預測的縮放法則,能夠指導如何更高效地進行模型推理。這一發現為優化大規模語言模型的推理性能提供了新的思路和方法。

使用Top-K函數實現稀疏化

Q-Sparse所做的最核心的操作,是對輸入的張量應用Top-K稀疏化函數。

具體來說,Transformer架構在注意力層和前饋層中都使用nn.Linear線性層(矩陣乘法)進行投影,可以表示為Y=X·W^T。(其中X就是輸入張量,W代表其權重,Y為輸出張量)

Q-Sparse中,對於一個輸入激活張量X,首先會計算其絕對值|X|並進行排序,找出其中絕對值最大的K個元素。

這裡的K是預先設定的超參數,決定了稀疏化的程度。

之後Q-Sparse會創建一個與X形狀相同的二進制掩碼張量M,對於一系列|X|中絕對值最大的K個元素對應的位置,將M中的相應位置設置為1,其餘位置設置為0。

接著,將輸入張量X與掩碼張量M進行Hadamard積(逐元素相乘)運算,就得到了稀疏化的張量X_sparse。

在前向傳播過程中,稀疏化後的張量X_sparse將代替原始的輸入張量X參與後續的計算(如矩陣乘法)。

由於X_sparse中大部分元素已經被設置為零,因此可以顯著減少計算量和內存帶寬需求。

在反向傳播過程中,Q-Sparse使用了直通估計器(Straight-Through Estimator,STE)來計算Top-K函數的梯度。

傳統的訓練方式中,通常需要計算損失函數對網絡參數的梯度,並使用梯度下降法更新參數以最小化損失。

但當網絡中存在量化、Top-K等一些不可微的操作時,梯度的計算就會遇到問題,因為這些操作的輸出對輸入的梯度在大多數點上都是0,導致梯度無法有效傳播。

STE通過直接將梯度傳遞給稀疏化之前的張量,避免了梯度消失的問題。

一般的反向傳播中,損失函數L對x的梯度∂L/∂x=∂L/∂y⋅∂y/∂x,但由於不可微分無法直接計算。

STE的解決方案是只計算損失函數對稀疏化張量y的梯度,然後將其直接複製給原始張量x,也就是直接將∂L/∂y作為∂L/∂x的估計。

對於前饋層,Q-Sparse使用平方ReLU函數代替常規的ReLU激活函數,平方運算可以進一步提高激活的稀疏性(⊙表示Hadamard積)。

另外,為了適配量化模型,Q-Sparse在應用Top-K稀疏化之前,會先對輸入張量進行量化,以確保稀疏化操作與量化表示兼容,其函數表示如下:

其中,ε是一個小常數,用於避免出現分母為零的情況。

特別的,對於1-bit量化的權重,Q-Sparse使用以下量化函數,其中α是權重張量W的平均絕對值。

60%激活參數達到相同效果

對比實驗表明,無論是稀疏率還是模型表現,Q-Sparse都顯著優於此前的ReLU方法。

針對Q-Sparse的具體效果,作者對其在從頭訓練、繼續訓練和微調三項任務上的性能進行了評估。

從頭訓練實驗使用的模型為Llama,結果在700M和7B模型上,使用70% top-K(即40%的整體稀疏率)的Q-Sparse可以達到與密集baseline相當的訓練損失。

繼續訓練的目的是將稠密模型稀疏化,這裡的實驗對象是Mistral-7B。

結果,在激活參數為2.9B和3.8B的情況下,模型在ARC、MMLU等數據集中的得分均未發生明顯下降。

在微調實驗中,對於Qwen-7B和Mistral-7B兩種模型,Q-Sparse顯示出了與繼續訓練相似的結果,用60%左右的激活參數實現了與密集模型十分接近的表現。

這些結果意味著,在相同的性能下,與密集模型相比,稀疏激活模型在推理過程中可以顯著減少激活參數,進而降低消耗FLOPS的數量。

對於量化模型,團隊在自研的BitNet b1.58模型上應用了Q-Sparse,並在多個數據集上進行了訓練和評估。

可以看到,在700M和7B兩種規模下,使用Q-Sparse的量化模型的收斂速度和最終損失函數值與未使用Q-Sparse的量化模型(BitNet b1.58)相當。

這說明Q-Sparse可以無縫集成到量化模型中,而不會顯著影響模型的訓練和收斂。

據此作者認為,將Q-Sparse與量化技術相結合,可以進一步提高大語言模型在推理階段的效率。

發現推理優化新"Scaling Law"

除了測評這些模型採取稀疏激活時的表現,作者也對模型性能、規模和稀疏率三者之間的關係進行了探究,並有了一些新的發現。

稀疏激活模型的性能縮放定律: 作者發現,與密集模型類似,稀疏激活模型的性能也遵循一個冪律縮放關係。

具體來說,給定稀疏率S,模型在收斂時的損失函數值L(N,S)可以用以下公式近似: