Only activating 3.8B parameters, it performs comparably to similar 7B models, suitable for training and fine-tuning, Microsoft's latest research achievement

Researchers have revealed optimization patterns in the inference process of large language models. Through analysis of models of different scales, they discovered a predictable scaling law that can guide how to perform model inference more efficiently. This finding provides new ideas and methods for optimizing the inference performance of large-scale language models.

Implementing Sparsification with Top-K Function

The core operation of Q-Sparse is applying the Top-K sparsification function to the input tensor.

Specifically, the Transformer architecture uses nn.Linear layers (matrix multiplication) for projection in both attention and feed-forward layers, which can be represented as Y=X·W^T. (Where X is the input tensor, W represents its weights, and Y is the output tensor)

In Q-Sparse, for an input activation tensor X, it first calculates its absolute value |X| and sorts it to find the K elements with the largest absolute values.

K is a predefined hyperparameter that determines the degree of sparsification.

Then Q-Sparse creates a binary mask tensor M with the same shape as X, setting the corresponding positions in M to 1 for the K elements with the largest absolute values in |X|, and 0 for the rest.

Next, the Hadamard product (element-wise multiplication) of the input tensor X and the mask tensor M is performed to obtain the sparsified tensor X_sparse.

During forward propagation, the sparsified tensor X_sparse replaces the original input tensor X in subsequent calculations (such as matrix multiplication).

Since most elements in X_sparse have been set to zero, this significantly reduces computational and memory bandwidth requirements.

During backpropagation, Q-Sparse uses the Straight-Through Estimator (STE) to calculate the gradient of the Top-K function.

In traditional training methods, it's usually necessary to calculate the gradient of the loss function with respect to the network parameters and use gradient descent to update the parameters to minimize the loss.

However, when there are non-differentiable operations like quantization or Top-K in the network, gradient calculation becomes problematic because the gradients of these operations' outputs with respect to their inputs are zero at most points, preventing effective gradient propagation.

STE avoids the problem of vanishing gradients by directly passing the gradient to the tensor before sparsification.

In general backpropagation, the gradient of the loss function L with respect to x is ∂L/∂x=∂L/∂y⋅∂y/∂x, but this can't be calculated directly due to non-differentiability.

STE's solution is to only calculate the gradient of the loss function with respect to the sparsified tensor y, and then directly copy it to the original tensor x, essentially using ∂L/∂y as an estimate for ∂L/∂x.

For feed-forward layers, Q-Sparse uses the squared ReLU function instead of the regular ReLU activation function. The squaring operation can further increase the sparsity of activations (⊙ represents the Hadamard product).

Additionally, to adapt to quantized models, Q-Sparse quantizes the input tensor before applying Top-K sparsification to ensure compatibility between the sparsification operation and quantized representation. Its function representation is as follows:

Here, ε is a small constant used to avoid division by zero.

Specifically, for 1-bit quantized weights, Q-Sparse uses the following quantization function, where α is the average absolute value of the weight tensor W.

Achieving the Same Effect with 60% of Activation Parameters

Comparative experiments show that Q-Sparse significantly outperforms previous ReLU methods in terms of both sparsity rate and model performance.

Regarding the specific effects of Q-Sparse, the authors evaluated its performance on three tasks: training from scratch, continued training, and fine-tuning.

The from-scratch training experiments used the Llama model. Results show that on 700M and 7B models, Q-Sparse with 70% top-K (i.e., 40% overall sparsity) can achieve comparable training loss to the dense baseline.

The purpose of continued training is to sparsify dense models, with Mistral-7B as the experimental subject.

Results show that with 2.9B and 3.8B activation parameters, the model's scores on datasets like ARC and MMLU did not significantly decrease.

In fine-tuning experiments, Q-Sparse showed similar results to continued training for both Qwen-7B and Mistral-7B models, achieving very close performance to dense models with about 60% of the activation parameters.

These results imply that, at the same performance level, sparse activation models can significantly reduce activation parameters during inference compared to dense models, thereby reducing the number of FLOPs consumed.

For quantized models, the team applied Q-Sparse to their self-developed BitNet b1.58 model and conducted training and evaluation on multiple datasets.

It can be seen that at both 700M and 7B scales, the convergence speed and final loss function values of quantized models using Q-Sparse are comparable to those of quantized models without Q-Sparse (BitNet b1.58).

This indicates that Q-Sparse can be seamlessly integrated into quantized models without significantly affecting model training and convergence.

Based on this, the authors believe that combining Q-Sparse with quantization techniques can further improve the efficiency of large language models during inference.

Discovering New "Scaling Law" for Inference Optimization

In addition to evaluating the performance of these models with sparse activation, the authors also explored the relationship between model performance, scale, and sparsity rate, leading to some new findings.

Performance scaling law for sparse activation models: The authors found that, similar to dense models, the performance of sparse activation models also follows a power-law scaling relationship.

Specifically, given a sparsity rate S, the value of the loss function L(N,S) at convergence can be approximated by the following formula: