bstraehle commited on
Commit
53b729b
1 Parent(s): 1939ff5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -65,29 +65,25 @@ def fine_tune_model(base_model_name, dataset_name):
65
  print(dataset["train"][:1])
66
  print("###")
67
 
68
- # Split dataset into training and validation sets
69
 
70
  train_dataset = dataset["train"]
71
- test_dataset = dataset["test"]
72
- #train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
73
- #test_dataset = dataset["test"].shuffle(seed=42).select(range(100))
74
 
75
  print("### Training dataset")
76
  print(train_dataset)
77
- print("### Validation dataset")
78
- print(test_dataset)
79
  print("###")
80
 
81
  # Configure training arguments
82
 
83
- # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
84
  training_args = Seq2SeqTrainingArguments(
85
  output_dir=f"./{FT_MODEL_NAME}",
86
- logging_dir="./logs",
87
- num_train_epochs=1,
88
- max_steps=1, # overwrites num_train_epochs
89
  push_to_hub=True, # only pushes model, also need to push tokenizer (see below)
90
- # TODO
91
  )
92
 
93
  print("### Training arguments")
@@ -96,13 +92,12 @@ def fine_tune_model(base_model_name, dataset_name):
96
 
97
  # Create trainer
98
 
99
- # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
100
  trainer = Seq2SeqTrainer(
101
  model=model,
102
  args=training_args,
103
  train_dataset=train_dataset,
104
- eval_dataset=test_dataset,
105
- # TODO
106
  )
107
 
108
  # Train model
 
65
  print(dataset["train"][:1])
66
  print("###")
67
 
68
+ # Split dataset into training and evaluation sets
69
 
70
  train_dataset = dataset["train"]
71
+ eval_dataset = dataset["test"]
 
 
72
 
73
  print("### Training dataset")
74
  print(train_dataset)
75
+ print("### Evaluation dataset")
76
+ print(eval_dataset)
77
  print("###")
78
 
79
  # Configure training arguments
80
 
 
81
  training_args = Seq2SeqTrainingArguments(
82
  output_dir=f"./{FT_MODEL_NAME}",
83
+ num_train_epochs=3,
84
+ #max_steps=1, # overwrites num_train_epochs
 
85
  push_to_hub=True, # only pushes model, also need to push tokenizer (see below)
86
+ # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
87
  )
88
 
89
  print("### Training arguments")
 
92
 
93
  # Create trainer
94
 
 
95
  trainer = Seq2SeqTrainer(
96
  model=model,
97
  args=training_args,
98
  train_dataset=train_dataset,
99
+ eval_dataset=eval_dataset,
100
+ # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
101
  )
102
 
103
  # Train model