|
import pandas as pd |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, AutoModelForCausalLM |
|
from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling |
|
from transformers import Trainer, TrainingArguments, RobertaTokenizerFast |
|
|
|
import datasets |
|
from datasets import disable_caching |
|
disable_caching() |
|
from datasets import IterableDataset |
|
|
|
from conditional_gpt2_model import ConditionalGPT2LMHeadModel |
|
|
|
|
|
ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m" |
|
TOKENIZER_MAX_LEN = 256 |
|
|
|
DATA_SUBSHARDS = 10 |
|
|
|
DATA_DIR = None |
|
TRAINER_SAVE_DIR = None |
|
|
|
assert DATA_DIR is not None, "data directory must be specified" |
|
assert TRAINER_SAVE_DIR is not None, "trainer save directory must be specified" |
|
|
|
|
|
|
|
def gen_dataset(): |
|
|
|
data_filenames = sorted([i for i in os.listdir(DATA_DIR) if '.hf' in i]) |
|
|
|
for filename in data_filenames: |
|
|
|
dataset = datasets.Dataset.load_from_disk(f'{DATA_DIR}/{filename}') |
|
|
|
keep_cols = ['input_ids', 'encoder_hidden_states'] |
|
|
|
dataset = dataset.remove_columns([i for i in dataset.column_names |
|
if not i in keep_cols]).with_format("torch") |
|
|
|
|
|
shards = [dataset.shard(num_shards=DATA_SUBSHARDS, index=index, contiguous=True) |
|
for index in range(DATA_SUBSHARDS)] |
|
|
|
for i, shard in enumerate(shards): |
|
for example in shard: |
|
|
|
example['encoder_hidden_states'] = example['encoder_hidden_states'][None,:] |
|
yield example |
|
|
|
dataset = IterableDataset.from_generator(gen_dataset) |
|
dataset = dataset.with_format("torch") |
|
|
|
tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN) |
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
|
|
|
|
|
config = GPT2Config( |
|
vocab_size=len(tokenizer), |
|
n_positions=TOKENIZER_MAX_LEN, |
|
bos_token_id=tokenizer.bos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
n_layer=6, |
|
n_head=8, |
|
add_cross_attention=True, |
|
) |
|
|
|
model = ConditionalGPT2LMHeadModel(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = TrainingArguments( |
|
output_dir=TRAINER_SAVE_DIR, |
|
per_device_train_batch_size=192, |
|
logging_steps=25, |
|
gradient_accumulation_steps=8, |
|
num_train_epochs=1, |
|
weight_decay=0.1, |
|
warmup_steps=1000, |
|
lr_scheduler_type="cosine", |
|
learning_rate=1e-5, |
|
save_steps=200, |
|
save_total_limit=30, |
|
fp16=True, |
|
push_to_hub=False, |
|
max_steps=50000, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
args=args, |
|
data_collator=collator, |
|
train_dataset=dataset, |
|
) |
|
|
|
trainer.train() |
|
|
|
|