Spaces:
Running
Running
diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py | |
index 3116307..5de661d 100644 | |
--- a/models/hierarchy_inference_model.py | |
+++ b/models/hierarchy_inference_model.py | |
class VQGANTextureAwareSpatialHierarchyInferenceModel(): | |
def __init__(self, opt): | |
self.opt = opt | |
- self.device = torch.device('cuda') | |
+ self.device = torch.device(opt['device']) | |
self.is_train = opt['is_train'] | |
self.top_encoder = Encoder( | |
diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py | |
index 4b0d657..0bf4712 100644 | |
--- a/models/hierarchy_vqgan_model.py | |
+++ b/models/hierarchy_vqgan_model.py | |
class HierarchyVQSpatialTextureAwareModel(): | |
def __init__(self, opt): | |
self.opt = opt | |
- self.device = torch.device('cuda') | |
+ self.device = torch.device(opt['device']) | |
self.top_encoder = Encoder( | |
ch=opt['top_ch'], | |
num_res_blocks=opt['top_num_res_blocks'], | |
diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py | |
index 9440345..15a1ecb 100644 | |
--- a/models/parsing_gen_model.py | |
+++ b/models/parsing_gen_model.py | |
class ParsingGenModel(): | |
def __init__(self, opt): | |
self.opt = opt | |
- self.device = torch.device('cuda') | |
+ self.device = torch.device(opt['device']) | |
self.is_train = opt['is_train'] | |
self.attr_embedder = ShapeAttrEmbedding( | |
diff --git a/models/sample_model.py b/models/sample_model.py | |
index 4c60e3f..5265cd0 100644 | |
--- a/models/sample_model.py | |
+++ b/models/sample_model.py | |
class BaseSampleModel(): | |
def __init__(self, opt): | |
self.opt = opt | |
- self.device = torch.device('cuda') | |
+ self.device = torch.device(opt['device']) | |
# hierarchical VQVAE | |
self.decoder = Decoder( | |
class BaseSampleModel(): | |
def load_top_pretrain_models(self): | |
# load pretrained vqgan | |
- top_vae_checkpoint = torch.load(self.opt['top_vae_path']) | |
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device) | |
self.decoder.load_state_dict( | |
top_vae_checkpoint['decoder'], strict=True) | |
class BaseSampleModel(): | |
self.top_post_quant_conv.eval() | |
def load_bot_pretrain_network(self): | |
- checkpoint = torch.load(self.opt['bot_vae_path']) | |
+ checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device) | |
self.bot_decoder_res.load_state_dict( | |
checkpoint['bot_decoder_res'], strict=True) | |
self.decoder.load_state_dict(checkpoint['decoder'], strict=True) | |
class BaseSampleModel(): | |
def load_pretrained_segm_token(self): | |
# load pretrained vqgan for segmentation mask | |
- segm_token_checkpoint = torch.load(self.opt['segm_token_path']) | |
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device) | |
self.segm_encoder.load_state_dict( | |
segm_token_checkpoint['encoder'], strict=True) | |
self.segm_quantizer.load_state_dict( | |
class BaseSampleModel(): | |
self.segm_quant_conv.eval() | |
def load_index_pred_network(self): | |
- checkpoint = torch.load(self.opt['pretrained_index_network']) | |
+ checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device) | |
self.index_pred_guidance_encoder.load_state_dict( | |
checkpoint['guidance_encoder'], strict=True) | |
self.index_pred_decoder.load_state_dict( | |
class BaseSampleModel(): | |
self.index_pred_decoder.eval() | |
def load_sampler_pretrained_network(self): | |
- checkpoint = torch.load(self.opt['pretrained_sampler']) | |
+ checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device) | |
self.sampler_fn.load_state_dict(checkpoint, strict=True) | |
self.sampler_fn.eval() | |
class SampleFromPoseModel(BaseSampleModel): | |
[185, 210, 205], [130, 165, 180], [225, 141, 151]] | |
def load_shape_generation_models(self): | |
- checkpoint = torch.load(self.opt['pretrained_parsing_gen']) | |
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device) | |
self.shape_attr_embedder.load_state_dict( | |
checkpoint['embedder'], strict=True) | |
diff --git a/models/transformer_model.py b/models/transformer_model.py | |
index 7db0f3e..4523d17 100644 | |
--- a/models/transformer_model.py | |
+++ b/models/transformer_model.py | |
class TransformerTextureAwareModel(): | |
def __init__(self, opt): | |
self.opt = opt | |
- self.device = torch.device('cuda') | |
+ self.device = torch.device(opt['device']) | |
self.is_train = opt['is_train'] | |
# VQVAE for image | |
class TransformerTextureAwareModel(): | |
def sample_fn(self, temp=1.0, sample_steps=None): | |
self._denoise_fn.eval() | |
- b, device = self.image.size(0), 'cuda' | |
+ b = self.image.size(0) | |
x_t = torch.ones( | |
- (b, np.prod(self.shape)), device=device).long() * self.mask_id | |
- unmasked = torch.zeros_like(x_t, device=device).bool() | |
+ (b, np.prod(self.shape)), device=self.device).long() * self.mask_id | |
+ unmasked = torch.zeros_like(x_t, device=self.device).bool() | |
sample_steps = list(range(1, sample_steps + 1)) | |
texture_mask_flatten = self.texture_tokens.view(-1) | |
class TransformerTextureAwareModel(): | |
for t in reversed(sample_steps): | |
print(f'Sample timestep {t:4d}', end='\r') | |
- t = torch.full((b, ), t, device=device, dtype=torch.long) | |
+ t = torch.full((b, ), t, device=self.device, dtype=torch.long) | |
# where to unmask | |
changes = torch.rand( | |
- x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) | |
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1) | |
# don't unmask somewhere already unmasked | |
changes = torch.bitwise_xor(changes, | |
torch.bitwise_and(changes, unmasked)) | |
diff --git a/models/vqgan_model.py b/models/vqgan_model.py | |
index 13a2e70..9c840f1 100644 | |
--- a/models/vqgan_model.py | |
+++ b/models/vqgan_model.py | |
class VQModel(): | |
def __init__(self, opt): | |
super().__init__() | |
self.opt = opt | |
- self.device = torch.device('cuda') | |
+ self.device = torch.device(opt['device']) | |
self.encoder = Encoder( | |
ch=opt['ch'], | |
num_res_blocks=opt['num_res_blocks'], | |
class VQImageSegmTextureModel(VQImageModel): | |
def __init__(self, opt): | |
self.opt = opt | |
- self.device = torch.device('cuda') | |
+ self.device = torch.device(opt['device']) | |
self.encoder = Encoder( | |
ch=opt['ch'], | |
num_res_blocks=opt['num_res_blocks'], | |