BitMat: Improving Ternary Matrix Multiplication with Triton
This work introduces BitMat, a Python package that employs custom Triton kernels to optimize ternary matrix multiplication operations (parameters are ternary, taking on values of {0, 1, -1}). Our package exploits the principles outlined in the '1-bit-LLM Era' family of articles presented in https://arxiv.org/abs/2402.17764 by leveraging packed int8 data during inference. BitMat achieves significant improvements in memory usage. Code is available at https://github.com/astramind-ai/BitMat/tree/main.
Using this methodology, memory savings can be achieved at the decoding levels.
As the model scales, these savings become increasingly ignificant. This trend is illustrated in Table below,, where the relative size of the model components that are not ternarized becomes negligible compared to the overall model size, especially as the hidden layer dimensions expand.
Although initial considerations might suggest a potential trade-off in performance and processing time due to the method's intermediate steps, empirical data from Figures below indicate that performance remains effectively consistent with PyTorch's standard matrix multiplication. However, by increasing the precision to float32, we notice how our custom kernel performs more effectively. This observation underscores the method's capacity to deliver substantial memory efficiency gains without substantially impacting computational performance, aligning closer to optimal savings as the model's size increases.
From results, it can be observed that the performance of the custom matmul to handle the multiplication of ternarymatrices is better for higher precision. This may be due to the optimized process within the GPU.We also noted that triton is not the best framework to address the following problem, as it can lead to threadsynchronization problems within the GPU. Therefore, this project should be considered as a first step in preparationfor a more structured implementation in CUDA.At the moment BitMat presents a optimized, but classical matrix multiplication. Our future goal, as proposedin the articles, is to introduce a new computational paradigm for BitNet b1.58. This will require the design of newhardware optimized for the 1-bit LLM, which needs almost no matrix multiplication.
We extend our gratitude to the Triton community and the authors of the ”1bit-LLM Era” papers fortheir inspirational work. Special thanks also to the developers of BitDelta and UnSloth, whose contributionslaid the groundwork for BitMat’s development.