sync from upstream
Browse files- configuration_chatglm.py +2 -0
- modeling_chatglm.py +42 -6
- quantization.py +370 -56
configuration_chatglm.py
CHANGED
@@ -73,6 +73,7 @@ class ChatGLMConfig(PretrainedConfig):
|
|
73 |
inner_hidden_size=16384,
|
74 |
position_encoding_2d=True,
|
75 |
quantization_bit=0,
|
|
|
76 |
pre_seq_len=None,
|
77 |
prefix_projection=False,
|
78 |
**kwargs
|
@@ -92,6 +93,7 @@ class ChatGLMConfig(PretrainedConfig):
|
|
92 |
self.gmask_token_id = gmask_token_id
|
93 |
self.position_encoding_2d = position_encoding_2d
|
94 |
self.quantization_bit = quantization_bit
|
|
|
95 |
self.pre_seq_len = pre_seq_len
|
96 |
self.prefix_projection = prefix_projection
|
97 |
|
|
|
73 |
inner_hidden_size=16384,
|
74 |
position_encoding_2d=True,
|
75 |
quantization_bit=0,
|
76 |
+
quantization_embeddings=False,
|
77 |
pre_seq_len=None,
|
78 |
prefix_projection=False,
|
79 |
**kwargs
|
|
|
93 |
self.gmask_token_id = gmask_token_id
|
94 |
self.position_encoding_2d = position_encoding_2d
|
95 |
self.quantization_bit = quantization_bit
|
96 |
+
self.quantization_embeddings = quantization_embeddings
|
97 |
self.pre_seq_len = pre_seq_len
|
98 |
self.prefix_projection = prefix_projection
|
99 |
|
modeling_chatglm.py
CHANGED
@@ -32,6 +32,7 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
|
|
32 |
|
33 |
from .configuration_chatglm import ChatGLMConfig
|
34 |
|
|
|
35 |
# flags required to enable jit fusion kernels
|
36 |
|
37 |
if sys.platform != 'darwin':
|
@@ -224,7 +225,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|
224 |
self.sin_cached = fn(self.sin_cached)
|
225 |
return super()._apply(fn)
|
226 |
|
227 |
-
|
228 |
def rotate_half(x):
|
229 |
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
230 |
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
|
@@ -1059,7 +1059,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1059 |
self.quantized = False
|
1060 |
|
1061 |
if self.config.quantization_bit:
|
1062 |
-
self.quantize(self.config.quantization_bit, empty_init=True)
|
1063 |
|
1064 |
def get_output_embeddings(self):
|
1065 |
return self.lm_head
|
@@ -1418,19 +1418,55 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1418 |
break
|
1419 |
yield input_ids
|
1420 |
|
1421 |
-
def quantize(self, bits: int, empty_init=False, **kwargs):
|
1422 |
if bits == 0:
|
1423 |
return
|
1424 |
|
1425 |
-
from .quantization import quantize
|
1426 |
|
1427 |
if self.quantized:
|
1428 |
-
|
|
|
|
|
|
|
|
|
1429 |
return self
|
1430 |
|
1431 |
self.quantized = True
|
1432 |
|
1433 |
self.config.quantization_bit = bits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1434 |
|
1435 |
-
self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
|
1436 |
return self
|
|
|
32 |
|
33 |
from .configuration_chatglm import ChatGLMConfig
|
34 |
|
35 |
+
|
36 |
# flags required to enable jit fusion kernels
|
37 |
|
38 |
if sys.platform != 'darwin':
|
|
|
225 |
self.sin_cached = fn(self.sin_cached)
|
226 |
return super()._apply(fn)
|
227 |
|
|
|
228 |
def rotate_half(x):
|
229 |
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
230 |
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
|
|
|
1059 |
self.quantized = False
|
1060 |
|
1061 |
if self.config.quantization_bit:
|
1062 |
+
self.quantize(self.config.quantization_bit, self.config.quantization_embeddings, use_quantization_cache=True, empty_init=True)
|
1063 |
|
1064 |
def get_output_embeddings(self):
|
1065 |
return self.lm_head
|
|
|
1418 |
break
|
1419 |
yield input_ids
|
1420 |
|
1421 |
+
def quantize(self, bits: int, quantize_embeddings=False, use_quantization_cache=False, empty_init=False, **kwargs):
|
1422 |
if bits == 0:
|
1423 |
return
|
1424 |
|
1425 |
+
from .quantization import quantize, QuantizedEmbedding, QuantizedLinear, load_cpu_kernel
|
1426 |
|
1427 |
if self.quantized:
|
1428 |
+
if self.device == torch.device("cpu"):
|
1429 |
+
logger.info("Already quantized, reloading cpu kernel.")
|
1430 |
+
load_cpu_kernel(**kwargs)
|
1431 |
+
else:
|
1432 |
+
logger.info("Already quantized.")
|
1433 |
return self
|
1434 |
|
1435 |
self.quantized = True
|
1436 |
|
1437 |
self.config.quantization_bit = bits
|
1438 |
+
self.config.quantization_embeddings = quantize_embeddings
|
1439 |
+
|
1440 |
+
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
|
1441 |
+
|
1442 |
+
if self.device == torch.device("cpu"):
|
1443 |
+
dtype = torch.float32
|
1444 |
+
else:
|
1445 |
+
dtype = torch.half
|
1446 |
+
|
1447 |
+
if quantize_embeddings:
|
1448 |
+
logger.info("Applying quantization to embeddings")
|
1449 |
+
self.transformer.word_embeddings = QuantizedEmbedding(
|
1450 |
+
weight_bit_width=bits,
|
1451 |
+
weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
|
1452 |
+
num_embeddings=self.transformer.word_embeddings.num_embeddings,
|
1453 |
+
embedding_dim=self.transformer.word_embeddings.embedding_dim,
|
1454 |
+
dtype=dtype,
|
1455 |
+
empty_init=empty_init,
|
1456 |
+
device=self.transformer.word_embeddings.weight.device,
|
1457 |
+
)
|
1458 |
+
self.lm_head = QuantizedLinear(
|
1459 |
+
weight_bit_width=bits,
|
1460 |
+
weight_tensor=self.lm_head.weight.to(self.device),
|
1461 |
+
bias_tensor=None,
|
1462 |
+
in_features=self.lm_head.in_features,
|
1463 |
+
out_features=self.lm_head.out_features,
|
1464 |
+
bias=False,
|
1465 |
+
quantized_weight=self.transformer.word_embeddings.weight,
|
1466 |
+
quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
|
1467 |
+
dtype=dtype,
|
1468 |
+
empty_init=empty_init,
|
1469 |
+
device=self.lm_head.weight.device,
|
1470 |
+
)
|
1471 |
|
|
|
1472 |
return self
|
quantization.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
-
from torch.nn import Linear
|
2 |
from torch.nn.parameter import Parameter
|
|
|
3 |
|
|
|
4 |
import bz2
|
5 |
import torch
|
6 |
import base64
|
@@ -38,7 +40,7 @@ try:
|
|
38 |
)
|
39 |
except Exception as exception:
|
40 |
kernels = None
|
41 |
-
logger.warning("Failed to load cpm_kernels:"
|
42 |
|
43 |
|
44 |
class W8A16Linear(torch.autograd.Function):
|
@@ -64,25 +66,193 @@ class W8A16Linear(torch.autograd.Function):
|
|
64 |
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
65 |
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
68 |
-
|
|
|
|
|
69 |
n, m = weight.size(0), weight.size(1)
|
70 |
assert m % 2 == 0
|
71 |
m = m // 2
|
72 |
-
out = torch.empty(n, m, dtype=torch.int8, device="
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
kernels.int4WeightCompression(
|
79 |
-
gridDim,
|
80 |
-
blockDim,
|
81 |
-
0,
|
82 |
-
stream,
|
83 |
-
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
84 |
)
|
85 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
@@ -117,85 +287,229 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc
|
|
117 |
return out
|
118 |
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
class QuantizedLinear(Linear):
|
121 |
-
def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs):
|
122 |
super(QuantizedLinear, self).__init__(*args, **kwargs)
|
123 |
self.weight_bit_width = weight_bit_width
|
|
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
self.weight = torch.empty(
|
130 |
-
shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
131 |
-
)
|
132 |
-
self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
|
133 |
else:
|
134 |
-
|
135 |
-
self.weight
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
140 |
-
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
|
141 |
if bias_tensor is not None:
|
142 |
self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False)
|
143 |
else:
|
144 |
self.bias = None
|
145 |
|
|
|
|
|
|
|
|
|
146 |
def forward(self, input):
|
147 |
-
|
|
|
|
|
|
|
148 |
if self.bias is not None:
|
149 |
output = output + self.bias
|
150 |
return output
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
def quantize(model, weight_bit_width, empty_init=False, **kwargs):
|
154 |
"""Replace fp16 linear with quantized linear"""
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
for layer in model.layers:
|
157 |
-
layer.attention.query_key_value =
|
158 |
-
|
159 |
-
weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
|
160 |
bias_tensor=layer.attention.query_key_value.bias,
|
161 |
in_features=layer.attention.query_key_value.in_features,
|
162 |
out_features=layer.attention.query_key_value.out_features,
|
163 |
-
bias=True,
|
164 |
-
dtype=torch.half,
|
165 |
device=layer.attention.query_key_value.weight.device,
|
166 |
-
|
167 |
)
|
168 |
-
layer.attention.dense =
|
169 |
-
|
170 |
-
weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()),
|
171 |
bias_tensor=layer.attention.dense.bias,
|
172 |
in_features=layer.attention.dense.in_features,
|
173 |
out_features=layer.attention.dense.out_features,
|
174 |
-
bias=True,
|
175 |
-
dtype=torch.half,
|
176 |
device=layer.attention.dense.weight.device,
|
177 |
-
|
178 |
)
|
179 |
-
layer.mlp.dense_h_to_4h =
|
180 |
-
|
181 |
-
weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
182 |
bias_tensor=layer.mlp.dense_h_to_4h.bias,
|
183 |
in_features=layer.mlp.dense_h_to_4h.in_features,
|
184 |
out_features=layer.mlp.dense_h_to_4h.out_features,
|
185 |
-
bias=True,
|
186 |
-
dtype=torch.half,
|
187 |
device=layer.mlp.dense_h_to_4h.weight.device,
|
188 |
-
|
189 |
)
|
190 |
-
layer.mlp.dense_4h_to_h =
|
191 |
-
|
192 |
-
weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
193 |
bias_tensor=layer.mlp.dense_4h_to_h.bias,
|
194 |
in_features=layer.mlp.dense_4h_to_h.in_features,
|
195 |
out_features=layer.mlp.dense_4h_to_h.out_features,
|
196 |
-
bias=True,
|
197 |
-
dtype=torch.half,
|
198 |
device=layer.mlp.dense_4h_to_h.weight.device,
|
199 |
-
|
200 |
)
|
201 |
return model
|
|
|
1 |
+
from torch.nn import Linear, Embedding
|
2 |
from torch.nn.parameter import Parameter
|
3 |
+
import torch.nn.functional as F
|
4 |
|
5 |
+
import os
|
6 |
import bz2
|
7 |
import torch
|
8 |
import base64
|
|
|
40 |
)
|
41 |
except Exception as exception:
|
42 |
kernels = None
|
43 |
+
logger.warning("Failed to load cpm_kernels:", exception)
|
44 |
|
45 |
|
46 |
class W8A16Linear(torch.autograd.Function):
|
|
|
66 |
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
67 |
|
68 |
|
69 |
+
class W8A16LinearCPU(torch.autograd.Function):
|
70 |
+
@staticmethod
|
71 |
+
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
|
72 |
+
ctx.inp_shape = inp.size()
|
73 |
+
ctx.weight_bit_width = weight_bit_width
|
74 |
+
out_features = quant_w.size(0)
|
75 |
+
inp = inp.contiguous().view(-1, inp.size(-1))
|
76 |
+
weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
|
77 |
+
ctx.weight_shape = weight.size()
|
78 |
+
output = inp.mm(weight.t())
|
79 |
+
ctx.save_for_backward(inp, quant_w, scale_w)
|
80 |
+
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def backward(ctx, grad_output: torch.Tensor):
|
84 |
+
inp, quant_w, scale_w = ctx.saved_tensors
|
85 |
+
weight = extract_weight_to_float(quant_w, scale_w, ctx.weight_bit_width)
|
86 |
+
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
87 |
+
grad_input = grad_output.mm(weight)
|
88 |
+
grad_weight = grad_output.t().mm(inp)
|
89 |
+
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
90 |
+
|
91 |
+
|
92 |
+
default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
|
93 |
+
default_cpu_kernel_code = "QlpoOTFBWSZTWXLbSoQAAgzbgERwQXxmTwAAr/ff3kABt0Q2oRVT0hpo9RtEAAAAyBEiSQ9EGjQGQAAAwANGhowjJoNGmgMEUplMTNSMJ5TQaDJpsoMyRMj8P4mZzFSVVwqSXG8GG7MlVwiToYEQwVD7noBxMhNfkeZYtYFtbgOBUSIGtIQjhNHCEnPJsadhb3yBmRIOD3TeAtNLSaU5GgvKUBWSNuuOIHmVt0YhW6rsmDMDUjeUJGJ64R1Jm5lrh0Aa0tKjhFwPdWcGogxLDSXPWQUWTM8Sd3Qz1HMYNxx3HMeiNqNo4jeRDEfZ3gUSHIcU/heomq0vEzL1Msz5KKGxH8FrNOYw3KaxdqaEmNHYMxJFgQbR0DyRknL2L4kwUSxKRdhjRpEtUqilVfggFL1klaMS3PPRDfNqbBOPWO7m4JTVGhS9QTBDDJaEbLbrUQNB+IpJSKQbG5SZZ5gkwJEhJ3aYKJipZ/i7kinChIOW2lQg"
|
94 |
+
default_cpu_parallel_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels_parallel.c")
|
95 |
+
default_cpu_parallel_kernel_code = "QlpoOTFBWSZTWUzax5EAALXbgERwSX1mTwAAr/ff3kACNyXUbZYwBpoaNGIyAaADQwRSaVP9QoMg0A2oAPU0AEUkU9GaaKMaQB6gA09T1ARRKnpk0niaJkaaNDJ6g0DTIKVKfZ/g6v1Kem5LJLa0WmkukkuCIHUqWbtJGJMsCSQFiPEIYHgBIZDzR8R6REbYxIqD2Cu7lMkFoPu6LmHeOAy0GF83Tc40jgmTs4HnCe60QfJa2bDBZ0Y1lhgbiZjW8SNsAKCk42UOEdjWN3KoiCIYeQUCCKWIyHewhtSoInLKSG22l4jKM2ZDCVKtBm3OTYBl3jsVqMImtj7PQw7xKxLXQzwgJaPPgW1fRhrvPJICl4YFDYfNbkbBh5JDgrazFml50xEQQwQUjxNwE0IDSofLzSg7UNVKn+Rr1KErzBHUxBqdHRlXzqYsIa5K9Y0UuE2ugw3g5KYofm7AaGNTzJSMhcchhxdaU4JZ0F1UNgQ8XcGDguypqYza8yFaEoGgNRcLej+g2t0feGKFE5OY2PFluQ3q4HgycxlfvzHqo0KcM0JI8OKXtzayJFgsqC1NdUQVu8rChnA6FO3MFyGOoC9KO8ITPpYM5pRqTlczFkLES/4u5IpwoSCZtY8i"
|
96 |
+
|
97 |
+
cpu_kernels = None
|
98 |
+
|
99 |
+
|
100 |
+
class CPUKernel:
|
101 |
+
def __init__(self, kernel_file="", source_code=default_cpu_kernel_code_path, compile_parallel_kernel=None, parallel_num=None):
|
102 |
+
self.load =False
|
103 |
+
self.int8WeightExtractionFloat = None
|
104 |
+
self.int4WeightExtractionFloat = None
|
105 |
+
self.int4WeightCompression = None
|
106 |
+
self.SetNumThreads = lambda x: x
|
107 |
+
|
108 |
+
try:
|
109 |
+
if not os.path.exists(default_cpu_kernel_code_path):
|
110 |
+
with open(default_cpu_kernel_code_path, "w", encoding="utf-8") as file:
|
111 |
+
code = default_cpu_kernel_code
|
112 |
+
cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode()
|
113 |
+
file.write(cpu_quantization_code)
|
114 |
+
|
115 |
+
if not os.path.exists(default_cpu_parallel_kernel_code_path):
|
116 |
+
with open(default_cpu_parallel_kernel_code_path, "w", encoding="utf-8") as file:
|
117 |
+
code = default_cpu_parallel_kernel_code
|
118 |
+
cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode()
|
119 |
+
file.write(cpu_quantization_code)
|
120 |
+
|
121 |
+
except Exception as ex:
|
122 |
+
print("Error when generating default cpu kernel code(can be ignored when using custom kernels).")
|
123 |
+
|
124 |
+
if compile_parallel_kernel is None:
|
125 |
+
compile_parallel_kernel = bool(int(os.cpu_count()) >= 4)
|
126 |
+
|
127 |
+
if compile_parallel_kernel and source_code == default_cpu_kernel_code_path:
|
128 |
+
source_code = default_cpu_parallel_kernel_code_path
|
129 |
+
|
130 |
+
kernels = None
|
131 |
+
|
132 |
+
if (not kernel_file) or (not os.path.exists(kernel_file)):
|
133 |
+
print("No compiled kernel found.")
|
134 |
+
try:
|
135 |
+
if os.path.exists(source_code):
|
136 |
+
print("Compiling kernels :", source_code)
|
137 |
+
kernel_file = source_code[:-2] + ".so"
|
138 |
+
|
139 |
+
if compile_parallel_kernel:
|
140 |
+
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(source_code, kernel_file)
|
141 |
+
print("Compiling", compile_command)
|
142 |
+
exit_state = os.system(compile_command)
|
143 |
+
if not exit_state:
|
144 |
+
try:
|
145 |
+
kernels = ctypes.cdll.LoadLibrary(kernel_file)
|
146 |
+
print("Load kernel :", kernel_file)
|
147 |
+
except:
|
148 |
+
kernels = None
|
149 |
+
print("Load parallel cpu kernel failed, using default cpu kernel code:")
|
150 |
+
import traceback
|
151 |
+
exception = traceback.format_exc()
|
152 |
+
print(exception)
|
153 |
+
else:
|
154 |
+
print("Compile default cpu kernel failed, using default cpu kernel code.")
|
155 |
+
|
156 |
+
if kernels is None: # adjust config, use default cpu kernel
|
157 |
+
compile_parallel_kernel = False
|
158 |
+
source_code = default_cpu_kernel_code_path
|
159 |
+
kernel_file = source_code[:-2] + ".so"
|
160 |
+
|
161 |
+
if kernels is None:
|
162 |
+
compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
|
163 |
+
print("Compiling", compile_command)
|
164 |
+
exit_state = os.system(compile_command)
|
165 |
+
if not exit_state:
|
166 |
+
try:
|
167 |
+
kernels = ctypes.cdll.LoadLibrary(kernel_file)
|
168 |
+
print("Load kernel :", kernel_file)
|
169 |
+
except:
|
170 |
+
kernels = None
|
171 |
+
print("Load default cpu kernel failed:")
|
172 |
+
import traceback
|
173 |
+
exception = traceback.format_exc()
|
174 |
+
print(exception)
|
175 |
+
else:
|
176 |
+
print("Compile default cpu kernel failed.")
|
177 |
+
else:
|
178 |
+
print("Kernel source code not found.")
|
179 |
+
return
|
180 |
+
except:
|
181 |
+
print("Failed to build cpu kernel:")
|
182 |
+
import traceback
|
183 |
+
exception = traceback.format_exc()
|
184 |
+
print(exception)
|
185 |
+
return
|
186 |
+
else:
|
187 |
+
try:
|
188 |
+
kernels = ctypes.cdll.LoadLibrary(kernel_file)
|
189 |
+
print("Load kernel :", kernel_file)
|
190 |
+
except:
|
191 |
+
kernels = None
|
192 |
+
print("Load custom cpu kernel failed:")
|
193 |
+
import traceback
|
194 |
+
exception = traceback.format_exc()
|
195 |
+
print(exception)
|
196 |
+
|
197 |
+
if kernels is not None:
|
198 |
+
self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float
|
199 |
+
self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float
|
200 |
+
self.int4WeightCompression = kernels.compress_int4_weight
|
201 |
+
if compile_parallel_kernel:
|
202 |
+
try:
|
203 |
+
self.SetNumThreads = kernels.set_num_threads
|
204 |
+
except:
|
205 |
+
print("No set_num_threads() found in kernel.")
|
206 |
+
self.load = True
|
207 |
+
else:
|
208 |
+
print("Failed to load kernel.")
|
209 |
+
return
|
210 |
+
|
211 |
+
if compile_parallel_kernel:
|
212 |
+
if parallel_num is None:
|
213 |
+
parallel_num = max(os.cpu_count() // 2, 1)
|
214 |
+
print("Setting CPU quantization kernel threads to", parallel_num)
|
215 |
+
if parallel_num < 4:
|
216 |
+
print("Parallel kernel is not recommended when parallel num < 4.")
|
217 |
+
self.SetNumThreads(parallel_num)
|
218 |
+
|
219 |
+
self.parallel_num = parallel_num
|
220 |
+
|
221 |
+
|
222 |
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
223 |
+
"""compress weight on cpu or cuda to int4"""
|
224 |
+
if weight.device == torch.device("cpu"):
|
225 |
+
assert isinstance(cpu_kernels, CPUKernel)
|
226 |
n, m = weight.size(0), weight.size(1)
|
227 |
assert m % 2 == 0
|
228 |
m = m // 2
|
229 |
+
out = torch.empty(n, m, dtype=torch.int8, device="cpu")
|
230 |
+
cpu_kernels.int4WeightCompression(
|
231 |
+
ctypes.c_void_p(weight.data_ptr()),
|
232 |
+
ctypes.c_void_p(out.data_ptr()),
|
233 |
+
ctypes.c_int32(n),
|
234 |
+
ctypes.c_int32(m)
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
)
|
236 |
return out
|
237 |
+
else:
|
238 |
+
with torch.cuda.device(weight.device):
|
239 |
+
n, m = weight.size(0), weight.size(1)
|
240 |
+
assert m % 2 == 0
|
241 |
+
m = m // 2
|
242 |
+
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
243 |
+
stream = torch.cuda.current_stream()
|
244 |
+
|
245 |
+
gridDim = (n, 1, 1)
|
246 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
247 |
+
|
248 |
+
kernels.int4WeightCompression(
|
249 |
+
gridDim,
|
250 |
+
blockDim,
|
251 |
+
0,
|
252 |
+
stream,
|
253 |
+
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
254 |
+
)
|
255 |
+
return out
|
256 |
|
257 |
|
258 |
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
|
|
287 |
return out
|
288 |
|
289 |
|
290 |
+
def extract_weight_to_float(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int, quantization_cache=None):
|
291 |
+
"""extract weight on cpu to float32"""
|
292 |
+
if source_bit_width == 8:
|
293 |
+
func = cpu_kernels.int8WeightExtractionFloat
|
294 |
+
elif source_bit_width == 4:
|
295 |
+
func = cpu_kernels.int4WeightExtractionFloat
|
296 |
+
else:
|
297 |
+
assert False, "Unsupported bit-width"
|
298 |
+
|
299 |
+
n, m = weight.size(0), weight.size(1)
|
300 |
+
|
301 |
+
if quantization_cache is not None:
|
302 |
+
out = quantization_cache
|
303 |
+
func(
|
304 |
+
ctypes.c_void_p(weight.data_ptr()),
|
305 |
+
ctypes.c_void_p(scale_list.data_ptr()),
|
306 |
+
ctypes.c_void_p(out.data_ptr()),
|
307 |
+
ctypes.c_int32(n),
|
308 |
+
ctypes.c_int32(m)
|
309 |
+
)
|
310 |
+
return out.tensor
|
311 |
+
else:
|
312 |
+
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.float, device="cpu")
|
313 |
+
func(
|
314 |
+
ctypes.c_void_p(weight.data_ptr()),
|
315 |
+
ctypes.c_void_p(scale_list.data_ptr()),
|
316 |
+
ctypes.c_void_p(out.data_ptr()),
|
317 |
+
ctypes.c_int32(n),
|
318 |
+
ctypes.c_int32(m)
|
319 |
+
)
|
320 |
+
return out
|
321 |
+
|
322 |
+
|
323 |
+
class CacheTensor():
|
324 |
+
def __init__(self, *args, **kwargs):
|
325 |
+
self.tensor = torch.empty(*args, **kwargs)
|
326 |
+
|
327 |
+
def to(self, *args, **kwargs):
|
328 |
+
self.tensor = self.tensor.to(*args, **kwargs)
|
329 |
+
|
330 |
+
def data_ptr(self):
|
331 |
+
return self.tensor.data_ptr()
|
332 |
+
|
333 |
+
|
334 |
class QuantizedLinear(Linear):
|
335 |
+
def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, quantized_weight=None, quantized_weight_scale=None, quantization_cache=None, empty_init=False, *args, **kwargs):
|
336 |
super(QuantizedLinear, self).__init__(*args, **kwargs)
|
337 |
self.weight_bit_width = weight_bit_width
|
338 |
+
self.quantization_cache = quantization_cache
|
339 |
|
340 |
+
if (quantized_weight is not None) and (quantized_weight_scale is not None):
|
341 |
+
del self.weight
|
342 |
+
self.weight = Parameter(quantized_weight.to(kwargs["device"]), requires_grad=False)
|
343 |
+
self.weight_scale = Parameter(quantized_weight_scale.to(kwargs["device"]), requires_grad=False)
|
|
|
|
|
|
|
|
|
344 |
else:
|
345 |
+
shape = self.weight.shape
|
346 |
+
del self.weight
|
347 |
+
|
348 |
+
if weight_tensor is None or empty_init:
|
349 |
+
self.weight = torch.empty(
|
350 |
+
shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
351 |
+
)
|
352 |
+
self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
|
353 |
+
else:
|
354 |
+
self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).to(kwargs["dtype"])
|
355 |
+
self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
|
356 |
+
if weight_bit_width == 4:
|
357 |
+
self.weight = compress_int4_weight(self.weight)
|
358 |
+
|
359 |
+
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
360 |
+
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
|
361 |
|
|
|
|
|
362 |
if bias_tensor is not None:
|
363 |
self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False)
|
364 |
else:
|
365 |
self.bias = None
|
366 |
|
367 |
+
def reset_parameters(self):
|
368 |
+
"""To accelerate initialization"""
|
369 |
+
pass
|
370 |
+
|
371 |
def forward(self, input):
|
372 |
+
if self.weight.device == torch.device("cpu"):
|
373 |
+
output = W8A16LinearCPU.apply(input, self.weight, self.weight_scale, self.weight_bit_width, self.quantization_cache)
|
374 |
+
else:
|
375 |
+
output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
|
376 |
if self.bias is not None:
|
377 |
output = output + self.bias
|
378 |
return output
|
379 |
|
380 |
+
def _apply(self, fn):
|
381 |
+
self_obj = super()._apply(fn)
|
382 |
+
if self.quantization_cache is not None:
|
383 |
+
self.quantization_cache.to(self_obj.weight.device)
|
384 |
+
self.quantization_cache.to(self_obj.weight_scale.dtype)
|
385 |
+
return self_obj
|
386 |
+
|
387 |
+
|
388 |
+
class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init
|
389 |
+
def __init__(self, weight_bit_width: int, weight_tensor=None, quantized_weight=None, quantized_weight_scale=None, empty_init=False, *args, **kwargs):
|
390 |
+
super(QuantizedEmbedding, self).__init__(*args, **kwargs)
|
391 |
+
self.weight_bit_width = weight_bit_width
|
392 |
+
|
393 |
+
if (quantized_weight is not None) and (quantized_weight_scale is not None):
|
394 |
+
del self.weight
|
395 |
+
self.weight = Parameter(quantized_weight.to(kwargs["device"]), requires_grad=False)
|
396 |
+
self.weight_scale = Parameter(quantized_weight_scale.to(kwargs["device"]), requires_grad=False)
|
397 |
+
else:
|
398 |
+
shape = self.weight.shape
|
399 |
+
del self.weight
|
400 |
+
|
401 |
+
if weight_tensor is None or empty_init:
|
402 |
+
self.weight = torch.empty(
|
403 |
+
shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
|
404 |
+
)
|
405 |
+
self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
|
406 |
+
else:
|
407 |
+
self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
408 |
+
self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
|
409 |
+
if weight_bit_width == 4:
|
410 |
+
self.weight = compress_int4_weight(self.weight)
|
411 |
+
|
412 |
+
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
413 |
+
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
|
414 |
+
|
415 |
+
def forward(self, input):
|
416 |
+
if self.weight.device == torch.device("cpu"):
|
417 |
+
original_weight = extract_weight_to_float(weight=self.weight, scale_list=self.weight_scale, source_bit_width=self.weight_bit_width)
|
418 |
+
else:
|
419 |
+
original_weight = extract_weight_to_half(weight=self.weight, scale_list=self.weight_scale, source_bit_width=self.weight_bit_width)
|
420 |
+
output = F.embedding(
|
421 |
+
input, original_weight, self.padding_idx, self.max_norm,
|
422 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse
|
423 |
+
)
|
424 |
+
return output
|
425 |
+
|
426 |
+
|
427 |
+
def load_cpu_kernel(**kwargs):
|
428 |
+
global cpu_kernels
|
429 |
+
cpu_kernels = CPUKernel(**kwargs)
|
430 |
+
assert cpu_kernels.load
|
431 |
+
|
432 |
|
433 |
+
def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
|
434 |
"""Replace fp16 linear with quantized linear"""
|
435 |
|
436 |
+
query_key_value_quantization_cache = None
|
437 |
+
dense_quantization_cache = None
|
438 |
+
dense_h_to_4h_quantization_cache = None
|
439 |
+
dense_4h_to_h_quantization_cache = None
|
440 |
+
|
441 |
+
try:
|
442 |
+
load_cpu_kernel(**kwargs)
|
443 |
+
except:
|
444 |
+
if kernels is None: # CUDA kernels failed
|
445 |
+
print("Cannot load cpu or cuda kernel, quantization failed:")
|
446 |
+
assert kernels is not None
|
447 |
+
print("Cannot load cpu kernel, don't use quantized model on cpu.")
|
448 |
+
|
449 |
+
current_device = model.device
|
450 |
+
|
451 |
+
if model.device == torch.device("cpu"):
|
452 |
+
dtype=torch.float32
|
453 |
+
else:
|
454 |
+
dtype = torch.half
|
455 |
+
|
456 |
+
QuantizedLinearWithPara = partial(
|
457 |
+
QuantizedLinear,
|
458 |
+
weight_bit_width=weight_bit_width,
|
459 |
+
bias=True,
|
460 |
+
dtype=dtype,
|
461 |
+
empty_init=empty_init
|
462 |
+
)
|
463 |
+
|
464 |
+
if use_quantization_cache:
|
465 |
+
print("Using quantization cache")
|
466 |
+
layer = model.layers[0]
|
467 |
+
weight = layer.attention.query_key_value.weight
|
468 |
+
n, m = weight.size(0), weight.size(1)
|
469 |
+
query_key_value_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
|
470 |
+
weight = layer.attention.dense.weight
|
471 |
+
n, m = weight.size(0), weight.size(1)
|
472 |
+
dense_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
|
473 |
+
weight = layer.mlp.dense_h_to_4h.weight
|
474 |
+
n, m = weight.size(0), weight.size(1)
|
475 |
+
dense_h_to_4h_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
|
476 |
+
weight = layer.mlp.dense_4h_to_h.weight
|
477 |
+
n, m = weight.size(0), weight.size(1)
|
478 |
+
dense_4h_to_h_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
|
479 |
+
|
480 |
+
print("Applying quantization to glm layers")
|
481 |
+
|
482 |
for layer in model.layers:
|
483 |
+
layer.attention.query_key_value = QuantizedLinearWithPara(
|
484 |
+
weight_tensor=layer.attention.query_key_value.weight.to(current_device),
|
|
|
485 |
bias_tensor=layer.attention.query_key_value.bias,
|
486 |
in_features=layer.attention.query_key_value.in_features,
|
487 |
out_features=layer.attention.query_key_value.out_features,
|
|
|
|
|
488 |
device=layer.attention.query_key_value.weight.device,
|
489 |
+
quantization_cache=query_key_value_quantization_cache
|
490 |
)
|
491 |
+
layer.attention.dense = QuantizedLinearWithPara(
|
492 |
+
weight_tensor=layer.attention.dense.weight.to(current_device),
|
|
|
493 |
bias_tensor=layer.attention.dense.bias,
|
494 |
in_features=layer.attention.dense.in_features,
|
495 |
out_features=layer.attention.dense.out_features,
|
|
|
|
|
496 |
device=layer.attention.dense.weight.device,
|
497 |
+
quantization_cache=dense_quantization_cache
|
498 |
)
|
499 |
+
layer.mlp.dense_h_to_4h = QuantizedLinearWithPara(
|
500 |
+
weight_tensor=layer.mlp.dense_h_to_4h.weight.to(current_device),
|
|
|
501 |
bias_tensor=layer.mlp.dense_h_to_4h.bias,
|
502 |
in_features=layer.mlp.dense_h_to_4h.in_features,
|
503 |
out_features=layer.mlp.dense_h_to_4h.out_features,
|
|
|
|
|
504 |
device=layer.mlp.dense_h_to_4h.weight.device,
|
505 |
+
quantization_cache=dense_h_to_4h_quantization_cache
|
506 |
)
|
507 |
+
layer.mlp.dense_4h_to_h = QuantizedLinearWithPara(
|
508 |
+
weight_tensor=layer.mlp.dense_4h_to_h.weight.to(current_device),
|
|
|
509 |
bias_tensor=layer.mlp.dense_4h_to_h.bias,
|
510 |
in_features=layer.mlp.dense_4h_to_h.in_features,
|
511 |
out_features=layer.mlp.dense_4h_to_h.out_features,
|
|
|
|
|
512 |
device=layer.mlp.dense_4h_to_h.weight.device,
|
513 |
+
quantization_cache=dense_4h_to_h_quantization_cache
|
514 |
)
|
515 |
return model
|