KumaTea commited on
Commit
1136001
1 Parent(s): 693e7b1

sync from upstream

Browse files
Files changed (3) hide show
  1. configuration_chatglm.py +2 -0
  2. modeling_chatglm.py +42 -6
  3. quantization.py +370 -56
configuration_chatglm.py CHANGED
@@ -73,6 +73,7 @@ class ChatGLMConfig(PretrainedConfig):
73
  inner_hidden_size=16384,
74
  position_encoding_2d=True,
75
  quantization_bit=0,
 
76
  pre_seq_len=None,
77
  prefix_projection=False,
78
  **kwargs
@@ -92,6 +93,7 @@ class ChatGLMConfig(PretrainedConfig):
92
  self.gmask_token_id = gmask_token_id
93
  self.position_encoding_2d = position_encoding_2d
94
  self.quantization_bit = quantization_bit
 
95
  self.pre_seq_len = pre_seq_len
96
  self.prefix_projection = prefix_projection
97
 
 
73
  inner_hidden_size=16384,
74
  position_encoding_2d=True,
75
  quantization_bit=0,
76
+ quantization_embeddings=False,
77
  pre_seq_len=None,
78
  prefix_projection=False,
79
  **kwargs
 
93
  self.gmask_token_id = gmask_token_id
94
  self.position_encoding_2d = position_encoding_2d
95
  self.quantization_bit = quantization_bit
96
+ self.quantization_embeddings = quantization_embeddings
97
  self.pre_seq_len = pre_seq_len
98
  self.prefix_projection = prefix_projection
99
 
modeling_chatglm.py CHANGED
@@ -32,6 +32,7 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
32
 
33
  from .configuration_chatglm import ChatGLMConfig
34
 
 
35
  # flags required to enable jit fusion kernels
36
 
37
  if sys.platform != 'darwin':
@@ -224,7 +225,6 @@ class RotaryEmbedding(torch.nn.Module):
224
  self.sin_cached = fn(self.sin_cached)
225
  return super()._apply(fn)
226
 
227
-
228
  def rotate_half(x):
229
  x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
230
  return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
@@ -1059,7 +1059,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1059
  self.quantized = False
1060
 
1061
  if self.config.quantization_bit:
1062
- self.quantize(self.config.quantization_bit, empty_init=True)
1063
 
1064
  def get_output_embeddings(self):
1065
  return self.lm_head
@@ -1418,19 +1418,55 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1418
  break
1419
  yield input_ids
1420
 
1421
- def quantize(self, bits: int, empty_init=False, **kwargs):
1422
  if bits == 0:
1423
  return
1424
 
1425
- from .quantization import quantize
1426
 
1427
  if self.quantized:
1428
- logger.info("Already quantized.")
 
 
 
 
1429
  return self
1430
 
1431
  self.quantized = True
1432
 
1433
  self.config.quantization_bit = bits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1434
 
1435
- self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
1436
  return self
 
32
 
33
  from .configuration_chatglm import ChatGLMConfig
34
 
35
+
36
  # flags required to enable jit fusion kernels
37
 
38
  if sys.platform != 'darwin':
 
225
  self.sin_cached = fn(self.sin_cached)
226
  return super()._apply(fn)
227
 
 
228
  def rotate_half(x):
229
  x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
230
  return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
 
1059
  self.quantized = False
1060
 
1061
  if self.config.quantization_bit:
1062
+ self.quantize(self.config.quantization_bit, self.config.quantization_embeddings, use_quantization_cache=True, empty_init=True)
1063
 
1064
  def get_output_embeddings(self):
1065
  return self.lm_head
 
1418
  break
1419
  yield input_ids
1420
 
1421
+ def quantize(self, bits: int, quantize_embeddings=False, use_quantization_cache=False, empty_init=False, **kwargs):
1422
  if bits == 0:
1423
  return
1424
 
1425
+ from .quantization import quantize, QuantizedEmbedding, QuantizedLinear, load_cpu_kernel
1426
 
1427
  if self.quantized:
1428
+ if self.device == torch.device("cpu"):
1429
+ logger.info("Already quantized, reloading cpu kernel.")
1430
+ load_cpu_kernel(**kwargs)
1431
+ else:
1432
+ logger.info("Already quantized.")
1433
  return self
1434
 
1435
  self.quantized = True
1436
 
1437
  self.config.quantization_bit = bits
1438
+ self.config.quantization_embeddings = quantize_embeddings
1439
+
1440
+ self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
1441
+
1442
+ if self.device == torch.device("cpu"):
1443
+ dtype = torch.float32
1444
+ else:
1445
+ dtype = torch.half
1446
+
1447
+ if quantize_embeddings:
1448
+ logger.info("Applying quantization to embeddings")
1449
+ self.transformer.word_embeddings = QuantizedEmbedding(
1450
+ weight_bit_width=bits,
1451
+ weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
1452
+ num_embeddings=self.transformer.word_embeddings.num_embeddings,
1453
+ embedding_dim=self.transformer.word_embeddings.embedding_dim,
1454
+ dtype=dtype,
1455
+ empty_init=empty_init,
1456
+ device=self.transformer.word_embeddings.weight.device,
1457
+ )
1458
+ self.lm_head = QuantizedLinear(
1459
+ weight_bit_width=bits,
1460
+ weight_tensor=self.lm_head.weight.to(self.device),
1461
+ bias_tensor=None,
1462
+ in_features=self.lm_head.in_features,
1463
+ out_features=self.lm_head.out_features,
1464
+ bias=False,
1465
+ quantized_weight=self.transformer.word_embeddings.weight,
1466
+ quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
1467
+ dtype=dtype,
1468
+ empty_init=empty_init,
1469
+ device=self.lm_head.weight.device,
1470
+ )
1471
 
 
1472
  return self
quantization.py CHANGED
@@ -1,6 +1,8 @@
1
- from torch.nn import Linear
2
  from torch.nn.parameter import Parameter
 
3
 
 
4
  import bz2
5
  import torch
6
  import base64
@@ -38,7 +40,7 @@ try:
38
  )
39
  except Exception as exception:
40
  kernels = None
41
- logger.warning("Failed to load cpm_kernels:" + str(exception))
42
 
43
 
44
  class W8A16Linear(torch.autograd.Function):
@@ -64,25 +66,193 @@ class W8A16Linear(torch.autograd.Function):
64
  return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def compress_int4_weight(weight: torch.Tensor): # (n, m)
68
- with torch.cuda.device(weight.device):
 
 
69
  n, m = weight.size(0), weight.size(1)
70
  assert m % 2 == 0
71
  m = m // 2
72
- out = torch.empty(n, m, dtype=torch.int8, device="cuda")
73
- stream = torch.cuda.current_stream()
74
-
75
- gridDim = (n, 1, 1)
76
- blockDim = (min(round_up(m, 32), 1024), 1, 1)
77
-
78
- kernels.int4WeightCompression(
79
- gridDim,
80
- blockDim,
81
- 0,
82
- stream,
83
- [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
84
  )
85
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
@@ -117,85 +287,229 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc
117
  return out
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  class QuantizedLinear(Linear):
121
- def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs):
122
  super(QuantizedLinear, self).__init__(*args, **kwargs)
123
  self.weight_bit_width = weight_bit_width
 
124
 
125
- shape = self.weight.shape
126
- del self.weight
127
-
128
- if weight_tensor is None or empty_init:
129
- self.weight = torch.empty(
130
- shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
131
- )
132
- self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
133
  else:
134
- self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
135
- self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
136
- if weight_bit_width == 4:
137
- self.weight = compress_int4_weight(self.weight)
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
140
- self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
141
  if bias_tensor is not None:
142
  self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False)
143
  else:
144
  self.bias = None
145
 
 
 
 
 
146
  def forward(self, input):
147
- output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
 
 
 
148
  if self.bias is not None:
149
  output = output + self.bias
150
  return output
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- def quantize(model, weight_bit_width, empty_init=False, **kwargs):
154
  """Replace fp16 linear with quantized linear"""
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  for layer in model.layers:
157
- layer.attention.query_key_value = QuantizedLinear(
158
- weight_bit_width=weight_bit_width,
159
- weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),
160
  bias_tensor=layer.attention.query_key_value.bias,
161
  in_features=layer.attention.query_key_value.in_features,
162
  out_features=layer.attention.query_key_value.out_features,
163
- bias=True,
164
- dtype=torch.half,
165
  device=layer.attention.query_key_value.weight.device,
166
- empty_init=empty_init
167
  )
168
- layer.attention.dense = QuantizedLinear(
169
- weight_bit_width=weight_bit_width,
170
- weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()),
171
  bias_tensor=layer.attention.dense.bias,
172
  in_features=layer.attention.dense.in_features,
173
  out_features=layer.attention.dense.out_features,
174
- bias=True,
175
- dtype=torch.half,
176
  device=layer.attention.dense.weight.device,
177
- empty_init=empty_init
178
  )
179
- layer.mlp.dense_h_to_4h = QuantizedLinear(
180
- weight_bit_width=weight_bit_width,
181
- weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
182
  bias_tensor=layer.mlp.dense_h_to_4h.bias,
183
  in_features=layer.mlp.dense_h_to_4h.in_features,
184
  out_features=layer.mlp.dense_h_to_4h.out_features,
185
- bias=True,
186
- dtype=torch.half,
187
  device=layer.mlp.dense_h_to_4h.weight.device,
188
- empty_init=empty_init
189
  )
190
- layer.mlp.dense_4h_to_h = QuantizedLinear(
191
- weight_bit_width=weight_bit_width,
192
- weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
193
  bias_tensor=layer.mlp.dense_4h_to_h.bias,
194
  in_features=layer.mlp.dense_4h_to_h.in_features,
195
  out_features=layer.mlp.dense_4h_to_h.out_features,
196
- bias=True,
197
- dtype=torch.half,
198
  device=layer.mlp.dense_4h_to_h.weight.device,
199
- empty_init=empty_init
200
  )
201
  return model
 
1
+ from torch.nn import Linear, Embedding
2
  from torch.nn.parameter import Parameter
3
+ import torch.nn.functional as F
4
 
5
+ import os
6
  import bz2
7
  import torch
8
  import base64
 
40
  )
41
  except Exception as exception:
42
  kernels = None
43
+ logger.warning("Failed to load cpm_kernels:", exception)
44
 
45
 
46
  class W8A16Linear(torch.autograd.Function):
 
66
  return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
67
 
68
 
69
+ class W8A16LinearCPU(torch.autograd.Function):
70
+ @staticmethod
71
+ def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
72
+ ctx.inp_shape = inp.size()
73
+ ctx.weight_bit_width = weight_bit_width
74
+ out_features = quant_w.size(0)
75
+ inp = inp.contiguous().view(-1, inp.size(-1))
76
+ weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
77
+ ctx.weight_shape = weight.size()
78
+ output = inp.mm(weight.t())
79
+ ctx.save_for_backward(inp, quant_w, scale_w)
80
+ return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad_output: torch.Tensor):
84
+ inp, quant_w, scale_w = ctx.saved_tensors
85
+ weight = extract_weight_to_float(quant_w, scale_w, ctx.weight_bit_width)
86
+ grad_output = grad_output.contiguous().view(-1, weight.size(0))
87
+ grad_input = grad_output.mm(weight)
88
+ grad_weight = grad_output.t().mm(inp)
89
+ return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
90
+
91
+
92
+ default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
93
+ default_cpu_kernel_code = "QlpoOTFBWSZTWXLbSoQAAgzbgERwQXxmTwAAr/ff3kABt0Q2oRVT0hpo9RtEAAAAyBEiSQ9EGjQGQAAAwANGhowjJoNGmgMEUplMTNSMJ5TQaDJpsoMyRMj8P4mZzFSVVwqSXG8GG7MlVwiToYEQwVD7noBxMhNfkeZYtYFtbgOBUSIGtIQjhNHCEnPJsadhb3yBmRIOD3TeAtNLSaU5GgvKUBWSNuuOIHmVt0YhW6rsmDMDUjeUJGJ64R1Jm5lrh0Aa0tKjhFwPdWcGogxLDSXPWQUWTM8Sd3Qz1HMYNxx3HMeiNqNo4jeRDEfZ3gUSHIcU/heomq0vEzL1Msz5KKGxH8FrNOYw3KaxdqaEmNHYMxJFgQbR0DyRknL2L4kwUSxKRdhjRpEtUqilVfggFL1klaMS3PPRDfNqbBOPWO7m4JTVGhS9QTBDDJaEbLbrUQNB+IpJSKQbG5SZZ5gkwJEhJ3aYKJipZ/i7kinChIOW2lQg"
94
+ default_cpu_parallel_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels_parallel.c")
95
+ default_cpu_parallel_kernel_code = "QlpoOTFBWSZTWUzax5EAALXbgERwSX1mTwAAr/ff3kACNyXUbZYwBpoaNGIyAaADQwRSaVP9QoMg0A2oAPU0AEUkU9GaaKMaQB6gA09T1ARRKnpk0niaJkaaNDJ6g0DTIKVKfZ/g6v1Kem5LJLa0WmkukkuCIHUqWbtJGJMsCSQFiPEIYHgBIZDzR8R6REbYxIqD2Cu7lMkFoPu6LmHeOAy0GF83Tc40jgmTs4HnCe60QfJa2bDBZ0Y1lhgbiZjW8SNsAKCk42UOEdjWN3KoiCIYeQUCCKWIyHewhtSoInLKSG22l4jKM2ZDCVKtBm3OTYBl3jsVqMImtj7PQw7xKxLXQzwgJaPPgW1fRhrvPJICl4YFDYfNbkbBh5JDgrazFml50xEQQwQUjxNwE0IDSofLzSg7UNVKn+Rr1KErzBHUxBqdHRlXzqYsIa5K9Y0UuE2ugw3g5KYofm7AaGNTzJSMhcchhxdaU4JZ0F1UNgQ8XcGDguypqYza8yFaEoGgNRcLej+g2t0feGKFE5OY2PFluQ3q4HgycxlfvzHqo0KcM0JI8OKXtzayJFgsqC1NdUQVu8rChnA6FO3MFyGOoC9KO8ITPpYM5pRqTlczFkLES/4u5IpwoSCZtY8i"
96
+
97
+ cpu_kernels = None
98
+
99
+
100
+ class CPUKernel:
101
+ def __init__(self, kernel_file="", source_code=default_cpu_kernel_code_path, compile_parallel_kernel=None, parallel_num=None):
102
+ self.load =False
103
+ self.int8WeightExtractionFloat = None
104
+ self.int4WeightExtractionFloat = None
105
+ self.int4WeightCompression = None
106
+ self.SetNumThreads = lambda x: x
107
+
108
+ try:
109
+ if not os.path.exists(default_cpu_kernel_code_path):
110
+ with open(default_cpu_kernel_code_path, "w", encoding="utf-8") as file:
111
+ code = default_cpu_kernel_code
112
+ cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode()
113
+ file.write(cpu_quantization_code)
114
+
115
+ if not os.path.exists(default_cpu_parallel_kernel_code_path):
116
+ with open(default_cpu_parallel_kernel_code_path, "w", encoding="utf-8") as file:
117
+ code = default_cpu_parallel_kernel_code
118
+ cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode()
119
+ file.write(cpu_quantization_code)
120
+
121
+ except Exception as ex:
122
+ print("Error when generating default cpu kernel code(can be ignored when using custom kernels).")
123
+
124
+ if compile_parallel_kernel is None:
125
+ compile_parallel_kernel = bool(int(os.cpu_count()) >= 4)
126
+
127
+ if compile_parallel_kernel and source_code == default_cpu_kernel_code_path:
128
+ source_code = default_cpu_parallel_kernel_code_path
129
+
130
+ kernels = None
131
+
132
+ if (not kernel_file) or (not os.path.exists(kernel_file)):
133
+ print("No compiled kernel found.")
134
+ try:
135
+ if os.path.exists(source_code):
136
+ print("Compiling kernels :", source_code)
137
+ kernel_file = source_code[:-2] + ".so"
138
+
139
+ if compile_parallel_kernel:
140
+ compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(source_code, kernel_file)
141
+ print("Compiling", compile_command)
142
+ exit_state = os.system(compile_command)
143
+ if not exit_state:
144
+ try:
145
+ kernels = ctypes.cdll.LoadLibrary(kernel_file)
146
+ print("Load kernel :", kernel_file)
147
+ except:
148
+ kernels = None
149
+ print("Load parallel cpu kernel failed, using default cpu kernel code:")
150
+ import traceback
151
+ exception = traceback.format_exc()
152
+ print(exception)
153
+ else:
154
+ print("Compile default cpu kernel failed, using default cpu kernel code.")
155
+
156
+ if kernels is None: # adjust config, use default cpu kernel
157
+ compile_parallel_kernel = False
158
+ source_code = default_cpu_kernel_code_path
159
+ kernel_file = source_code[:-2] + ".so"
160
+
161
+ if kernels is None:
162
+ compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
163
+ print("Compiling", compile_command)
164
+ exit_state = os.system(compile_command)
165
+ if not exit_state:
166
+ try:
167
+ kernels = ctypes.cdll.LoadLibrary(kernel_file)
168
+ print("Load kernel :", kernel_file)
169
+ except:
170
+ kernels = None
171
+ print("Load default cpu kernel failed:")
172
+ import traceback
173
+ exception = traceback.format_exc()
174
+ print(exception)
175
+ else:
176
+ print("Compile default cpu kernel failed.")
177
+ else:
178
+ print("Kernel source code not found.")
179
+ return
180
+ except:
181
+ print("Failed to build cpu kernel:")
182
+ import traceback
183
+ exception = traceback.format_exc()
184
+ print(exception)
185
+ return
186
+ else:
187
+ try:
188
+ kernels = ctypes.cdll.LoadLibrary(kernel_file)
189
+ print("Load kernel :", kernel_file)
190
+ except:
191
+ kernels = None
192
+ print("Load custom cpu kernel failed:")
193
+ import traceback
194
+ exception = traceback.format_exc()
195
+ print(exception)
196
+
197
+ if kernels is not None:
198
+ self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float
199
+ self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float
200
+ self.int4WeightCompression = kernels.compress_int4_weight
201
+ if compile_parallel_kernel:
202
+ try:
203
+ self.SetNumThreads = kernels.set_num_threads
204
+ except:
205
+ print("No set_num_threads() found in kernel.")
206
+ self.load = True
207
+ else:
208
+ print("Failed to load kernel.")
209
+ return
210
+
211
+ if compile_parallel_kernel:
212
+ if parallel_num is None:
213
+ parallel_num = max(os.cpu_count() // 2, 1)
214
+ print("Setting CPU quantization kernel threads to", parallel_num)
215
+ if parallel_num < 4:
216
+ print("Parallel kernel is not recommended when parallel num < 4.")
217
+ self.SetNumThreads(parallel_num)
218
+
219
+ self.parallel_num = parallel_num
220
+
221
+
222
  def compress_int4_weight(weight: torch.Tensor): # (n, m)
223
+ """compress weight on cpu or cuda to int4"""
224
+ if weight.device == torch.device("cpu"):
225
+ assert isinstance(cpu_kernels, CPUKernel)
226
  n, m = weight.size(0), weight.size(1)
227
  assert m % 2 == 0
228
  m = m // 2
229
+ out = torch.empty(n, m, dtype=torch.int8, device="cpu")
230
+ cpu_kernels.int4WeightCompression(
231
+ ctypes.c_void_p(weight.data_ptr()),
232
+ ctypes.c_void_p(out.data_ptr()),
233
+ ctypes.c_int32(n),
234
+ ctypes.c_int32(m)
 
 
 
 
 
 
235
  )
236
  return out
237
+ else:
238
+ with torch.cuda.device(weight.device):
239
+ n, m = weight.size(0), weight.size(1)
240
+ assert m % 2 == 0
241
+ m = m // 2
242
+ out = torch.empty(n, m, dtype=torch.int8, device="cuda")
243
+ stream = torch.cuda.current_stream()
244
+
245
+ gridDim = (n, 1, 1)
246
+ blockDim = (min(round_up(m, 32), 1024), 1, 1)
247
+
248
+ kernels.int4WeightCompression(
249
+ gridDim,
250
+ blockDim,
251
+ 0,
252
+ stream,
253
+ [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
254
+ )
255
+ return out
256
 
257
 
258
  def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
 
287
  return out
288
 
289
 
290
+ def extract_weight_to_float(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int, quantization_cache=None):
291
+ """extract weight on cpu to float32"""
292
+ if source_bit_width == 8:
293
+ func = cpu_kernels.int8WeightExtractionFloat
294
+ elif source_bit_width == 4:
295
+ func = cpu_kernels.int4WeightExtractionFloat
296
+ else:
297
+ assert False, "Unsupported bit-width"
298
+
299
+ n, m = weight.size(0), weight.size(1)
300
+
301
+ if quantization_cache is not None:
302
+ out = quantization_cache
303
+ func(
304
+ ctypes.c_void_p(weight.data_ptr()),
305
+ ctypes.c_void_p(scale_list.data_ptr()),
306
+ ctypes.c_void_p(out.data_ptr()),
307
+ ctypes.c_int32(n),
308
+ ctypes.c_int32(m)
309
+ )
310
+ return out.tensor
311
+ else:
312
+ out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.float, device="cpu")
313
+ func(
314
+ ctypes.c_void_p(weight.data_ptr()),
315
+ ctypes.c_void_p(scale_list.data_ptr()),
316
+ ctypes.c_void_p(out.data_ptr()),
317
+ ctypes.c_int32(n),
318
+ ctypes.c_int32(m)
319
+ )
320
+ return out
321
+
322
+
323
+ class CacheTensor():
324
+ def __init__(self, *args, **kwargs):
325
+ self.tensor = torch.empty(*args, **kwargs)
326
+
327
+ def to(self, *args, **kwargs):
328
+ self.tensor = self.tensor.to(*args, **kwargs)
329
+
330
+ def data_ptr(self):
331
+ return self.tensor.data_ptr()
332
+
333
+
334
  class QuantizedLinear(Linear):
335
+ def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, quantized_weight=None, quantized_weight_scale=None, quantization_cache=None, empty_init=False, *args, **kwargs):
336
  super(QuantizedLinear, self).__init__(*args, **kwargs)
337
  self.weight_bit_width = weight_bit_width
338
+ self.quantization_cache = quantization_cache
339
 
340
+ if (quantized_weight is not None) and (quantized_weight_scale is not None):
341
+ del self.weight
342
+ self.weight = Parameter(quantized_weight.to(kwargs["device"]), requires_grad=False)
343
+ self.weight_scale = Parameter(quantized_weight_scale.to(kwargs["device"]), requires_grad=False)
 
 
 
 
344
  else:
345
+ shape = self.weight.shape
346
+ del self.weight
347
+
348
+ if weight_tensor is None or empty_init:
349
+ self.weight = torch.empty(
350
+ shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
351
+ )
352
+ self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
353
+ else:
354
+ self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).to(kwargs["dtype"])
355
+ self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
356
+ if weight_bit_width == 4:
357
+ self.weight = compress_int4_weight(self.weight)
358
+
359
+ self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
360
+ self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
361
 
 
 
362
  if bias_tensor is not None:
363
  self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False)
364
  else:
365
  self.bias = None
366
 
367
+ def reset_parameters(self):
368
+ """To accelerate initialization"""
369
+ pass
370
+
371
  def forward(self, input):
372
+ if self.weight.device == torch.device("cpu"):
373
+ output = W8A16LinearCPU.apply(input, self.weight, self.weight_scale, self.weight_bit_width, self.quantization_cache)
374
+ else:
375
+ output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
376
  if self.bias is not None:
377
  output = output + self.bias
378
  return output
379
 
380
+ def _apply(self, fn):
381
+ self_obj = super()._apply(fn)
382
+ if self.quantization_cache is not None:
383
+ self.quantization_cache.to(self_obj.weight.device)
384
+ self.quantization_cache.to(self_obj.weight_scale.dtype)
385
+ return self_obj
386
+
387
+
388
+ class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init
389
+ def __init__(self, weight_bit_width: int, weight_tensor=None, quantized_weight=None, quantized_weight_scale=None, empty_init=False, *args, **kwargs):
390
+ super(QuantizedEmbedding, self).__init__(*args, **kwargs)
391
+ self.weight_bit_width = weight_bit_width
392
+
393
+ if (quantized_weight is not None) and (quantized_weight_scale is not None):
394
+ del self.weight
395
+ self.weight = Parameter(quantized_weight.to(kwargs["device"]), requires_grad=False)
396
+ self.weight_scale = Parameter(quantized_weight_scale.to(kwargs["device"]), requires_grad=False)
397
+ else:
398
+ shape = self.weight.shape
399
+ del self.weight
400
+
401
+ if weight_tensor is None or empty_init:
402
+ self.weight = torch.empty(
403
+ shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
404
+ )
405
+ self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
406
+ else:
407
+ self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
408
+ self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
409
+ if weight_bit_width == 4:
410
+ self.weight = compress_int4_weight(self.weight)
411
+
412
+ self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
413
+ self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
414
+
415
+ def forward(self, input):
416
+ if self.weight.device == torch.device("cpu"):
417
+ original_weight = extract_weight_to_float(weight=self.weight, scale_list=self.weight_scale, source_bit_width=self.weight_bit_width)
418
+ else:
419
+ original_weight = extract_weight_to_half(weight=self.weight, scale_list=self.weight_scale, source_bit_width=self.weight_bit_width)
420
+ output = F.embedding(
421
+ input, original_weight, self.padding_idx, self.max_norm,
422
+ self.norm_type, self.scale_grad_by_freq, self.sparse
423
+ )
424
+ return output
425
+
426
+
427
+ def load_cpu_kernel(**kwargs):
428
+ global cpu_kernels
429
+ cpu_kernels = CPUKernel(**kwargs)
430
+ assert cpu_kernels.load
431
+
432
 
433
+ def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
434
  """Replace fp16 linear with quantized linear"""
435
 
436
+ query_key_value_quantization_cache = None
437
+ dense_quantization_cache = None
438
+ dense_h_to_4h_quantization_cache = None
439
+ dense_4h_to_h_quantization_cache = None
440
+
441
+ try:
442
+ load_cpu_kernel(**kwargs)
443
+ except:
444
+ if kernels is None: # CUDA kernels failed
445
+ print("Cannot load cpu or cuda kernel, quantization failed:")
446
+ assert kernels is not None
447
+ print("Cannot load cpu kernel, don't use quantized model on cpu.")
448
+
449
+ current_device = model.device
450
+
451
+ if model.device == torch.device("cpu"):
452
+ dtype=torch.float32
453
+ else:
454
+ dtype = torch.half
455
+
456
+ QuantizedLinearWithPara = partial(
457
+ QuantizedLinear,
458
+ weight_bit_width=weight_bit_width,
459
+ bias=True,
460
+ dtype=dtype,
461
+ empty_init=empty_init
462
+ )
463
+
464
+ if use_quantization_cache:
465
+ print("Using quantization cache")
466
+ layer = model.layers[0]
467
+ weight = layer.attention.query_key_value.weight
468
+ n, m = weight.size(0), weight.size(1)
469
+ query_key_value_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
470
+ weight = layer.attention.dense.weight
471
+ n, m = weight.size(0), weight.size(1)
472
+ dense_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
473
+ weight = layer.mlp.dense_h_to_4h.weight
474
+ n, m = weight.size(0), weight.size(1)
475
+ dense_h_to_4h_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
476
+ weight = layer.mlp.dense_4h_to_h.weight
477
+ n, m = weight.size(0), weight.size(1)
478
+ dense_4h_to_h_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False)
479
+
480
+ print("Applying quantization to glm layers")
481
+
482
  for layer in model.layers:
483
+ layer.attention.query_key_value = QuantizedLinearWithPara(
484
+ weight_tensor=layer.attention.query_key_value.weight.to(current_device),
 
485
  bias_tensor=layer.attention.query_key_value.bias,
486
  in_features=layer.attention.query_key_value.in_features,
487
  out_features=layer.attention.query_key_value.out_features,
 
 
488
  device=layer.attention.query_key_value.weight.device,
489
+ quantization_cache=query_key_value_quantization_cache
490
  )
491
+ layer.attention.dense = QuantizedLinearWithPara(
492
+ weight_tensor=layer.attention.dense.weight.to(current_device),
 
493
  bias_tensor=layer.attention.dense.bias,
494
  in_features=layer.attention.dense.in_features,
495
  out_features=layer.attention.dense.out_features,
 
 
496
  device=layer.attention.dense.weight.device,
497
+ quantization_cache=dense_quantization_cache
498
  )
499
+ layer.mlp.dense_h_to_4h = QuantizedLinearWithPara(
500
+ weight_tensor=layer.mlp.dense_h_to_4h.weight.to(current_device),
 
501
  bias_tensor=layer.mlp.dense_h_to_4h.bias,
502
  in_features=layer.mlp.dense_h_to_4h.in_features,
503
  out_features=layer.mlp.dense_h_to_4h.out_features,
 
 
504
  device=layer.mlp.dense_h_to_4h.weight.device,
505
+ quantization_cache=dense_h_to_4h_quantization_cache
506
  )
507
+ layer.mlp.dense_4h_to_h = QuantizedLinearWithPara(
508
+ weight_tensor=layer.mlp.dense_4h_to_h.weight.to(current_device),
 
509
  bias_tensor=layer.mlp.dense_4h_to_h.bias,
510
  in_features=layer.mlp.dense_4h_to_h.in_features,
511
  out_features=layer.mlp.dense_4h_to_h.out_features,
 
 
512
  device=layer.mlp.dense_4h_to_h.weight.device,
513
+ quantization_cache=dense_4h_to_h_quantization_cache
514
  )
515
  return model