Ricecake123's picture
first commit
e79b770
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import argparse
import logging
import os
from pathlib import Path
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from soundstorm.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high')
from soundstorm.utils import get_newest_ckpt
def main(args):
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt'
ckpt_dir.mkdir(parents=True, exist_ok=True)
config = load_yaml_config(args.config_file)
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = ModelCheckpoint(
save_top_k=-1,
save_on_train_epoch_end=False,
every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir)
logger = WandbLogger(
project="AR_S1",
name=output_dir.stem,
save_dir=output_dir,
# resume the loss curve
resume=True,
# id='k19kvsq8'
)
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator='gpu',
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(find_unused_parameters=True),
precision=config["train"]["precision"],
logger=logger,
callbacks=[ckpt_callback])
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
train_semantic_path=args.train_semantic_path,
train_phoneme_path=args.train_phoneme_path,
dev_semantic_path=args.dev_semantic_path,
dev_phoneme_path=args.dev_phoneme_path)
try:
# 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
ckpt_path = ckpt_dir / newest_ckpt_name
except Exception:
ckpt_path = None
print("ckpt_path:", ckpt_path)
trainer.fit(model, data_module, ckpt_path=ckpt_path)
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--config_file',
type=str,
default='conf/default.yaml',
help='path of config file')
# args for dataset
parser.add_argument(
'--train_semantic_path',
type=str,
default='dump/train/semantic_token.tsv')
parser.add_argument(
'--train_phoneme_path', type=str, default='dump/train/phonemes.npy')
parser.add_argument(
'--dev_semantic_path', type=str, default='dump/dev/semantic_token.tsv')
parser.add_argument(
'--dev_phoneme_path', type=str, default='dump/dev/phonemes.npy')
parser.add_argument(
'--output_dir',
type=str,
default='exp/default',
help='directory to save the results')
args = parser.parse_args()
logging.info(str(args))
main(args)