liuhaotian commited on
Commit
5c79044
1 Parent(s): 087de09

Fix OOM issue.

Browse files
Files changed (2) hide show
  1. app.py +29 -24
  2. gligen/task_grounded_generation.py +10 -6
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  import argparse
4
  from omegaconf import OmegaConf
5
- from gligen.task_grounded_generation import grounded_generation_box, load_ckpt
6
 
7
  import json
8
  import numpy as np
@@ -34,41 +34,46 @@ def parse_option():
34
  args = parse_option()
35
 
36
 
37
- def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin'):
38
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
39
  return torch.load(cache_file, map_location='cpu')
40
 
41
  def load_ckpt_config_from_hf(modality):
42
- ckpt = load_from_hf(f'gligen/{modality}')
43
- config = load_from_hf('gligen/demo_config_legacy', filename=f'{modality}.pth')
44
  return ckpt, config
45
 
46
 
47
- if args.load_text_box_generation:
48
- pretrained_ckpt_gligen, config = load_ckpt_config_from_hf('gligen-generation-text-box')
49
  config = OmegaConf.create( config["_content"] ) # config used in training
50
  config.update( vars(args) )
51
- config.model['params']['is_inpaint'] = False
52
- config.model['params']['is_style'] = False
53
- loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen)
54
 
 
 
 
55
 
56
- if args.load_text_box_inpainting:
57
- pretrained_ckpt_gligen_inpaint, config = load_ckpt_config_from_hf('gligen-inpainting-text-box')
58
- config = OmegaConf.create( config["_content"] ) # config used in training
59
- config.update( vars(args) )
60
- config.model['params']['is_inpaint'] = True
61
- config.model['params']['is_style'] = False
62
- loaded_model_list_inpaint = load_ckpt(config, pretrained_ckpt_gligen_inpaint)
63
 
 
64
 
65
- if args.load_text_image_box_generation:
66
- pretrained_ckpt_gligen_style, config = load_ckpt_config_from_hf('gligen-generation-text-image-box')
67
- config = OmegaConf.create( config["_content"] ) # config used in training
68
- config.update( vars(args) )
69
- config.model['params']['is_inpaint'] = False
70
- config.model['params']['is_style'] = True
71
- loaded_model_list_style = load_ckpt(config, pretrained_ckpt_gligen_style)
 
 
 
 
 
 
 
 
72
 
73
 
74
  def load_clip_model():
 
2
  import torch
3
  import argparse
4
  from omegaconf import OmegaConf
5
+ from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
6
 
7
  import json
8
  import numpy as np
 
34
  args = parse_option()
35
 
36
 
37
+ def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
38
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
39
  return torch.load(cache_file, map_location='cpu')
40
 
41
  def load_ckpt_config_from_hf(modality):
42
+ ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model')
43
+ config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config')
44
  return ckpt, config
45
 
46
 
47
+ def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
48
+ pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
49
  config = OmegaConf.create( config["_content"] ) # config used in training
50
  config.update( vars(args) )
51
+ config.model['params']['is_inpaint'] = is_inpaint
52
+ config.model['params']['is_style'] = is_style
 
53
 
54
+ if common_instances is None:
55
+ common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
56
+ common_instances = load_common_ckpt(config, common_ckpt)
57
 
58
+ loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances)
 
 
 
 
 
 
59
 
60
+ return loaded_model_list, common_instances
61
 
62
+
63
+ loaded_model_list, common_instances = ckpt_load_helper(
64
+ 'gligen-generation-text-box',
65
+ is_inpaint=False, is_style=False, common_instances=None
66
+ )
67
+
68
+ loaded_model_list_inpaint = ckpt_load_helper(
69
+ 'gligen-inpainting-text-box',
70
+ is_inpaint=True, is_style=False, common_instances=common_instances
71
+ )[0]
72
+
73
+ loaded_model_list_style = ckpt_load_helper(
74
+ 'gligen-generation-text-image-box',
75
+ is_inpaint=False, is_style=True, common_instances=common_instances
76
+ )[0]
77
 
78
 
79
  def load_clip_model():
gligen/task_grounded_generation.py CHANGED
@@ -65,21 +65,25 @@ def draw_box(img, locations):
65
  draw.rectangle([box[0]*WW, box[1]*HH, box[2]*WW, box[3]*HH], outline =colors[bid % len(colors)], width=5)
66
  return img
67
 
68
- def load_ckpt(config, state_dict):
69
- model = instantiate_from_config(config.model).to(device).eval()
70
  autoencoder = instantiate_from_config(config.autoencoder).to(device).eval()
71
  text_encoder = instantiate_from_config(config.text_encoder).to(device).eval()
72
  diffusion = instantiate_from_config(config.diffusion).to(device)
73
 
74
- autoencoder.load_state_dict( state_dict["autoencoder"] )
75
- text_encoder.load_state_dict( state_dict["text_encoder"] )
76
- diffusion.load_state_dict( state_dict["diffusion"] )
 
 
 
 
 
77
 
78
  model.load_state_dict(state_dict['model'])
79
  set_alpha_scale(model, config.alpha_scale)
80
  print("ckpt is loaded")
81
 
82
- return model, autoencoder, text_encoder, diffusion
83
 
84
 
85
 
 
65
  draw.rectangle([box[0]*WW, box[1]*HH, box[2]*WW, box[3]*HH], outline =colors[bid % len(colors)], width=5)
66
  return img
67
 
68
+ def load_common_ckpt(config, common_ckpt):
 
69
  autoencoder = instantiate_from_config(config.autoencoder).to(device).eval()
70
  text_encoder = instantiate_from_config(config.text_encoder).to(device).eval()
71
  diffusion = instantiate_from_config(config.diffusion).to(device)
72
 
73
+ autoencoder.load_state_dict( common_ckpt["autoencoder"] )
74
+ text_encoder.load_state_dict( common_ckpt["text_encoder"] )
75
+ diffusion.load_state_dict( common_ckpt["diffusion"] )
76
+
77
+ return [autoencoder, text_encoder, diffusion]
78
+
79
+ def load_ckpt(config, state_dict, common_instances):
80
+ model = instantiate_from_config(config.model).to(device).eval()
81
 
82
  model.load_state_dict(state_dict['model'])
83
  set_alpha_scale(model, config.alpha_scale)
84
  print("ckpt is loaded")
85
 
86
+ return [model] + common_instances
87
 
88
 
89