File size: 13,473 Bytes
f6086aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""
In this mixin, I use a different implementation than sat/model/finetune/lora.py
I just use a fake linear layer to replace any model with lora mixin.
"""

import torch
import torch.nn as nn
from sat.model.base_model import BaseMixin
import math
from sat.helpers import print_all
from sat.model.transformer import RowParallelLinear, ColumnParallelLinear

class HackLinear(nn.Linear):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        if prefix + 'weight' in state_dict:
            self.weight.data.copy_(state_dict[prefix+'weight'])
        if prefix + 'bias' in state_dict:
            self.bias.data.copy_(state_dict[prefix+'bias'])

class HackRowParallelLinear(RowParallelLinear):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        if prefix + 'weight' in state_dict:
            self.weight.data.copy_(state_dict[prefix+'weight'])
        if prefix + 'bias' in state_dict:
            self.bias.data.copy_(state_dict[prefix+'bias'])

class HackColumnParallelLinear(ColumnParallelLinear):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        if prefix + 'weight' in state_dict:
            self.weight.data.copy_(state_dict[prefix+'weight'])
        if prefix + 'bias' in state_dict:
            self.bias.data.copy_(state_dict[prefix+'bias'])

try:
    from bitsandbytes.nn import LinearNF4
    def copy_nested_list(src, dst):
        for i in range(len(dst)):
            if type(dst[i]) is torch.Tensor:
                dst[i].copy_(src[i])
            elif type(dst[i]) is list:
                copy_nested_list(src[i], dst[i])
            else:
                dst[i] = src[i]
    class HackLinearNF4(LinearNF4):
        def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
            if prefix + 'weight' in state_dict:
                self.weight.data.copy_(state_dict[prefix+'weight'])
                if self.weight.data.dtype == torch.uint8:
                    copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)
            if prefix + 'bias' in state_dict:
                self.bias.data.copy_(state_dict[prefix+'bias'])
        def _save_to_state_dict(self, destination, prefix, keep_vars):
            super()._save_to_state_dict(destination, prefix, keep_vars)
            destination[prefix+'quant_state'] = self.weight.quant_state
except Exception as exception:
    print_all("Failed to load bitsandbytes:" + str(exception), level='WARNING')


class HackParameterList(nn.ParameterList):
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        for i in range(len(self)):
            if prefix + str(i) in state_dict:
                self[i].data.copy_(state_dict[prefix+str(i)])

map_cls = {
    nn.Linear: (HackLinear, {}),
    ColumnParallelLinear: (HackColumnParallelLinear, {'gather_output': False}),
    RowParallelLinear: (HackRowParallelLinear, {'input_is_parallel': True})
}

class LoraLinear(nn.Module):
    def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., head_first=False, num_attention_heads=None, hidden_size_per_attention_head=None, qlora=False):
        """
        You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order.
        If you use a different order like ChatGLM
        """
        super().__init__()
        if lora_dropout and lora_dropout > 0:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = self.lora_alpha / self.r
        if qlora:
            try:
                self.original = HackLinearNF4(in_dim, out_dim)
            except:
                raise Exception('Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.')
        else:
            base_cls, kwargs = map_cls[original_cls]
            self.original = base_cls(in_dim, out_dim, **kwargs)
        self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, in_dim))) for _ in range(partition)])
        self.matrix_B = HackParameterList([nn.Parameter(torch.empty((out_dim // partition, r))) for _ in range(partition)])
        for i in range(partition):
            nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
            nn.init.zeros_(self.matrix_B[i])
        self.head_first = head_first
        self.partition = partition
        if head_first:
            assert num_attention_heads is not None and hidden_size_per_attention_head is not None, "You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!"
            self.num_attention_heads = num_attention_heads
            self.hidden_size_per_attention_head = hidden_size_per_attention_head

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        # This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
        if prefix + 'weight' in state_dict:
            # load from normal Linear
            self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
        else:
            # load from LoraLinear
            super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
            
    def forward(self, x):
        mixed_raw_layer = self.original(x)
        lora_outputs = []
        for i in range(self.partition):
            lora_outputs.append((self.lora_dropout(x) @ self.matrix_A[i].T @ self.matrix_B[i].T) * self.scaling)
        if self.head_first:
            new_tensor_shape = lora_outputs[0].size()[:-1] + (
                self.num_attention_heads,
                self.hidden_size_per_attention_head,
            )
            for i in range(self.partition):
                lora_outputs[i] = lora_outputs[i].view(*new_tensor_shape)
            mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1).view(*mixed_raw_layer.size())
        else:
            mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1)

        return mixed_raw_layer


def replace_linear_with_lora(lin, partition, r, *args, **kw_args):
    # not supported for linear without bias for now
    out_dim, in_dim = lin.weight.shape
    original_cls = type(lin)
    del lin
    return LoraLinear(original_cls, partition, in_dim, out_dim, r, *args, **kw_args)

def merge_linear_lora(lin):
    if lin.original.weight.data.dtype is not torch.uint8:
        weight = lin.original.weight
        out_dim, in_dim = weight.shape
        new_lin = nn.Linear(in_dim, out_dim)
    else:
        import bitsandbytes.functional as F
        weight = F.dequantize_fp4(lin.original.weight.data, lin.original.weight.quant_state).to(lin.original.bias.data.dtype)
        out_dim, in_dim = weight.shape
        new_lin = HackLinearNF4(in_dim, out_dim)
    new_lin.bias.data = lin.original.bias.data
    new_qkv = []
    for i in range(lin.partition):
        new_qkv.append(lin.matrix_A[i].data.T.float() @ lin.matrix_B[i].data.T.float() * lin.scaling)
    if lin.head_first:
        ini_shape = new_qkv[0].shape
        new_qkv = [x.view(ini_shape[0], lin.num_attention_heads, -1) for x in new_qkv]
        new_qkv = torch.cat(new_qkv, -1).view(ini_shape[0], lin.partition*ini_shape[1])
    else:
        new_qkv = torch.cat(new_qkv, -1)
    new_lin.weight.data = weight + new_qkv.T.to(lin.original.bias.data.dtype)
    return new_lin.cuda() if torch.cuda.is_available() else new_lin

class LoraMixin(BaseMixin):
    def __init__(self, 
                layer_num,
                r: int = 0,
                lora_alpha: int = 1,
                lora_dropout: float = 0.,
                layer_range = None,
                head_first = False,
                num_attention_heads = None,
                hidden_size_per_attention_head = None,
                qlora = False,
                cross_attention = True):
        super().__init__()
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout

        if layer_range is None:
            layer_range = [i for i in range(layer_num)]
        self.layer_range = layer_range

        self.scaling = self.lora_alpha / self.r
        self.head_first = head_first
        self.num_attention_heads = num_attention_heads
        self.hidden_size_per_attention_head = hidden_size_per_attention_head
        self.qlora = qlora
        self.cross_attention = cross_attention

    def reinit(self, parent_model):
        for i in self.layer_range:
            print(f'replacing layer {i} attention with lora')
            parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
            parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(parent_model.transformer.layers[i].attention.query_key_value, 3, self.r, self.lora_alpha, self.lora_dropout, head_first=self.head_first, num_attention_heads=self.num_attention_heads, hidden_size_per_attention_head=self.hidden_size_per_attention_head, qlora=self.qlora)
            if self.cross_attention and parent_model.transformer.layers[i].is_decoder:
                print(f'replacing layer {i} cross attention with lora')
                parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
                parent_model.transformer.layers[i].cross_attention.query = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.query, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
                parent_model.transformer.layers[i].cross_attention.key_value = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.key_value, 2, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora)
        if self.qlora:
            print('replacing chatglm linear layer with 4bit')
            def replace_linear_with_nf4(model, name=None, cache={}):
                if type(model) in (nn.Linear, RowParallelLinear, ColumnParallelLinear):
                    out_dim, in_dim = model.weight.shape
                    return HackLinearNF4(in_dim, out_dim)
                names = set()
                for name, child in model.named_children():
                    if name not in names:
                        if child in cache:
                            new_child = cache[child]
                        else:
                            new_child = replace_linear_with_nf4(child, name=name, cache=cache)
                            cache[child] = new_child
                        setattr(model, name, new_child)
                        names.add(name)
                flag = True
                while flag:
                    flag = False
                    for name, child in model.named_children():
                        if name not in names:
                            setattr(model, name, cache[child])
                            names.add(name)
                            flag = True
                return model
            replace_linear_with_nf4(parent_model.transformer, None, {})

    def merge_lora(self):
        for i in self.layer_range:
            print(f'merge layer {i} lora attention back to linear')
            self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense)
            self.transformer.layers[i].attention.query_key_value = merge_linear_lora(self.transformer.layers[i].attention.query_key_value)
            if self.transformer.layers[i].is_decoder:
                print(f'merge layer {i} lora cross attention back to linear')
                self.transformer.layers[i].cross_attention.dense = merge_linear_lora(self.transformer.layers[i].cross_attention.dense)
                self.transformer.layers[i].cross_attention.query = merge_linear_lora(self.transformer.layers[i].cross_attention.query)
                self.transformer.layers[i].cross_attention.key_value = merge_linear_lora(self.transformer.layers[i].cross_attention.key_value)

if __name__ == '__main__':
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.child = nn.Linear(100, 200)
        
        def forward(self, x):
            return self.child(x)

    model = Model()
    torch.save(model.state_dict(), "linear.pt")
    x = torch.randn(2, 100)
    out1 = model(x)
    model.child = LoraLinear(100, 200, 10)
    model.load_state_dict(torch.load("linear.pt"), strict=False)
    out2 = model(x)
    torch.save(model.state_dict(), "lora.pt")
    ckpt = torch.load("lora.pt")
    breakpoint()
    model.load_state_dict(ckpt, strict=False)
    out3 = model(x)
    breakpoint()