Positive-tuning LLMs to 1.58bit: extreme quantization made easy

-



As Large Language Models (LLMs) grow in size and complexity, finding ways to scale back their computational and energy costs has change into a critical challenge. One popular solution is quantization, where the precision of parameters is reduced from the usual 16-bit floating-point (FP16) or 32-bit floating-point (FP32) to lower-bit formats like 8-bit or 4-bit. While this approach significantly cuts down on memory usage and quickens computation, it often comes on the expense of accuracy. Reducing the precision an excessive amount of could cause models to lose crucial information, leading to degraded performance.

BitNet is a special transformers architecture that represents each parameter with only three values: (-1, 0, 1), offering a extreme quantization of just 1.58 ( log2(3) log_2(3)



Table of Contents



TL;DR

BitNet is an architecture introduced by Microsoft Research that uses extreme quantization, representing each parameter with only three values: -1, 0, and 1. This ends in a model that uses just 1.58 bits per parameter, significantly reducing computational and memory requirements.

This architecture uses INT8 addition calculations when performing matrix multiplication, in contrast to LLaMA LLM’s FP16 addition and multiplication operations.

The new computation paradigm of BitNet b1.58
The brand new computation paradigm of BitNet b1.58 (source: BitNet paper https://arxiv.org/abs/2402.17764)

This ends in a theoretically reduced energy consumption, with BitNet b1.58 saving 71.4 times the arithmetic operations energy for matrix multiplication in comparison with the Llama baseline.

Energy consumption of BitNet b1.58 compared to LLaMA
Energy consumption of BitNet b1.58 in comparison with LLama (source: BitNet paper https://arxiv.org/abs/2402.17764)

We now have successfully fine-tuned a Llama3 8B model using the BitNet architecture, achieving strong performance on downstream tasks. The 8B models we developed are released under the HF1BitLLM organization. Two of those models were fine-tuned on 10B tokens with different training setup, while the third was fine-tuned on 100B tokens. Notably, our models surpass the Llama 1 7B model in MMLU benchmarks.



The way to Use with Transformers

To integrate the BitNet architecture into Transformers, we introduced a brand new quantization method called “bitnet” (PR). This method involves replacing the usual Linear layers with specialized BitLinear layers which are compatible with the BitNet architecture, with appropriate dynamic quantization of activations, weight unpacking, and matrix multiplication.

Loading and testing the model in Transformers is incredibly straightforward, there are zero changes to the API:

model = AutoModelForCausalLM.from_pretrained(
    "HF1BitLLM/Llama3-8B-1.58-100B-tokens",
    device_map="cuda",
    torch_dtype=torch.bfloat16
)    
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

input_text = "Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?nAnswer:"

input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
output = model.generate(input_ids, max_new_tokens=10)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)

With this code, all the pieces is managed seamlessly behind the scenes, so there isn’t any have to worry about additional complexities, you only need to put in the most recent version of transformers.

For a fast test of the model, take a look at this notebook



What’s BitNet In More Depth?

BitNet replaces traditional Linear layers in Multi-Head Attention and Feed-Forward Networks with specialized layers called BitLinear that use ternary precision (and even binary, within the initial version). The BitLinear layers we use on this project quantize the weights using ternary precision (with values of -1, 0, and 1), and we quantize the activations to 8-bit precision. We use a special implementation of BitLinear for training than we do for inference, as we’ll see in the subsequent section.

The important obstacle to training in ternary precision is that the load values are discretized (via the round() function) and thus non-differentiable. BitLinear solves this with a pleasant trick: STE (Straight Through Estimator). The STE allows gradients to flow through the non-differentiable rounding operation by approximating its gradient as 1 (treating round() as akin to the identity function). One other technique to view it’s that, as an alternative of stopping the gradient on the rounding step, the STE lets the gradient go through as if the rounding never occurred, enabling weight updates using standard gradient-based optimization techniques.

The architecture of BitNet with BitLinear layers
The architecture of BitNet with BitLinear layers (source: BitNet paper https://arxiv.org/pdf/2310.11453)



Training

We train in full precision, but quantize the weights into ternary values as we go, using symmetric per tensor quantization. First, we compute the typical of absolutely the values of the load matrix and use this as a scale. We then divide the weights by the size, around the values, constrain them between -1 and 1, and eventually rescale them to proceed in full precision.

scalew=11nm∑ij∣Wij∣ scale_w = frac{1}{frac{1}{nm} sum_{ij} |W_{ij}|}

Wq=clamp[−1,1](round(W∗scale)) W_q = text{clamp}_{[-1,1]}(text{round}(W*scale))

Wdequantized=Wq∗scalew W_{dequantized} = W_q*scale_w

Activations are then quantized to a specified bit-width (8-bit, in our case) using absmax per token quantization (for a comprehensive introduction to quantization methods take a look at this post). This involves scaling the activations into the range [−128, 127] for an 8-bit bit-width. The quantization formula is:

scalex=127∣X∣max, dim=−1 scale_x = frac{127}{|X|_{text{max}, , text{dim}=-1}}

Xq=clamp[−128,127](round(X∗scale)) X_q = text{clamp}_{[-128,127]}(text{round}(X*scale))

Xdequantized=Xq∗scalex X_{dequantized} = X_q * scale_x

To make the formulas clearer, listed here are examples of weight and activation quantization using a 3×3 matrix:


Example 1: Weight Matrix Quantization

Let the load matrix ( W ) be:

W=[0.8−0.51.2−1.50.4−0.91.3−0.70.2] W = begin{bmatrix} 0.8 & -0.5 & 1.2 -1.5 & 0.4 & -0.9 1.3 & -0.7 & 0.2 end{bmatrix}

Step 1: Compute the Scale for Weights

Using the formula:

scalew=11nm∑ij∣Wij∣ scale_w = frac{1}{frac{1}{nm} sum_{ij} |W_{ij}|}

we calculate the typical absolute value of ( W ):

1nm∑ij∣Wij∣=19(0.8+0.5+1.2+1.5+0.4+0.9+1.3+0.7+0.2)=19(7.5)=0.8333 frac{1}{nm} sum_{ij} |W_{ij}| = frac{1}{9}(0.8 + 0.5 + 1.2 + 1.5 + 0.4 + 0.9 + 1.3 + 0.7 + 0.2) = frac{1}{9}(7.5) = 0.8333

Now, the size factor is:

scalew=10.8333≈1.2 scale_w = frac{1}{0.8333} approx 1.2

Step 2: Quantize the Weight Matrix

Using the formula:

Wq=clamp[−1,1](round(W×scalew)) W_q = text{clamp}_{[-1, 1]}(text{round}(W times scale_w))

We first scale the weights by scalew≈1.2 scale_w approx 1.2

W×scalew=[0.8×1.2−0.5×1.21.2×1.2−1.5×1.20.4×1.2−0.9×1.21.3×1.2−0.7×1.20.2×1.2]=[0.96−0.61.44−1.80.48−1.081.56−0.840.24] W times scale_w = begin{bmatrix} 0.8 times 1.2 & -0.5 times 1.2 & 1.2 times 1.2 -1.5 times 1.2 & 0.4 times 1.2 & -0.9 times 1.2 1.3 times 1.2 & -0.7 times 1.2 & 0.2 times 1.2 end{bmatrix} = begin{bmatrix} 0.96 & -0.6 & 1.44 -1.8 & 0.48 & -1.08 1.56 & -0.84 & 0.24 end{bmatrix}

Next, we around the values and clamp them to the range [−1,1] [-1, 1]

Wq=[1−11−10−11−10] W_q = begin{bmatrix} 1 & -1 & 1 -1 & 0 & -1 1 & -1 & 0 end{bmatrix}

Step 3: Dequantize the Weights

Finally, we dequantize the weights using:

Wdequantized=Wq×scalew W_{dequantized} = W_q times scale_w

Substituting scale_w, we get:

Wdequantized=[1×1.2−1×1.21×1.2−1×1.20×1.2−1×1.21×1.2−1×1.20×1.2]=[1.2−1.21.2−1.20−1.21.2−1.20] W_{dequantized} = begin{bmatrix} 1 times 1.2 & -1 times 1.2 & 1 times 1.2 -1 times 1.2 & 0 times 1.2 & -1 times 1.2 1 times 1.2 & -1 times 1.2 & 0 times 1.2 end{bmatrix} = begin{bmatrix} 1.2 & -1.2 & 1.2 -1.2 & 0 & -1.2 1.2 & -1.2 & 0 end{bmatrix}

Example 2: Activation Matrix Quantization

Let the activation matrix ( X ) be:

X=[1.0−0.60.7−0.90.4−1.20.8−0.50.3] X = begin{bmatrix} 1.0 & -0.6 & 0.7 -0.9 & 0.4 & -1.2 0.8 & -0.5 & 0.3 end{bmatrix}

Step 1: Compute the Scale for Activations

For every row (or channel), compute the utmost absolute value:

  • Row 1: Maximum absolute value = 1.0
  • Row 2: Maximum absolute value = 1.2
  • Row 3: Maximum absolute value = 0.8

Compute the size aspects for every row:

scale=[1271.01271.21270.8]=[127105.83158.75] text{scale} = begin{bmatrix} frac{127}{1.0} frac{127}{1.2} frac{127}{0.8} end{bmatrix} = begin{bmatrix} 127 105.83 158.75 end{bmatrix}

Step 2: Quantize the Activation Matrix

Using the formula:

Xq=clamp[−128,127](round(X×scale)) X_q = text{clamp}_{[-128,127]}(text{round}(X times text{scale}))

Scale the activations:

X×scale=[1.0×127−0.6×1270.7×127−0.9×105.830.4×105.83−1.2×105.830.8×158.75−0.5×158.750.3×158.75]=[127−76.288.9−95.242.3−127127−79.447.6] X times text{scale} = begin{bmatrix} 1.0 times 127 & -0.6 times 127 & 0.7 times 127 -0.9 times 105.83 & 0.4 times 105.83 & -1.2 times 105.83 0.8 times 158.75 & -0.5 times 158.75 & 0.3 times 158.75 end{bmatrix} = begin{bmatrix} 127 & -76.2 & 88.9 -95.2 & 42.3 & -127 127 & -79.4 & 47.6 end{bmatrix}

Around the values and clamp them to the range [−128,127][-128, 127]

Xq=[127−7689−9542−127127−7948] X_q = begin{bmatrix} 127 & -76 & 89 -95 & 42 & -127 127 & -79 & 48 end{bmatrix}

Step 3: Dequantize the Activations

Finally, dequantize the activations using:

Xdequantized=Xq×1scale X_{dequantized} = X_q times frac{1}{text{scale}}

Substituting the scales:

Xdequantized=[127×1127−76×112789×1127−95×1105.8342×1105.83−127×1105.83127×1158.75−79×1158.7548×1158.75]=[1.0−0.60.7−0.90.4−1.20.8−0.50.3] X_{dequantized} = begin{bmatrix} 127 times frac{1}{127} & -76 times frac{1}{127} & 89 times frac{1}{127} -95 times frac{1}{105.83} & 42 times frac{1}{105.83} & -127 times frac{1}{105.83} 127 times frac{1}{158.75} & -79 times frac{1}{158.75} & 48 times frac{1}{158.75} end{bmatrix} = begin{bmatrix} 1.0 & -0.6 & 0.7 -0.9 & 0.4 & -1.2 0.8 & -0.5 & 0.3 end{bmatrix}


We apply Layer Normalization (LN) before quantizing the activations to keep up the variance of the output:

LN(x)=x−E(x)Var(x)+ϵ text{LN}(x) = frac{x – E(x)}{sqrt{text{Var}(x) + epsilon}}

where ϵ is a small number to stop overflow.

The round() function is just not differentiable, as mentioned before. We use detach() as a trick to implement a differentiable straight-through estimator within the backward pass:


import torch
import torch.nn as nn 
import torch.nn.functional as F

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y
 
def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u

class BitLinear(nn.Linear):
    """
    Just for training
    """
    def forward(self, x):
        w = self.weight
        x_norm = LN(x)
        
        
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        
        
        y = F.linear(x_quant, w_quant)
        return y



Inference

During inference, we simply quantize the weights to ternary values without rescaling. We apply the identical approach to activations using 8-bit precision, then perform a matrix multiplication with an efficient kernel, followed by dividing by each the load and activation scales. This could significantly improve inference speed, particularly with optimized hardware. You’ll be able to see that the rescaling process differs during training, as matrix multiplications are kept in fp16/bf16/fp32 for correct training.


import torch
import torch.nn as nn 
import torch.nn.functional as F

def activation_quant_inference(x):
    x = LN(x)
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127)
    return y, scale
 
class BitLinear(nn.Linear):
    """
    Just for training
    """
    def forward(self, x):
        w = self.weight 
        w_scale = self.w_scale  
        x_quant, x_scale = activation_quant_inference(x)
        y = efficient_kernel(x_quant, w) / w_scale / x_scale
        return y



Pre-training Leads to 1.58b

Before attempting fine-tuning, we first tried to breed the outcomes of the BitNet paper with pre-training. We began with a small dataset, tinystories, and a Llama3 8B model. We confirmed that adding a normalization function, just like the paper does, improves performance. For instance, after 2000 steps of coaching, we had a perplexity on the validation set equal to six.3 without normalization, and 5.9 with normalization. Training was stable in each cases.

Pre-training plots without (blue) & with (green) layer normalisation
Pre-training plots without (blue) & with (orange) layer normalisation

While this approach looks very interesting for pre-training, only a couple of institutions can afford doing it on the mandatory scale. Nonetheless, there’s already a wide selection of strong pretrained models, and it will be extremely useful in the event that they may very well be converted to 1.58bit after pre-training. Other groups had reported that fine-tuning results weren’t as strong as those achieved with pre-training, so we set out on an investigation to see if we could make 1.58 fine-tuning work.



Positive-tuning in 1.58bit

Once we began fine-tuning from the pre-trained Llama3 8B weights, the model performed barely higher but not in addition to we expected.

Note: All our experiments were conducted using Nanotron. For those who’re fascinated with trying 1.58bit pre-training or fine-tuning, you’ll be able to take a look at this PR.

Fine-tuning plot compared to pre-training plot
Positive-tuning plot in comparison with pre-training plot

To know why, we tried to examine each the load distributions of the randomly initialized model and the pre-trained model to discover potential issues.

Random weights distribution (2 merged stds)
Random weights distribution (2 merged stds)
Pre-trained Llama3 weights distribution
Pre-trained Llama3 weights distribution

And the size values for the 2 distributions are, respectively :

Random weights scales distribution
Random weights scales distribution
Pre-trained Llama3 weights distribution
Pre-trained Llama3 weights distribution

The initial random weight distribution is a mixture of two normal distributions:

  • One with a regular deviation (std) of 0.025 0.025
  • One other with a std of 0.0252â‹…num_hidden_layers=0.00325 frac{0.025}{sqrt{2 cdot text{num_hidden_layers}}} = 0.00325

This results from using different stds for column linear and row linear weights in nanotron. Within the quantized version, all matrices have only 2 weight scales (50.25 and 402), that are the inverse of the mean absolute value of the weights for every matrix: scale = 1.0 / w.abs().mean().clamp_(min=1e-5)

  • For scale=50.25text{scale} = 50.25
  • For scale=402 text{scale} = 402

However, the pretrained weight’s distribution looks like a standard distribution with an std=0.013 text{std} = 0.013

Clearly, the pretrained model starts with more information (scales), while the randomly initialized model starts with practically no information and adds to it over time. Our conclusion was that starting with random weights gives the model minimal initial information, enabling a gradual learning process, while during fine-tuning, the introduction of BitLinear layers overwhelms the model into losing all its prior information.

To enhance the fine-tuning results, we tried different techniques. For instance, as an alternative of using per-tensor quantization, we tried per-row and per-column quantization to maintain more information from the Llama 3 weights. We also tried to vary the best way the size is computed: as an alternative of just taking the mean absolute value of the weights as a scale, we take the mean absolute value of the outliers as a scale (an outlier value is a worth that exceeds k*mean_absolute_value, where k is a relentless we tried to differ in our experiments), but we didn’t notice big improvements.

def scale_outliers(tensor, threshold_factor=1):
    mean_absolute_value = torch.mean(torch.abs(tensor))
    threshold = threshold_factor * mean_absolute_value
    outliers = tensor[torch.abs(tensor) > threshold]
    mean_outlier_value = torch.mean(torch.abs(outliers))
    return mean_outlier_value

def weight_quant_scaling(w):
    scale = 1.0 / scale_outliers(w).clamp_(min=1e-5)
    quantized_weights = (w * scale).round().clamp_(-1, 1) / scale
    return quantized_weights

We observed that each the random weights and the Llama 3 weights resulted in losses starting at roughly the identical value of 13. This means that the Llama 3 model loses all of its prior information when quantization is introduced. To further investigate how much information the model loses during this process, we experimented with per-group quantization.

As a sanity check, we first set the group size to 1, which essentially means no quantization. On this scenario, the loss began at 1.45, same as we see during normal fine-tuning. Nonetheless, once we increased the group size to 2, the loss jumped to around 11. This means that even with a minimal group size of two, the model still loses nearly all of its information.

To deal with this issue, we considered the potential for introducing quantization progressively moderately than applying it abruptly to the weights and activations for every tensor. To attain this, we implemented a lambda value to regulate the method :

lambda_ = ?
x_quant = x + lambda_ * (activation_quant(x) - x).detach()
w_quant = w + lambda_ * (weight_quant(w) - w).detach()

When lambda is ready to 0, there is basically no quantization occurring, while at lambda=1, full quantization is applied.

We initially tested some discrete lambda values, equivalent to 0.25, 0.5, 0.75, and 1. Nonetheless, this approach didn’t result in any significant improvement in results, mainly because lambda=0.25 is already high enough for the loss to start out very high.

Fine-tuning plot with lambda = 0.25->0.5->0.75->1
Positive-tuning plot with lambda = 0.25->0.5->0.75->1

In consequence, we decided to experiment with a lambda value that adjusts dynamically based on the training step.

lambda_ = training_step / total_training_steps

Using this dynamic lambda value led to raised loss convergence, however the perplexity (ppl) results during inference, when lambda was set to 1, were still removed from satisfactory. We realized this was likely since the model hadn’t been trained long enough with lambda=1. To deal with this, we adjusted our lambda value to enhance the training process.

lambda_ = min(2 * training_step / total_training_steps, 1)

With this configuration, after 2000 steps we’ve :

Fine-tuning plot with lambda = min(2*training_step/total_training_steps, 1)
Positive-tuning plot with lambda = min(2*training_step/total_training_steps, 1)

Our fine-tuning method shows higher convergence overall. You’ll be able to observe a slight increase within the loss curve around 1,000 steps, which corresponds to once we begin approaching lambda=1, or full quantization. Nonetheless, immediately after this point, the loss starts to converge again, resulting in an improved perplexity of roughly 4.

Despite this progress, once we tested the quantized model on the WikiText dataset (as an alternative of the tinystories one we used for fine-tuning), it showed a really high perplexity. This means that fine-tuning the model in low-bit mode on a particular dataset causes it to lose much of its general knowledge. This issue might arise since the minimal representations we aim for with ternary weights can vary significantly from one dataset to a different. To deal with this problem, we scaled our training process to incorporate the larger FineWeb-edu dataset. We maintained a lambda value of:

lambda_ = min(training_step/1000, 1)

We selected this lambda value since it gave the impression to be a great start line for warming up the model. We then trained the model using a learning rate of 1e-4 for five,000 steps on the FineWeb-edu dataset. The training involved a batch size (BS) of two million, totaling 10 billion tokens.

Finding the proper learning rate and the proper decay was difficult; it appears to be an important think about the model’s performance.

Fine-tuning plot with warmup quantization on Fineweb-edu
Positive-tuning plot with warmup quantization on Fineweb-edu

After the fine-tuning process on Fineweb-Edu, the perplexity on the WikiText dataset reached 12.2, which is kind of impressive provided that we only used 10 billion tokens. The opposite evaluation metrics also show strong performance considering the limited amount of knowledge (see results).

We also tried to smooth out the sharp increase when lambda approaches 1. To do that, we considered using lambda schedulers that grow exponentially at first, then level off as they catch up with to 1.

def scheduler(step, total_steps, k):
    normalized_step = step / total_steps
    return 1 - (1 - normalized_step)**k

for various k values, with various total warmup steps of 1, we’ve plots like the next :

Exponential scheduler for different k values
Exponential scheduler for various k values

We ran 4 experiments using the best-performing learning rate of 1e-4, testing values of k in [4, 6, 8, 10].

Fine-tuning plots with exponential scheduler
Positive-tuning plots with exponential scheduler

The smoothing worked well, as there isn’t any spike like with the linear scheduler. Nonetheless, the perplexity is not great, staying around ~15, and the performance on downstream tasks is just not higher.

We also noticed the spike at first, which the model struggled to get better from. With lambda = 0, there’s essentially no quantization, so the loss starts low, around ~2. But right after step one, there is a spike, just like what happened with the linear scheduler (as seen within the blue plot above). So, we tried a special scheduler—a sigmoid one—that starts slowly, rises sharply to 1, after which levels off because it approaches 1.

def sigmoid_scheduler(step, total_steps, k):
    
    normalized_step = step / total_steps
    return 1 / (1 + np.exp(-k * (normalized_step - 0.5)))

For various k values we’ve the next curves :

Sigmoid scheduler for different k values
Sigmoid scheduler for various k values

We ran 5 experiments this time with k in [15, 20, 25, 40, 100] :

Finetuning plots with sigmoid scheduler
Positive-tuning plots with sigmoid scheduler

The sharp increase in lambda caused instability across the five hundredth step and didn’t fix the primary divergence issue. Nonetheless, for k=100 k = 100

Moreover, we experimented with training models from scratch using random weights and various learning rates. This allowed us to match the effectiveness of our fine-tuning approach against traditional pre-training methods.

Different Pre-training plots with different learning rates
Different Pre-training plots with different learning rates

Not one of the models trained from random weights performed higher than our fine-tuned model. One of the best perplexity we achieved with those models was 26, which falls short in comparison with the outcomes from our fine-tuning approach.



Scaling to 100B Tokens !

We scaled our experiments to 100 billion tokens to see if we could match the performance of Llama 3 8B. We conducted longer training runs, ranging from our best-performing checkpoint from the shorter runs with the linear scheduler, and continued fine-tuning for 45,000 steps. We experimented with different learning rates, and while the model performed closely to the Llama 3 model in some metrics, on average, it still lagged behind.

Listed below are some examples of the metrics we evaluated at various checkpoints through the training :

Metrics evaluations during the training for different lrs
Metrics evaluations through the training for various lrs

and the typical rating looks like :

Average evaluation during the training for different lrs
Average evaluation through the training for various lrs



Experiments on Smaller Models

In our initial experiments with smaller models like SmolLM, we observed that the warmup quantization technique didn’t yield as much improvement because it did with larger models. This means that the effectiveness of warmup quantization may very well be more closely related to model size and complexity.

For instance, listed here are the loss curves for the SmolLM 135M model, comparing warmup quantization with full quantization from the beginning. Interestingly, the curves closely align, and the resulting perplexities aren’t significantly different.

Smoll LLm fine-tuning experiment with & without warmup quantization
Smoll LLm fine-tuning experiment with & without warmup quantization



Results & Comparison

BitNet is effective in delivering strong performance in comparison with baseline methods, especially at lower bit levels. In line with the paper, BitNet achieves scores which are on par with 8-bit models but with significantly lower inference costs. Within the case of 4-bit models, methods that only quantize weights outperform those who quantize each weights and activations, as activations are harder to quantify. Nonetheless, BitNet, which uses 1.58-bit weights, surpasses each weight-only and weight-and-activation quantization methods.

The table below presents the outcomes for various metrics after the 10B fine-tuning means of Llama3 8B. These results are compared against those from other model architectures to supply a comprehensive overview of performance (All evaluations were conducted using Lighteval on the Nanotron format model)

Metrics comparison with Llama models
Metrics comparison with Llama models : Linear means Linear lambda scheduler, and Sigmoid means Sigmoid lambda scheduler (in our case k = 100)

After fine-tuning on just 10 billion tokens using ternary weights, the model demonstrates impressive performance, especially compared to other models that underwent way more extensive training. As an illustration, it outperforms the Bitnet 7B model, which was trained on a significantly larger dataset of 100 billion tokens. Moreover, it performs higher than the FBI LLM (Fully Binarized LLM), a model that was distilled on a fair more massive 1.26 trillion tokens. This highlights the model’s efficiency and effectiveness despite the relatively smaller scale of its fine-tuning process.

For the 100B tokens experiments, one of the best performing checkpoint we had is the next :

Metrics comparison with Llama models for the model trained on 100B tokens
Metrics comparison with Llama models for the model trained on 100B tokens

To copy these results, you’ll be able to take a look at this PR to convert models to nanotron format, unpack the weights (check the function unpack_weights), and use lighteval

Note that although the models are fine-tuned from an Instruct-tuned model, they still should be fine-tuned using an Instruct dataset as well. These might be considered base models.



Custom Kernels & Benchmarks

To learn from the BitNet low-precision weights, we pack them into an int8 tensor (this makes the variety of parameters go from 8B to 2.8B!). During inference, these weights have to be unpacked before performing matrix multiplication. We implemented custom kernels in Cuda and Triton to handle the on-the-fly unpacking through the matrix multiplication process. For the matrix multiplication itself, we employed the cached tiled matrix multiplication technique. To totally grasp this approach, let’s first review some Cuda programming fundamentals.



Basic GPU Concepts: Threads, Blocks, and Shared Memory

Before diving into cached tiled matrix multiplication, it is vital to know some basic GPU concepts:

  • Threads and Blocks: GPUs execute hundreds of threads concurrently. These threads are grouped into blocks, and every block runs independently. The grid is made up of those blocks, and it represents your complete problem space. For instance, in matrix multiplication, each thread may be answerable for computing a single element of the output matrix.
  • Shared Memory: Each block has access to a limited amount of shared memory, which is far faster than global memory (the important memory on the GPU). Nonetheless, shared memory is restricted in size and shared amongst all threads inside a block. Using shared memory effectively is vital to improving performance in GPU programs.



Challenges in Matrix Multiplication

A straightforward implementation of matrix multiplication on a GPU might involve each thread computing a single element of the result matrix by directly reading the mandatory elements from global memory. Nonetheless, this approach might be inefficient for the next reasons:

  • Memory Bandwidth: Accessing global memory is comparatively slow in comparison with the speed at which the GPU cores can perform computations. If each thread reads matrix elements directly from global memory, the memory access times can change into a bottleneck.
  • Redundant Data Access: In matrix multiplication, many elements of the input matrices are used multiple times. If each thread fetches the required data from global memory independently, the identical data may be loaded into the GPU multiple times, resulting in inefficiency. For instance, if each thread is used to compute a single element within the output matrix, the thread answerable for calculating the element at position (i, j) might want to load the i-th row of matrix A and the j-th column of matrix B from global memory. Nonetheless, other threads, equivalent to the one computing the element at position (i+1, j), cannot reuse this data and may have to reload the identical j-th column from global memory again.



The Idea of Tiling

Tiling is a way used to handle these challenges, and it was mainly utilized in FlashAttention to enhance the kernel’s efficiency. The fundamental idea is to divide the matrices into smaller sub-matrices, called tiles, which may fit into the shared memory of the GPU. As an alternative of computing your complete output matrix in a single go, the computation is broken down into smaller pieces which are processed tile by tile.

Within the context of matrix multiplication, this implies dividing matrices A and B into blocks (tiles), loading these tiles into shared memory, after which performing the multiplication on these smaller blocks. This approach allows the threads to reuse data stored within the fast shared memory, reducing the necessity to access global memory repeatedly.

Here’s how it really works:

  • Loading Tiles into Shared Memory: Each block of threads cooperatively loads a tile of matrix A and a corresponding tile of matrix B from global memory into shared memory. This operation is completed once per tile, after which the tile is reused multiple times by the threads within the block.
  • Computing Partial Products: Once the tiles are loaded into shared memory, each thread computes a partial product. Since all threads in a block are working on the identical tiles in shared memory, they will efficiently reuse the information without additional global memory accesses.
  • Accumulating Results: After computing the partial products for one tile, the threads load the subsequent tiles from matrices A and B into shared memory and repeat the method. The outcomes are collected in a register (or local memory), and once all tiles have been processed, the ultimate value for the output matrix element is written back to global memory.
Tiled Matrix multiplication illustration
Tiled Matrix multiplication illustration (source https://cnugteren.github.io/tutorial/pages/page4.html)

Practical Considerations

When implementing cached tiled matrix multiplication, several aspects are considered:

  • Tile Size: The scale of the tiles must be chosen to balance the trade-off between the quantity of knowledge that may fit into shared memory and the number of world memory accesses.
  • Memory Coalescing: the worldwide memory accesses are coalesced, which implies that adjoining threads access adjoining memory locations.
  • Occupancy: The variety of threads per block and the variety of blocks within the grid must be chosen to make sure high occupancy, which implies having as many energetic warps (a warp is a set of 32 threads) as possible on the GPU to cover memory latency.



Triton Kernel

Here is the kernel in triton we benchmarked :

@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn, 
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  
        GROUP_SIZE_M: tl.constexpr,
):

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)

    for i in range(4) : 
        b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
        for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ):
            k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j 

            
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
            b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K // 4 - j * BLOCK_SIZE_K, other=0)
            mask = 3<<(2*i)
            b = ((b_uint8 & mask) >> (2*i))

            
            tensor_full = tl.full((1,), 1, dtype=tl.int8)

            accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)

            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk

    c = accumulator

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def matmul(a, b):
    assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the load matrix should be packed"
    assert a.is_contiguous(), "Matrix A have to be contiguous"
    M, K = a.shape
    _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c



Code Breakdown

  1. Determining Tile Positions

The kernel first determines which tile (block) of the output matrix each thread block is answerable for:

  • pid is the unique identifier for every thread block, obtained using tl.program_id(axis=0).
  • The grid is split into groups of thread blocks (GROUP_SIZE_M). Each group processes a portion of the output matrix.
  • pid_m and pid_n are the coordinates of the tile within the M and N dimensions, respectively.
  • Offsets (offs_am, offs_bn, offs_k) are calculated to find out which elements of matrices A and B each thread within the block will work on
  1. Loading and Computing Tiles

The kernel uses a loop to iterate over the K dimension in chunks of BLOCK_SIZE_K. For every chunk:

  • Load Tiles: tiles from matrices A and B are loaded from global memory.
  • Unpacking Matrix B: The kernel assumes that matrix B is full of int8 values, meaning each element actually represents 4 smaller values packed into one byte. The unpacking happens inside the loop:
    • b_uint8 is loaded from global memory as packed int8.
    • Each packed value is unpacked to acquire the actual weight values used for computation.
  • Dot Product: The kernel computes the dot product of the loaded tiles from A and B, accumulating the ends in the accumulator. The accumulator stores the partial results for the tile of the output matrix C.
  1. Storing Results

In any case tiles along the K dimension have been processed, the ultimate results stored within the accumulator are converted to float16 and written back to the corresponding tile of matrix C in global memory. The writing process respects memory boundaries using a mask to be sure that only valid elements are written.

For a more detailed explanation of the code, checkout this PR



Benchmark

We benchmarked our kernel against the tactic of unpacking the weights using @torch.compile followed by performing the matmul in BF16 precision, and located that each approaches achieved roughly the identical performance. To make sure accurate benchmarking, we performed the matmul operation over 2000 iterations and averaged the time taken through the last 1000 iterations, to eliminate any inefficiencies related to initial loading or compilation. Below is a graph showing the benchmark results. We also tested various matrix sizes, with the x-axis representing the variety of multiplications on a log scale, and the y-axis showing the typical time in ms.

Triton kernel compared to torch.compile
Triton kernel in comparison with torch.compile

We also tried using BitBlas, which is a software library designed to perform matrix operations with mixed precision. It helps optimize these operations by allowing calculations to be done in lower precision formats like INT8, INT4, and even INT2, as an alternative of the standard FP32 or FP16 formats.

The benchmark results are promising, as BitBlas outperforms each our custom kernel and Torch’s matmul function in low precision, as shown within the graph.

Bitblas benchmark
Bitblas benchmark

Nonetheless, during model loading, BitBlas must compile kernels tailored to the form of the load matrix and store them in an area database, which may increase the initial loading time.



Conclusion

In conclusion, as LLMs proceed to expand, reducing their computational demands through quantization is crucial. This blog has explored the approach of 1.58-bit quantization, which uses ternary weights. While pre-training models in 1.58 bits is resource-intensive, we’ve demonstrated that, with some tricks, it’s possible to fine-tune existing models to this precision level, achieving efficient performance without sacrificing accuracy. By optimizing inference speed through specialized kernels, BitNet opens recent possibilities for making LLMs more practical and scalable.



Acknowledgements

We would really like to precise our sincere gratitude to Leandro von Werra, Thomas Wolf, and Marc Sun for his or her invaluable assistance and insights throughout this project. We also extend our due to Omar Sanseviero and Pedro Cuenca for his or her contributions in refining this blog post, helping to speak our findings clearly and effectively to the AI community.
Moreover, we would like to acknowledge the GeneralAI team for his or her pioneering work on the BitNet project. Their research has been foundational to our efforts, and we’re particularly grateful for the clear and precise figures provided of their paper.



Additional Resources

  1. H. Wang et al., BitNet: Scaling 1-bit Transformers for Large Language Models. arxiv paper
  2. S. Ma et al., The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits. arxiv paper
  3. S. Ma et al., The Era of 1-bit LLMs: Training Suggestions, Code and FAQ. link
  4. RJ. Honicky, Are All Large Language Models Really in 1.58 Bits?. blogpost
  5. L. Mao, CUDA Matrix Multiplication Optimization. blogpost
  6. Tutorial: OpenCL SGEMM tuning for Kepler. link
  7. CUDAMODE. github, youtube
  8. Wen-mei W. Hwu, David B. Kirk, Izzat El Hajj, Programming Massively Parallel Processors : A Hands-on Approach



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x