Spaces:
Runtime error
Runtime error
File size: 5,436 Bytes
e79b770 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import argparse
import logging
import os
from pathlib import Path
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module_librilight_6k import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from soundstorm.utils import get_newest_ckpt
from soundstorm.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high')
def main(args):
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt'
ckpt_dir.mkdir(parents=True, exist_ok=True)
config = load_yaml_config(args.config_file)
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = ModelCheckpoint(
save_top_k=-1,
save_on_train_epoch_end=False,
every_n_train_steps=config["train"]["every_n_train_steps"],
dirpath=ckpt_dir)
logger = WandbLogger(
project="AR_S1_LibriLight",
name=output_dir.stem,
save_dir=output_dir,
# resume the loss curve
resume=True,
# id='k19kvsq8'
)
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator='gpu',
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(find_unused_parameters=True),
precision=config["train"]["precision"],
logger=logger,
callbacks=[ckpt_callback])
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
train_semantic_dirs=args.train_semantic_dirs,
train_phoneme_dirs=args.train_phoneme_dirs,
dev_semantic_dirs=args.dev_semantic_dirs,
dev_phoneme_dirs=args.dev_phoneme_dirs,
train_non_speech_dirs=args.train_non_speech_dirs,
dev_non_speech_dirs=args.dev_non_speech_dirs)
try:
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
ckpt_path = ckpt_dir / newest_ckpt_name
except Exception:
ckpt_path = None
print("ckpt_path:", ckpt_path)
trainer.fit(model, data_module, ckpt_path=ckpt_path)
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--config_file',
type=str,
default='conf/default.yaml',
help='path of config file')
# args for dataset
parser.add_argument(
'--train_semantic_dirs',
type=list,
nargs='+',
default=["dump/small/train/"],
help='dirs of train semantic')
parser.add_argument(
'--train_phoneme_dirs',
type=list,
nargs='+',
default=["dump/small/train/"],
help='dirs of train phoneme')
parser.add_argument(
'--dev_semantic_dirs',
type=list,
nargs='+',
default=["dump/small/dev/"],
help='dirs of dev semantic')
parser.add_argument(
'--dev_phoneme_dirs',
type=list,
nargs='+',
default=["dump/small/dev/"],
help='dirs of dev phoneme')
parser.add_argument(
'--output_dir',
type=str,
default='exp/default',
help='directory to save the results')
parser.add_argument(
'--train_non_speech_dirs',
type=list,
nargs='+',
default=None,
help='dirs of train non_speech data')
parser.add_argument(
'--dev_non_speech_dirs',
type=list,
nargs='+',
default=None,
help='dirs of dev non_speech data')
args = parser.parse_args()
new_train_semantic_dirs = []
new_train_phoneme_dirs = []
new_dev_semantic_dirs = []
new_dev_phoneme_dirs = []
new_train_non_speech_dirs = []
new_dev_non_speech_dirs = []
# format dataset dirs
for item in args.train_semantic_dirs:
new_train_semantic_dirs.append(''.join(item))
args.train_semantic_dirs = new_train_semantic_dirs
for item in args.train_phoneme_dirs:
new_train_phoneme_dirs.append(''.join(item))
args.train_phoneme_dirs = new_train_phoneme_dirs
for item in args.dev_semantic_dirs:
new_dev_semantic_dirs.append(''.join(item))
args.dev_semantic_dirs = new_dev_semantic_dirs
for item in args.dev_phoneme_dirs:
new_dev_phoneme_dirs.append(''.join(item))
args.dev_phoneme_dirs = new_dev_phoneme_dirs
if args.train_non_speech_dirs is not None:
for item in args.train_non_speech_dirs:
new_train_non_speech_dirs.append(''.join(item))
args.train_non_speech_dirs = new_train_non_speech_dirs
if args.dev_non_speech_dirs is not None:
for item in args.dev_non_speech_dirs:
new_dev_non_speech_dirs.append(''.join(item))
args.dev_non_speech_dirs = new_dev_non_speech_dirs
logging.info(str(args))
main(args)
|