Pascrayon commited on
Commit
41ed25e
1 Parent(s): 05657f4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +162 -1
README.md CHANGED
@@ -8,4 +8,165 @@ pipeline_tag: text-generation
8
  tags:
9
  - crayon
10
  - language-technologies
11
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  tags:
9
  - crayon
10
  - language-technologies
11
+ ---
12
+
13
+ # Bloom 560M Finetuned on Instructions
14
+
15
+ # Training Code
16
+
17
+ ```python
18
+ # coding=utf-8
19
+ # Code 99.99% copied and adapted from:
20
+ # https://github.com/bofenghuang/vigogne
21
+ # https://colab.research.google.com/drive/1jCkpikz0J2o20FBQmYmAGdiKmJGOMo-o?usp=sharing#scrollTo=DpYr24pR8T_0
22
+
23
+
24
+ import os
25
+ import sys
26
+ from dataclasses import dataclass
27
+ from typing import Dict, List, Optional, Sequence
28
+
29
+ import bitsandbytes as bnb
30
+ import fire
31
+ import torch
32
+ import transformers
33
+ from datasets import load_dataset
34
+ from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict, prepare_model_for_int8_training
35
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
36
+
37
+ IGNORE_INDEX = -100
38
+ DEFAULT_PAD_TOKEN = "[PAD]"
39
+
40
+ PROMPT_DICT = {
41
+ "prompt_input": (
42
+ "Below is a^n instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
43
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
44
+ ),
45
+ "prompt_no_input": (
46
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
47
+ "### Instruction:\n{instruction}\n\n### Response:\n"
48
+ ),
49
+ }
50
+
51
+
52
+ def generate_prompt(example):
53
+ return (
54
+ PROMPT_DICT["prompt_input"].format_map(example)
55
+ if example["input"]
56
+ else PROMPT_DICT["prompt_no_input"].format_map(example)
57
+ )
58
+
59
+
60
+ # Modified from: https://github.com/bofenghuang/stanford_alpaca/blob/eb5b171d9b103a12a8e14e0edca9cbc45fe1d512/train.py#L166-L182
61
+ # Almost same to transformers.DataCollatorForSeq2Seq
62
+ @dataclass
63
+ class DataCollatorForSupervisedDataset(object):
64
+ """Collate examples for supervised fine-tuning."""
65
+
66
+ tokenizer: transformers.PreTrainedTokenizer
67
+ pad_to_multiple_of: Optional[int] = None
68
+
69
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
70
+ # dtype = torch.long
71
+ # input_ids, labels = tuple([torch.LongTensor(instance[key]) for instance in instances] for key in ("input_ids", "labels"))
72
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
73
+
74
+ if self.pad_to_multiple_of is not None:
75
+ max_length_index, max_length = max(enumerate([len(input_ids_) for input_ids_ in input_ids]),
76
+ key=lambda x: x[1])
77
+ # int(math.ceil
78
+ n_padding = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of - max_length
79
+ # Pad the longest example to pad_to_multiple_of * N
80
+ input_ids[max_length_index].extend([self.tokenizer.pad_token_id] * n_padding)
81
+ labels[max_length_index].extend([IGNORE_INDEX] * n_padding)
82
+
83
+ input_ids = [torch.LongTensor(input_ids_) for input_ids_ in input_ids]
84
+ labels = [torch.LongTensor(labels_) for labels_ in labels]
85
+
86
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True,
87
+ padding_value=self.tokenizer.pad_token_id)
88
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
89
+
90
+ return dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
91
+
92
+
93
+ def train(model_name_or_path: str, output_dir: str, data_path: str, val_set_size: int = 500,
94
+ model_max_length: int = 512, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05,
95
+ target_modules: List[str] = ["query_key_value"], num_train_epochs: int = 3, learning_rate: float = 0.0001,
96
+ per_device_train_batch_size: int = 8, gradient_accumulation_steps: int = 16, **kwargs):
97
+ device_map = "auto"
98
+
99
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_8bit=True, device_map=device_map)
100
+
101
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, model_max_length=model_max_length,
102
+ padding_side="right", use_fast=False)
103
+
104
+ model = prepare_model_for_int8_training(model)
105
+
106
+ lora_config = LoraConfig(r=lora_r, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=lora_dropout,
107
+ bias="none", task_type=TaskType.CAUSAL_LM)
108
+
109
+ model = get_peft_model(model, lora_config)
110
+
111
+ model.print_trainable_parameters()
112
+
113
+ # Load data
114
+ data = load_dataset("json", data_files=data_path)
115
+
116
+ def preprocess_function(example):
117
+ # Format prompt
118
+ user_prompt = generate_prompt(example)
119
+
120
+ # Get prompt length for masking
121
+ len_user_prompt_tokens = len(tokenizer(user_prompt, truncation=True)["input_ids"])
122
+
123
+ input_ids = tokenizer(user_prompt + example["output"] + tokenizer.eos_token, truncation=True)["input_ids"]
124
+ labels = [IGNORE_INDEX] * len_user_prompt_tokens + input_ids[len_user_prompt_tokens:]
125
+
126
+ return {"input_ids": input_ids, "labels": labels}
127
+
128
+ if val_set_size > 0:
129
+ train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
130
+ train_data = train_val["train"].shuffle().map(preprocess_function, remove_columns=data["train"].column_names)
131
+ val_data = train_val["test"].map(preprocess_function, remove_columns=data["train"].column_names)
132
+ else:
133
+ train_data = data["train"].shuffle().map(preprocess_function, remove_columns=data["train"].column_names)
134
+ val_data = None
135
+
136
+ trainer = transformers.Trainer(
137
+ model=model,
138
+ train_dataset=train_data,
139
+ eval_dataset=val_data,
140
+ args=transformers.TrainingArguments(
141
+ per_device_train_batch_size=per_device_train_batch_size,
142
+ gradient_accumulation_steps=gradient_accumulation_steps,
143
+ num_train_epochs=num_train_epochs,
144
+ learning_rate=learning_rate,
145
+ fp16=True,
146
+ output_dir=output_dir,
147
+ load_best_model_at_end=True if val_set_size > 0 else False,
148
+ **kwargs,
149
+ ),
150
+ data_collator=DataCollatorForSupervisedDataset(tokenizer=tokenizer, pad_to_multiple_of=8),
151
+ )
152
+ print(trainer.args)
153
+
154
+ # Silence the warnings. Please re-enable for inference!
155
+ model.config.use_cache = False
156
+
157
+ old_state_dict = model.state_dict
158
+ model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(model,
159
+ type(model))
160
+
161
+ if torch.__version__ >= "2" and sys.platform != "win32":
162
+ model = torch.compile(model)
163
+
164
+ trainer.train()
165
+
166
+ model.save_pretrained(output_dir)
167
+
168
+
169
+ if __name__ == "__main__":
170
+ fire.Fire(train)
171
+
172
+ ```