from .t2m_dataset import HumanML3D,KIT from os.path import join as pjoin __all__ = [ 'HumanML3D', 'KIT', 'get_dataset',] def get_dataset(opt, split='train', mode='train', accelerator=None): if opt.dataset_name == 't2m' : dataset = HumanML3D(opt, split, mode, accelerator) elif opt.dataset_name == 'kit' : dataset = KIT(opt,split, mode, accelerator) else: raise KeyError('Dataset Does Not Exist') if accelerator: accelerator.print('Completing loading %s dataset' % opt.dataset_name) else: print('Completing loading %s dataset' % opt.dataset_name) return dataset