https://arxiv.org/abs/2402.17764
The main addition of the new paper seems to be the implementation of optimized and fused kernels using triton, as seen here:
https://github.com/ridgerchu/matmulfreellm/blob/master/mmfre...
This is quite useful, as this should make training this type of LLMs much more efficient.
So this is a ternary weight LLM using quantization aware training (QAT). The activations are quantized to 8 bits. The matmal is still there, but it is multiplying the 8 bit activations by one bit values.
Quantization aware training with low bit weights seems to lead to reduced overfitting by an intrensic tendency to regularize. However, also the model capacity should be reduced compared to a model with the same number of weights and a higher number of bits per weights. It's quite possible that this only becomes apparent after the models have been trained with a significant number of tokens, as LLMs seem to be quite sparse.
Edit: In addition to the QAT they also changed the model architecture to use a linear transformer to reduce reliance on multiplications in the attention mechanism. Thanks to logicchains for pointing this out.
This is the kind of stuff that can only be done so quickly by having more and more people brought into the field to try these ideas out. The more people the more permutations of ideas.
1) It’s weird to choose linear attention for their implementation because that’s not what their paper is about and they claim no insights relevant to attention mechanisms.
2) By benchmarking all models this way (linear vs linear) it likely inflated their numbers over comparing their removal of matmul in a quadratic vs quadratic scenario.
3) This claim implies a comparison to the state of the art in language models where the standard is quadratic attention, and is therefore a flawed comparison:
“We processed billion-parameter scale models at 13W beyond human readable throughput, moving LLMs closer to brain-like efficiency.”
4) Those type of brain comparisons fall apart under scrutiny, are not standard in ML research and don’t mean much anyway.
5) Right up front in the abstract they make specific performance claims and imply they come from removing matmul, but don’t mention linear attention until section 4 on experiments.
They get real (61%!?) memory savings during training, and inference too.
On top of all that, they then go build an FPGA core which is programmed with a custom assembler. And their code is posted and works seamlessly with huggingface transformers?! Absolutely going to test this out.
> we show that we can eliminate all multiplications in the entire training process, including operations in the forward pass, backward pass and optimizer update, demonstrating the first successful training of modern neural network architectures in a fully multiplication-free fashion.
This reaches demoscene levels of crazy/impressive!
The exp/log trick to multiply with addition does indeed look very familiar. I know that a number of demos used it in the 90ies to simplify matrix multiplications for 3d graphics.
Bill Dally, the Nvidia Chief Scientist also seems to be a big fan of log8 representations: https://youtu.be/gofI47kfD28?t=1953
It seems it did not make it into the hardware though, yet...
So what’s the extra trick to make the model stay quantized? Does one evaluate the gradients on a whole bunch of training inputs, add them up, apply some randomness, and then re-quantize the model? Or is it something else?
>To train our 1-bit model, we employ the straight-through estimator (STE)[BLC13] to approximate the gradient during backpropagation. This method bypasses the nondifferentiable functions, such as the Sign (Eq. 2) and Clip (Eq. 5) functions, during the backward pass. STE allows gradients to flow through the network without being affected by these non-differentiable functions, making it possible to train our quantized model.
>While the weights and the activations are quantized to low precision, the gradients and the optimizer states are stored in high precision to ensure training stability and accuracy. Following the previous work [LSL+21], we maintain a latent weight in a high-precision format for the learnable parameters to accumulate the parameter updates. The latent weights are binarized on the fly during the forward pass and never used for the inference process.
This seems a bit unfortunate — the training process will end up using a whole lot more memory than inference. I wonder whether one could get away with storing the high precision weights in slow host memory and using the quantized weights for the backward pass, thus keeping them out of GPU memory.
the OP outlines what could be an entirely different compute paradigm for LLMs, hence the FPGA study. they just happen to also get impressive performance on GPUs making the most of the available interface.
it is super easy to try it out, the 2.7B, 1.3B, 0.37B models are on huggingface, and the generate.py example just works if you have triton 2.2 installed