bstraehle commited on
Commit
d66c584
1 Parent(s): 70d0c0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -102,6 +102,7 @@ def test(base_model_id, dataset):
102
  ###################
103
  # Hyper-parameters
104
  ###################
 
105
  training_config = {
106
  "bf16": True,
107
  "do_eval": False,
@@ -125,7 +126,8 @@ def test(base_model_id, dataset):
125
  "gradient_accumulation_steps": 1,
126
  "warmup_ratio": 0.2,
127
  }
128
-
 
129
  peft_config = {
130
  "r": 16,
131
  "lora_alpha": 32,
@@ -142,6 +144,7 @@ def test(base_model_id, dataset):
142
  ###############
143
  # Setup logging
144
  ###############
 
145
  logging.basicConfig(
146
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
147
  datefmt="%Y-%m-%d %H:%M:%S",
@@ -156,6 +159,7 @@ def test(base_model_id, dataset):
156
  transformers.utils.logging.enable_explicit_format()
157
 
158
  # Log on each process a small summary
 
159
  logger.warning(
160
  f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
161
  + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
@@ -167,6 +171,7 @@ def test(base_model_id, dataset):
167
  ################
168
  # Model Loading
169
  ################
 
170
  checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
171
  # checkpoint_path = "microsoft/Phi-3-mini-128k-instruct"
172
  model_kwargs = dict(
@@ -176,6 +181,7 @@ def test(base_model_id, dataset):
176
  torch_dtype=torch.bfloat16,
177
  device_map=None
178
  )
 
179
  model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
180
  tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
181
  tokenizer.model_max_length = 2048
@@ -187,10 +193,8 @@ def test(base_model_id, dataset):
187
  ##################
188
  # Data Processing
189
  ##################
190
- def apply_chat_template(
191
- example,
192
- tokenizer,
193
- ):
194
  messages = example["messages"]
195
  example["text"] = tokenizer.apply_chat_template(
196
  messages, tokenize=False, add_generation_prompt=False)
@@ -200,7 +204,8 @@ def test(base_model_id, dataset):
200
  train_dataset = raw_dataset["train_sft"]
201
  test_dataset = raw_dataset["test_sft"]
202
  column_names = list(train_dataset.features)
203
-
 
204
  processed_train_dataset = train_dataset.map(
205
  apply_chat_template,
206
  fn_kwargs={"tokenizer": tokenizer},
@@ -208,7 +213,8 @@ def test(base_model_id, dataset):
208
  remove_columns=column_names,
209
  desc="Applying chat template to train_sft",
210
  )
211
-
 
212
  processed_test_dataset = test_dataset.map(
213
  apply_chat_template,
214
  fn_kwargs={"tokenizer": tokenizer},
@@ -221,6 +227,7 @@ def test(base_model_id, dataset):
221
  ###########
222
  # Training
223
  ###########
 
224
  trainer = SFTTrainer(
225
  model=model,
226
  args=train_conf,
@@ -242,6 +249,7 @@ def test(base_model_id, dataset):
242
  #############
243
  # Evaluation
244
  #############
 
245
  tokenizer.padding_side = 'left'
246
  metrics = trainer.evaluate()
247
  metrics["eval_samples"] = len(processed_test_dataset)
@@ -252,6 +260,7 @@ def test(base_model_id, dataset):
252
  # ############
253
  # # Save model
254
  # ############
 
255
  trainer.save_model(train_conf.output_dir)
256
 
257
  def download_model(base_model_id):
 
102
  ###################
103
  # Hyper-parameters
104
  ###################
105
+ print("111")
106
  training_config = {
107
  "bf16": True,
108
  "do_eval": False,
 
126
  "gradient_accumulation_steps": 1,
127
  "warmup_ratio": 0.2,
128
  }
129
+
130
+ print("222")
131
  peft_config = {
132
  "r": 16,
133
  "lora_alpha": 32,
 
144
  ###############
145
  # Setup logging
146
  ###############
147
+ print("333")
148
  logging.basicConfig(
149
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
150
  datefmt="%Y-%m-%d %H:%M:%S",
 
159
  transformers.utils.logging.enable_explicit_format()
160
 
161
  # Log on each process a small summary
162
+ print("444")
163
  logger.warning(
164
  f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
165
  + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
 
171
  ################
172
  # Model Loading
173
  ################
174
+ print("444")
175
  checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
176
  # checkpoint_path = "microsoft/Phi-3-mini-128k-instruct"
177
  model_kwargs = dict(
 
181
  torch_dtype=torch.bfloat16,
182
  device_map=None
183
  )
184
+ print("555")
185
  model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
186
  tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
187
  tokenizer.model_max_length = 2048
 
193
  ##################
194
  # Data Processing
195
  ##################
196
+ print("666")
197
+ def apply_chat_template(example, tokenizer):
 
 
198
  messages = example["messages"]
199
  example["text"] = tokenizer.apply_chat_template(
200
  messages, tokenize=False, add_generation_prompt=False)
 
204
  train_dataset = raw_dataset["train_sft"]
205
  test_dataset = raw_dataset["test_sft"]
206
  column_names = list(train_dataset.features)
207
+
208
+ print("777")
209
  processed_train_dataset = train_dataset.map(
210
  apply_chat_template,
211
  fn_kwargs={"tokenizer": tokenizer},
 
213
  remove_columns=column_names,
214
  desc="Applying chat template to train_sft",
215
  )
216
+
217
+ print("888")
218
  processed_test_dataset = test_dataset.map(
219
  apply_chat_template,
220
  fn_kwargs={"tokenizer": tokenizer},
 
227
  ###########
228
  # Training
229
  ###########
230
+ print("999")
231
  trainer = SFTTrainer(
232
  model=model,
233
  args=train_conf,
 
249
  #############
250
  # Evaluation
251
  #############
252
+ print("aaa")
253
  tokenizer.padding_side = 'left'
254
  metrics = trainer.evaluate()
255
  metrics["eval_samples"] = len(processed_test_dataset)
 
260
  # ############
261
  # # Save model
262
  # ############
263
+ print("bbb")
264
  trainer.save_model(train_conf.output_dir)
265
 
266
  def download_model(base_model_id):