|
from dataclasses import dataclass |
|
from typing import Any, Dict, Optional, Tuple, List, Union |
|
from fragment_creator import fragment_creator_factory |
|
|
|
from model import ContextArgs, ModelArgs |
|
from tqdm import tqdm |
|
import math |
|
import os |
|
import time |
|
from contextlib import nullcontext |
|
from datetime import datetime |
|
from functools import partial |
|
|
|
import torch |
|
import numpy as np |
|
from model import ContextArgs, Transformer, ModelArgs |
|
from torch.distributed import destroy_process_group, init_process_group |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
from preprocess_dataset import SmilesTask |
|
from tokenizer import SmilesTokenizer |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class IOConfig: |
|
|
|
out_dir: str = "out" |
|
eval_interval: int = 500 |
|
log_interval: int = 10 |
|
eval_iters: int = 25 |
|
eval_only: bool = False |
|
always_save_checkpoint: bool = ( |
|
False |
|
) |
|
init_from: str = "scratch" |
|
resume_when_snapshot_available: bool = True |
|
|
|
|
|
@dataclass |
|
class LoaderConfig: |
|
|
|
batch_size: int = ( |
|
384 |
|
) |
|
max_seq_len: int = 768 |
|
dataset: str = "smiles" |
|
processed_dataset_ckpt: str = "processed_dataset_None.pkl" |
|
fragment_creator: Union[str, None] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class OptimizerConfig: |
|
|
|
gradient_accumulation_steps: int = 4 |
|
learning_rate: float = 1e-4 |
|
max_iters: int = 100000 |
|
weight_decay: float = 1e-1 |
|
beta1: float = 0.9 |
|
beta2: float = 0.95 |
|
grad_clip: float = 1.0 |
|
|
|
decay_lr: bool = True |
|
warmup_iters: int = 1000 |
|
|
|
lr_decay_iters: int = 100000 |
|
min_lr: float = ( |
|
0.0 |
|
) |
|
|
|
|
|
@dataclass |
|
class TrainerArgs: |
|
|
|
io_conf: IOConfig |
|
|
|
|
|
loader_conf: LoaderConfig |
|
|
|
|
|
model_conf: ModelArgs |
|
context_conf: ContextArgs |
|
|
|
|
|
optimizer_conf: OptimizerConfig |
|
|
|
run_name: str |
|
|
|
|
|
class Trainer: |
|
def __init__( |
|
self, train_args: TrainerArgs, dtype: str = "float16", compile: bool = False |
|
) -> None: |
|
self.train_conf = train_args |
|
self.dtype = dtype |
|
self.compile = compile |
|
|
|
self.run_name = train_args.run_name |
|
self.device = ( |
|
"cuda:0" if torch.cuda.is_available() else "cpu" |
|
) |
|
|
|
self.CKPT_PT = f"{self.run_name}.pt" |
|
self.SNAPSHOT_PT = f"snapshot_{self.run_name}.pt" |
|
|
|
def _init_ddp_if_possible(self): |
|
|
|
self.ddp = int(os.environ.get("RANK", -1)) != -1 |
|
if self.ddp: |
|
logger.info(f"Using ddp!") |
|
init_process_group(backend="nccl") |
|
self.ddp_rank = int(os.environ["RANK"]) |
|
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) |
|
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) |
|
logger.info(f"{self.ddp_rank}, {self.ddp_local_rank},{self.ddp_world_size}") |
|
|
|
self.device = f"cuda:{self.ddp_local_rank}" |
|
torch.cuda.set_device(self.device) |
|
self.master_process = ( |
|
self.ddp_rank == 0 |
|
) |
|
|
|
logger.info(f"Is master process {self.device}? {self.master_process}") |
|
self.seed_offset = self.ddp_rank |
|
|
|
|
|
assert ( |
|
self.train_conf.optimizer_conf.gradient_accumulation_steps |
|
% self.ddp_world_size |
|
== 0 |
|
) |
|
self.train_conf.optimizer_conf.gradient_accumulation_steps //= ( |
|
self.ddp_world_size |
|
) |
|
else: |
|
|
|
self.master_process = True |
|
self.seed_offset = 0 |
|
self.ddp_world_size = 1 |
|
|
|
def _init_train(self): |
|
self.tokens_per_iter = ( |
|
self.train_conf.optimizer_conf.gradient_accumulation_steps |
|
* self.ddp_world_size |
|
* self.train_conf.loader_conf.batch_size |
|
* self.train_conf.loader_conf.max_seq_len |
|
) |
|
if self.master_process: |
|
logger.info(f"tokens per iteration will be: {self.tokens_per_iter:,}") |
|
logger.info( |
|
f"breaks down as: {self.train_conf.optimizer_conf.gradient_accumulation_steps} grad accum steps * {self.ddp_world_size} processes * {self.train_conf.loader_conf.batch_size} batch size * {self.train_conf.loader_conf.max_seq_len } max seq len" |
|
) |
|
|
|
if self.master_process: |
|
os.makedirs(self.train_conf.io_conf.out_dir, exist_ok=True) |
|
|
|
torch.manual_seed(1337 + self.seed_offset) |
|
np.random.seed(1337 + self.seed_offset) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
self.device_type = ( |
|
"cuda" if "cuda" in self.device else "cpu" |
|
) |
|
|
|
ptdtype = { |
|
"float32": torch.float32, |
|
"bfloat16": torch.bfloat16, |
|
"float16": torch.float16, |
|
}[self.dtype] |
|
self.ctx = ( |
|
nullcontext() |
|
if self.device_type == "cpu" |
|
else torch.amp.autocast(device_type=self.device_type, dtype=ptdtype) |
|
) |
|
|
|
task = {"smiles": SmilesTask}[self.train_conf.loader_conf.dataset] |
|
self.iter_batches = partial( |
|
task.iter_batches, |
|
batch_size=self.train_conf.loader_conf.batch_size, |
|
device=self.device, |
|
context_keys=self.train_conf.context_conf.context_keys, |
|
num_workers=0, |
|
dataset=self.train_conf.loader_conf.processed_dataset_ckpt, |
|
fragment_creator=fragment_creator_factory( |
|
self.train_conf.loader_conf.fragment_creator |
|
), |
|
) |
|
|
|
self.iter_num = 0 |
|
self.best_val_loss = 1e9 |
|
self.epoch = 1 |
|
|
|
self.tokenizer = SmilesTokenizer() |
|
|
|
has_resumed = False |
|
if ( |
|
self.train_conf.io_conf.init_from == "resume" |
|
or self.train_conf.io_conf.resume_when_snapshot_available |
|
): |
|
snapshot_path = os.path.join( |
|
self.train_conf.io_conf.out_dir, self.SNAPSHOT_PT |
|
) |
|
if os.path.exists(snapshot_path): |
|
has_resumed = True |
|
logger.info(f"Resuming training from {self.train_conf.io_conf.out_dir}") |
|
|
|
ckpt_path = os.path.join(self.train_conf.io_conf.out_dir, self.CKPT_PT) |
|
self.model = Transformer.load(ckpt_path, device=self.device) |
|
snapshot = torch.load(snapshot_path, map_location=self.device) |
|
self.iter_num = snapshot["iter_num"] |
|
self.best_val_loss = snapshot["best_val_loss"] |
|
self.epoch = snapshot["epoch"] |
|
|
|
if self.train_conf.io_conf.init_from == "scratch" and not has_resumed: |
|
|
|
logger.info("Initializing a new model from scratch") |
|
logger.info(self.device) |
|
|
|
model_conf = self.train_conf.model_conf |
|
model_conf.vocab_size = self.tokenizer.vocab_size |
|
|
|
self.model = Transformer(model_conf, self.train_conf.context_conf).to( |
|
self.device |
|
) |
|
logger.info( |
|
f"Number of params: {self.model.getNumberParams()} Number Trainable Params: {self.model.getNumberTrainableParams()}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == "float16")) |
|
|
|
|
|
self.optimizer = self.model.configure_optimizers( |
|
self.train_conf.optimizer_conf.weight_decay, |
|
self.train_conf.optimizer_conf.learning_rate, |
|
( |
|
self.train_conf.optimizer_conf.beta1, |
|
self.train_conf.optimizer_conf.beta2, |
|
), |
|
self.device_type, |
|
) |
|
|
|
if ( |
|
self.train_conf.io_conf.init_from == "resume" |
|
and "optimizer_state" in snapshot |
|
): |
|
logger.info("Loading optimizer state from snapshot") |
|
self.optimizer.load_state_dict(snapshot["optimizer_state"]) |
|
snapshot = None |
|
|
|
|
|
if self.compile: |
|
logger.info("compiling the model... (takes a ~minute)") |
|
self.unoptimized_model = self.model |
|
|
|
|
|
self.model = torch.compile( |
|
self.model, dynamic=False |
|
) |
|
|
|
|
|
if self.ddp: |
|
|
|
|
|
prefix = "_orig_mod." if compile else "" |
|
self.model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"} |
|
self.model = DDP(self.model, device_ids=[self.ddp_local_rank]) |
|
|
|
|
|
@torch.no_grad() |
|
def estimate_loss(self): |
|
out = {} |
|
self.model.eval() |
|
for split in ["train", "val"]: |
|
batch_iter = self.iter_batches(split) |
|
losses = torch.zeros(self.train_conf.io_conf.eval_iters) |
|
for k in tqdm( |
|
range(self.train_conf.io_conf.eval_iters), |
|
total=self.train_conf.io_conf.eval_iters, |
|
desc="Eval", |
|
): |
|
try: |
|
X = next(batch_iter) |
|
with self.ctx: |
|
|
|
|
|
|
|
logits = self.model( |
|
X["src"], |
|
targets=X["tgt"], |
|
context=X["context"], |
|
fragment=X["fragment"], |
|
) |
|
|
|
loss = self.raw_model.last_loss |
|
losses[k] = loss.item() |
|
except StopIteration: |
|
logger.info("Early Eval Stop") |
|
|
|
out[split] = losses.mean() |
|
self.model.train() |
|
return out |
|
|
|
|
|
def get_lr(self, it: int): |
|
warmup_iters = self.train_conf.optimizer_conf.warmup_iters |
|
learning_rate = self.train_conf.optimizer_conf.learning_rate |
|
lr_decay_iters = self.train_conf.optimizer_conf.lr_decay_iters |
|
min_lr = self.train_conf.optimizer_conf.min_lr |
|
|
|
|
|
if it < warmup_iters: |
|
return learning_rate * it / warmup_iters |
|
|
|
if it > lr_decay_iters: |
|
return min_lr |
|
|
|
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) |
|
assert 0 <= decay_ratio <= 1 |
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
return min_lr + coeff * (learning_rate - min_lr) |
|
|
|
def train(self): |
|
self._init_ddp_if_possible() |
|
self._init_train() |
|
|
|
|
|
train_batch_iter = self.iter_batches("train") |
|
X = next(train_batch_iter) |
|
t0 = time.time() |
|
local_iter_num = 0 |
|
self.raw_model = ( |
|
self.model.module if self.ddp else self.model |
|
) |
|
running_mfu = -1.0 |
|
|
|
gradient_accumulation_steps = ( |
|
self.train_conf.optimizer_conf.gradient_accumulation_steps |
|
) |
|
while True: |
|
|
|
lr = ( |
|
self.get_lr(self.iter_num) |
|
if self.train_conf.optimizer_conf.decay_lr |
|
else self.train_conf.optimizer_conf.learning_rate |
|
) |
|
for param_group in self.optimizer.param_groups: |
|
param_group["lr"] = lr |
|
|
|
|
|
if ( |
|
self.iter_num % self.train_conf.io_conf.eval_interval == 0 |
|
and self.master_process |
|
and self.iter_num != 0 |
|
): |
|
logger.info( |
|
f"Estimating loss for master_process({self.master_process}) on iter {self.iter_num}" |
|
) |
|
losses = self.estimate_loss() |
|
logger.info( |
|
f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" |
|
) |
|
log_dict = { |
|
"iter": self.iter_num, |
|
"tokens": self.iter_num * self.tokens_per_iter, |
|
"loss/train": losses["train"], |
|
"loss/val": losses["val"], |
|
"lr": lr, |
|
"mfu": running_mfu * 100, |
|
} |
|
logger.info(f"{log_dict}") |
|
|
|
if ( |
|
losses["val"] < self.best_val_loss |
|
or self.train_conf.io_conf.always_save_checkpoint |
|
): |
|
self.best_val_loss = losses["val"] |
|
if self.iter_num > 0: |
|
logger.info( |
|
f"saving checkpoint to {self.train_conf.io_conf.out_dir}" |
|
) |
|
self.raw_model.save( |
|
os.path.join(self.train_conf.io_conf.out_dir, self.CKPT_PT) |
|
) |
|
|
|
torch.save( |
|
{ |
|
"iter_num": self.iter_num, |
|
"epoch": self.epoch, |
|
"best_val_loss": self.best_val_loss, |
|
"optimizer_state": self.optimizer.state_dict(), |
|
}, |
|
os.path.join( |
|
self.train_conf.io_conf.out_dir, self.SNAPSHOT_PT |
|
), |
|
) |
|
|
|
if self.iter_num == 0 and self.train_conf.io_conf.eval_only: |
|
break |
|
|
|
|
|
|
|
for micro_step in range(gradient_accumulation_steps): |
|
if self.ddp: |
|
|
|
|
|
|
|
|
|
self.model.require_backward_grad_sync = ( |
|
micro_step == gradient_accumulation_steps - 1 |
|
) |
|
with self.ctx: |
|
context = X["context"] |
|
|
|
fragment = X["fragment"] |
|
|
|
|
|
if np.random.random() < 0.15 or fragment is None: |
|
fragment = None |
|
|
|
|
|
current_context_keys = list(context.keys()) |
|
for k in current_context_keys: |
|
if np.random.random() < 0.15: |
|
del context[k] |
|
|
|
logits = self.model( |
|
X["src"], targets=X["tgt"], context=context, fragment=fragment |
|
) |
|
loss = self.raw_model.last_loss |
|
loss = loss / gradient_accumulation_steps |
|
|
|
try: |
|
X = next(train_batch_iter) |
|
|
|
except StopIteration: |
|
|
|
|
|
logger.info(f"Done Epoch {self.epoch}") |
|
train_batch_iter = self.iter_batches("train") |
|
X = next(train_batch_iter) |
|
self.epoch += 1 |
|
|
|
|
|
self.scaler.scale(loss).backward() |
|
|
|
|
|
if self.train_conf.optimizer_conf.grad_clip != 0.0: |
|
self.scaler.unscale_(self.optimizer) |
|
torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), self.train_conf.optimizer_conf.grad_clip |
|
) |
|
|
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
|
|
self.optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
t1 = time.time() |
|
dt = t1 - t0 |
|
t0 = t1 |
|
|
|
if ( |
|
self.iter_num % self.train_conf.io_conf.log_interval == 0 |
|
and self.master_process |
|
): |
|
|
|
lossf = loss.item() * gradient_accumulation_steps |
|
if local_iter_num >= 5: |
|
mfu = self.raw_model.estimate_mfu( |
|
self.train_conf.loader_conf.batch_size |
|
* gradient_accumulation_steps, |
|
dt, |
|
) |
|
running_mfu = ( |
|
mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu |
|
) |
|
logger.info( |
|
f"{self.iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%" |
|
) |
|
self.iter_num += 1 |
|
local_iter_num += 1 |
|
|
|
|
|
|
|
if self.iter_num > self.train_conf.optimizer_conf.max_iters: |
|
logger.info("Done with training iters!") |
|
break |
|
|
|
if self.ddp: |
|
destroy_process_group() |
|
|
|
|
|
if __name__ == "__main__": |
|
pass |
|
|