# Author: Bingxin Ke # Last modified: 2024-05-17 from .marigold_trainer import MarigoldTrainer from .marigold_xl_trainer import MarigoldXLTrainer from .marigold_inpaint_trainer import MarigoldInpaintTrainer trainer_cls_name_dict = { "MarigoldTrainer": MarigoldTrainer, "MarigoldXLTrainer": MarigoldXLTrainer, "MarigoldInpaintTrainer": MarigoldInpaintTrainer } def get_trainer_cls(trainer_name): return trainer_cls_name_dict[trainer_name]