Spaces:
Runtime error
Runtime error
File size: 1,050 Bytes
e6fd727 557fb53 0030bc6 e6fd727 557fb53 e6fd727 557fb53 e6fd727 3b31903 557fb53 c914273 3b31903 0030bc6 c914273 e6fd727 c914273 3b31903 e6fd727 557fb53 3b31903 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from typing import Callable
import importlib
import yaml
from argparse import ArgumentParser
import os
ROOT_DIR = os.path.basename(os.path.dirname(__file__))
def get_training_fn(id: str) -> Callable:
module_name, fn_name = id.rsplit(".", 1)
module = importlib.import_module("models." + module_name, ROOT_DIR)
return getattr(module, fn_name)
def get_config(filepath: str) -> dict:
with open(filepath, "r") as f:
config = yaml.safe_load(f)
return config
if __name__ == "__main__":
parser = ArgumentParser(
description="Trains models on the dance dataset and saves weights."
)
parser.add_argument(
"--config",
help="Path to the yaml file that defines the training configuration.",
default="models/config/train_local.yaml",
)
args = parser.parse_args()
config = get_config(args.config)
training_fn_path = config["training_fn"]
print(f"Config: {args.config}\nTrainer Id: {training_fn_path}")
train = get_training_fn(training_fn_path)
train(config)
|