File size: 5,741 Bytes
af247a4 c3666a9 af247a4 c3666a9 cba0a96 c3666a9 20e9915 c3666a9 2dc97bb 3ebba7d 092838f c3666a9 f0003b0 81babae c3666a9 342bf46 c3666a9 b9643c4 c3666a9 4b20f21 c3666a9 b9643c4 c3666a9 4b20f21 c3666a9 4b20f21 c3666a9 57862da c3666a9 57862da b8e0c41 c3666a9 f0c7a8b c3666a9 f0c7a8b c3666a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
---
license: llama2
train: false
inference: false
pipeline_tag: text-generation
---
This is an experimental <a href="https://github.com/mobiusml/hqq/">HQQ</a> 1-bit quantized (<b>binary weights</b>) <a href="https://huggingface.co/meta-llama/Llama-2-7b-chat-hf"> Llama2-7B-chat model </a> using a low-rank adapter to improve the performance (referred to as <a href="https://mobiusml.github.io/1bit_blog/">HQQ+</a>).
![image/gif](https://cdn-uploads.huggingface.co/production/uploads/636b945ef575d3705149e982/3fOfrg-5WtJwC5cpcVDub.gif)
Quantizing small models at extreme low-bits is a challenging task. The purpose of this model is to show the community what to expect when fine-tuning such models.
We notice that, 1-bit quantization doesn't work well when applied directly on small models such as the Llama2-7B. However, when fine-tuned, the model's ouput significantly improves. In fact, the 1-bit base model outperforms Quip# 2-bit after fine-tuning on ~2.8K samples.
Note that the weights here are unsigned 1-bit (0 or 1), <a href="https://arxiv.org/abs/2402.17764">not ternary like the recent 1.58-bit work </a>. This is a more challenging task since we lose the sign of the weights and only fine-tune a small fraction of the parameters (~94MB worth of weights).
The dequantization step can be rewriten as a 1-bit matmul which could potentially require only additions + a very low-rank matmul which is fast to compute.
This version offloads the meta-data to the CPU, so only the binary weights and the low-rank adapters are stored in the GPU memory.
## Datasets
The adapter was trained via SFT on random subsets of the following:
### Base Model
* <a href="https://huggingface.co/datasets/wikitext">wikitext-2-raw-v1</a> (full)
### Chat Model
* <a href="https://huggingface.co/datasets/timdettmers/openassistant-guana"> timdettmers/openassistant-guanaco </a> (full)
* <a href="https://huggingface.co/datasets/icrosoft/orca-math-word-problems-200k"> microsoft/orca-math-word-problems-200k </a> (25K)
* <a href="https://huggingface.co/datasets/meta-math/MetaMathQA"> meta-math/MetaMathQA </a> (25K)
* <a href="https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized"> HuggingFaceH4/ultrafeedback_binarized </a> (25K - chosen answers only)
## Performance
| Models | Llama2-7B (fp16)| Llama2-7B (HQQ 1-bit)| Llama2-7B (HQQ+ 1-bit)| Quip# (2-bit)|
|-------------------|------------------|------------------|------------------|------------------|
| Wiki Perpexlity | 5.18 | 9866 | <b>8.53</b> | 8.54 |
| VRAM (GB) | 13.5 | <b>1.76</b> | 1.85 | 2.72 |
| forward time (sec)| <b>0.1<b> | 0.231 | 0.257 | 0.353 |
| Models | Llama2-7B-chat (fp16)| Llama2-7B-chat (HQQ 1-bit)| Llama2-7B-chat (HQQ+ 1-bit)|
|-------------------|------------------|------------------|------------------|
| ARC (25-shot) | 53.67 | 21.59 | 31.14 |
| HellaSwag (10-shot)| 78.56 | 25.66 | 52.96 |
| MMLU (5-shot) | 48.16 | 25.08 | 26.54 |
| TruthfulQA-MC2 | 45.32 | 47.81 | 43.16 |
| Winogrande (5-shot)| 72.53 | 49.72 | 60.54 |
| GSM8K (5-shot) | 23.12 | 0 | 11 |
| Average | 53.56 | 28.31 | 37.56 |
## Usage
First, install <a href="https://github.com/mobiusml/hqq/">HQQ</a>:
```
pip install hqq==0.1.8
pip install transformers==4.40
```
Then you can use the sample code below:
``` Python
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
#Load the model
model_id = 'mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq'
model = HQQModelForCausalLM.from_quantized(model_id, adapter='adapter_v0.1.lora')
tokenizer = AutoTokenizer.from_pretrained(model_id)
#Setup Inference Mode
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
if not tokenizer.pad_token: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.config.use_cache = True
model.eval();
# Optional: torch compile for faster inference
# model = torch.compile(model)
#Streaming Inference
import torch, transformers
from threading import Thread
def chat_processor(chat, max_new_tokens=100, do_sample=True, device='cuda'):
tokenizer.use_default_system_prompt = False
streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_params = dict(
tokenizer("<s> [INST] " + chat + " [/INST] ", return_tensors="pt").to(device),
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
top_p=0.90 if do_sample else None,
top_k=50 if do_sample else None,
temperature= 0.6 if do_sample else None,
num_beams=1,
repetition_penalty=1.2,
)
t = Thread(target=model.generate, kwargs=generate_params)
t.start()
print("User: ", chat);
print("Assistant: ");
outputs = ""
for text in streamer:
outputs += text
print(text, end="", flush=True)
torch.cuda.empty_cache()
return outputs
```
### Example
``` Python
outputs = chat_processor("What is the solution to x^2 - 1 = 0", max_new_tokens=1000, do_sample=False)
```
```
User: What is the solution to x^2 - 1 = 0
Assistant:
The equation $x^2 - 1 = 0$ can be factored as $(x-1)(x+1) = 0$.
You want to find a value of $x$ that makes this true for all values of $x$. This means that either $x=1$ or $-1$, or $x=-1$. So, there are two solutions: $x=\boxed{1}$ and $x=\boxed{-1}$. The answer is: 1
```
|