lnyan commited on
Commit
113ecf3
1 Parent(s): e181bbd

Add stage2 support

Browse files
Files changed (2) hide show
  1. app.py +9 -8
  2. model.py +52 -49
app.py CHANGED
@@ -8,7 +8,8 @@ from model import AppModel
8
 
9
  DESCRIPTION = '''# <a href="https://github.com/THUDM/CogVideo">CogVideo</a>
10
 
11
- Currently, this Space only supports the first stage of the CogVideo pipeline due to hardware limitations.
 
12
 
13
  The model accepts only Chinese as input.
14
  By checking the "Translate to Chinese" checkbox, the results of English to Chinese translation with [this Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) will be used as input.
@@ -19,7 +20,7 @@ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge
19
 
20
 
21
  def main():
22
- only_first_stage = True
23
  model = AppModel(only_first_stage)
24
 
25
  with gr.Blocks(css='style.css') as demo:
@@ -36,9 +37,9 @@ def main():
36
  step=1,
37
  value=1234,
38
  label='Seed')
39
- only_first_stage = gr.Checkbox(
40
  label='Only First Stage',
41
- value=only_first_stage,
42
  visible=not only_first_stage)
43
  image_prompt = gr.Image(type="filepath",
44
  label="Image Prompt",
@@ -53,10 +54,10 @@ def main():
53
  result_video = gr.Video(show_label=False)
54
 
55
  examples = gr.Examples(
56
- examples=[['骑滑板的皮卡丘', False, 1234, True,None],
57
- ['a cat playing chess', True, 1253, True,None]],
58
  fn=model.run_with_translation,
59
- inputs=[text, translate, seed, only_first_stage,image_prompt],
60
  outputs=[translated_text, result_video],
61
  cache_examples=True)
62
 
@@ -68,7 +69,7 @@ def main():
68
  text,
69
  translate,
70
  seed,
71
- only_first_stage,
72
  image_prompt
73
  ],
74
  outputs=[translated_text, result_video])
 
8
 
9
  DESCRIPTION = '''# <a href="https://github.com/THUDM/CogVideo">CogVideo</a>
10
 
11
+ This Space supports the first stage and the second stage (better quality) of the CogVideo pipeline.
12
+ Note that the second stage of CogVideo is **much slower**.
13
 
14
  The model accepts only Chinese as input.
15
  By checking the "Translate to Chinese" checkbox, the results of English to Chinese translation with [this Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) will be used as input.
 
20
 
21
 
22
  def main():
23
+ only_first_stage = False
24
  model = AppModel(only_first_stage)
25
 
26
  with gr.Blocks(css='style.css') as demo:
 
37
  step=1,
38
  value=1234,
39
  label='Seed')
40
+ run_only_first_stage = gr.Checkbox(
41
  label='Only First Stage',
42
+ value=True,
43
  visible=not only_first_stage)
44
  image_prompt = gr.Image(type="filepath",
45
  label="Image Prompt",
 
54
  result_video = gr.Video(show_label=False)
55
 
56
  examples = gr.Examples(
57
+ examples=[['骑滑板的皮卡丘', False, 1234, True, None],
58
+ ['a cat playing chess', True, 1253, False, None]],
59
  fn=model.run_with_translation,
60
+ inputs=[text, translate, seed, run_only_first_stage, image_prompt],
61
  outputs=[translated_text, result_video],
62
  cache_examples=True)
63
 
 
69
  text,
70
  translate,
71
  seed,
72
+ run_only_first_stage,
73
  image_prompt
74
  ],
75
  outputs=[translated_text, result_video])
model.py CHANGED
@@ -62,8 +62,8 @@ if os.getenv('SYSTEM') == 'spaces':
62
 
63
  download_and_extract_icetk_models()
64
  download_and_extract_cogvideo_models('cogvideo-stage1.zip')
65
- #download_and_extract_cogvideo_models('cogvideo-stage2.zip')
66
- #download_and_extract_cogview2_models('cogview2-dsr.zip')
67
 
68
  os.environ['SAT_HOME'] = '/home/user/app/pretrained'
69
 
@@ -299,7 +299,8 @@ def my_filling_sequence(
299
  # initialize generation
300
  counter = context_length - 1 # Last fixed index is ``counter''
301
  index = 0 # Next forward starting index, also the length of cache.
302
- mems_buffers_on_GPU = False
 
303
  mems_indexs = [0, 0]
304
  mems_len = [(400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
305
  5 * 400 + 74]
@@ -308,7 +309,8 @@ def my_filling_sequence(
308
  batch_size,
309
  mem_len,
310
  args.hidden_size * 2,
311
- dtype=next(model.parameters()).dtype)
 
312
  for mem_len in mems_len
313
  ]
314
 
@@ -320,13 +322,13 @@ def my_filling_sequence(
320
  batch_size,
321
  mem_len,
322
  args.hidden_size * 2,
323
- dtype=next(model.parameters()).dtype)
 
324
  for mem_len in mems_len
325
  ]
326
  guider_mems_indexs = [0, 0]
327
  guider_mems = None
328
 
329
- torch.cuda.empty_cache()
330
  # step-by-step generation
331
  while counter < len(seq[0]) - 1:
332
  # we have generated counter+1 tokens
@@ -448,34 +450,34 @@ def my_filling_sequence(
448
  ]
449
  guider_logits = guider_logits_all
450
  else:
451
- if not mems_buffers_on_GPU:
452
- if not mode_stage1:
453
- torch.cuda.empty_cache()
454
- for idx, mem in enumerate(mems):
455
- mems[idx] = mem.to(next(model.parameters()).device)
456
- if guider_seq is not None:
457
- for idx, mem in enumerate(guider_mems):
458
- guider_mems[idx] = mem.to(
459
- next(model.parameters()).device)
460
- else:
461
- torch.cuda.empty_cache()
462
- for idx, mem_buffer in enumerate(mems_buffers):
463
- mems_buffers[idx] = mem_buffer.to(
464
- next(model.parameters()).device)
465
- mems = [
466
- mems_buffers[id][:, :, :mems_indexs[id]]
467
- for id in range(2)
468
- ]
469
- if guider_seq is not None:
470
- for idx, guider_mem_buffer in enumerate(
471
- guider_mems_buffers):
472
- guider_mems_buffers[idx] = guider_mem_buffer.to(
473
- next(model.parameters()).device)
474
- guider_mems = [
475
- guider_mems_buffers[id]
476
- [:, :, :guider_mems_indexs[id]] for id in range(2)
477
- ]
478
- mems_buffers_on_GPU = True
479
 
480
  logits, *output_per_layers = model(
481
  input_tokens[:, index:],
@@ -513,17 +515,17 @@ def my_filling_sequence(
513
  o['mem_kv'][0] for o in guider_output_per_layers
514
  ], [o['mem_kv'][1] for o in guider_output_per_layers]
515
 
516
- if not mems_buffers_on_GPU:
517
- torch.cuda.empty_cache()
518
- for idx, mem_buffer in enumerate(mems_buffers):
519
- mems_buffers[idx] = mem_buffer.to(
520
- next(model.parameters()).device)
521
- if guider_seq is not None:
522
- for idx, guider_mem_buffer in enumerate(
523
- guider_mems_buffers):
524
- guider_mems_buffers[idx] = guider_mem_buffer.to(
525
- next(model.parameters()).device)
526
- mems_buffers_on_GPU = True
527
 
528
  mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1],
529
  mems_buffers, mems_indexs,
@@ -677,7 +679,7 @@ def get_default_args() -> argparse.Namespace:
677
  '--batch-size',
678
  '1',
679
  '--max-inference-batch-size',
680
- '8',
681
  ]
682
  args = get_args(args_list)
683
  args = argparse.Namespace(**vars(args), **vars(known))
@@ -779,7 +781,7 @@ class Model:
779
  path = auto_create('cogview2-dsr', path=None)
780
  dsr = DirectSuperResolution(self.args,
781
  path,
782
- max_bz=12,
783
  onCUDA=False)
784
  else:
785
  dsr = None
@@ -1184,7 +1186,8 @@ class Model:
1184
  else:
1185
  self.args.stage_1 = False
1186
  self.args.both_stages = True
1187
-
 
1188
  parent_given_tokens, res = self.process_stage1(
1189
  self.model_stage1,
1190
  text,
@@ -1231,7 +1234,7 @@ class AppModel(Model):
1231
 
1232
  def run_with_translation(
1233
  self, text: str, translate: bool, seed: int,
1234
- only_first_stage: bool,image_prompt: None) -> tuple[str | None, str | None]:
1235
 
1236
  logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=},{image_prompt=}')
1237
  if translate:
 
62
 
63
  download_and_extract_icetk_models()
64
  download_and_extract_cogvideo_models('cogvideo-stage1.zip')
65
+ download_and_extract_cogvideo_models('cogvideo-stage2.zip')
66
+ download_and_extract_cogview2_models('cogview2-dsr.zip')
67
 
68
  os.environ['SAT_HOME'] = '/home/user/app/pretrained'
69
 
 
299
  # initialize generation
300
  counter = context_length - 1 # Last fixed index is ``counter''
301
  index = 0 # Next forward starting index, also the length of cache.
302
+ # mems_buffers_on_GPU = False
303
+ torch.cuda.empty_cache()
304
  mems_indexs = [0, 0]
305
  mems_len = [(400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
306
  5 * 400 + 74]
 
309
  batch_size,
310
  mem_len,
311
  args.hidden_size * 2,
312
+ dtype=next(model.parameters()).dtype,
313
+ device=next(model.parameters()).device)
314
  for mem_len in mems_len
315
  ]
316
 
 
322
  batch_size,
323
  mem_len,
324
  args.hidden_size * 2,
325
+ dtype=next(model.parameters()).dtype,
326
+ device=next(model.parameters()).device)
327
  for mem_len in mems_len
328
  ]
329
  guider_mems_indexs = [0, 0]
330
  guider_mems = None
331
 
 
332
  # step-by-step generation
333
  while counter < len(seq[0]) - 1:
334
  # we have generated counter+1 tokens
 
450
  ]
451
  guider_logits = guider_logits_all
452
  else:
453
+ # if not mems_buffers_on_GPU:
454
+ # if not mode_stage1:
455
+ # torch.cuda.empty_cache()
456
+ # for idx, mem in enumerate(mems):
457
+ # mems[idx] = mem.to(next(model.parameters()).device)
458
+ # if guider_seq is not None:
459
+ # for idx, mem in enumerate(guider_mems):
460
+ # guider_mems[idx] = mem.to(
461
+ # next(model.parameters()).device)
462
+ # else:
463
+ # torch.cuda.empty_cache()
464
+ # for idx, mem_buffer in enumerate(mems_buffers):
465
+ # mems_buffers[idx] = mem_buffer.to(
466
+ # next(model.parameters()).device)
467
+ # mems = [
468
+ # mems_buffers[id][:, :, :mems_indexs[id]]
469
+ # for id in range(2)
470
+ # ]
471
+ # if guider_seq is not None:
472
+ # for idx, guider_mem_buffer in enumerate(
473
+ # guider_mems_buffers):
474
+ # guider_mems_buffers[idx] = guider_mem_buffer.to(
475
+ # next(model.parameters()).device)
476
+ # guider_mems = [
477
+ # guider_mems_buffers[id]
478
+ # [:, :, :guider_mems_indexs[id]] for id in range(2)
479
+ # ]
480
+ # mems_buffers_on_GPU = True
481
 
482
  logits, *output_per_layers = model(
483
  input_tokens[:, index:],
 
515
  o['mem_kv'][0] for o in guider_output_per_layers
516
  ], [o['mem_kv'][1] for o in guider_output_per_layers]
517
 
518
+ # if not mems_buffers_on_GPU:
519
+ # torch.cuda.empty_cache()
520
+ # for idx, mem_buffer in enumerate(mems_buffers):
521
+ # mems_buffers[idx] = mem_buffer.to(
522
+ # next(model.parameters()).device)
523
+ # if guider_seq is not None:
524
+ # for idx, guider_mem_buffer in enumerate(
525
+ # guider_mems_buffers):
526
+ # guider_mems_buffers[idx] = guider_mem_buffer.to(
527
+ # next(model.parameters()).device)
528
+ # mems_buffers_on_GPU = True
529
 
530
  mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1],
531
  mems_buffers, mems_indexs,
 
679
  '--batch-size',
680
  '1',
681
  '--max-inference-batch-size',
682
+ '1',
683
  ]
684
  args = get_args(args_list)
685
  args = argparse.Namespace(**vars(args), **vars(known))
 
781
  path = auto_create('cogview2-dsr', path=None)
782
  dsr = DirectSuperResolution(self.args,
783
  path,
784
+ max_bz=4,
785
  onCUDA=False)
786
  else:
787
  dsr = None
 
1186
  else:
1187
  self.args.stage_1 = False
1188
  self.args.both_stages = True
1189
+
1190
+ torch.cuda.empty_cache()
1191
  parent_given_tokens, res = self.process_stage1(
1192
  self.model_stage1,
1193
  text,
 
1234
 
1235
  def run_with_translation(
1236
  self, text: str, translate: bool, seed: int,
1237
+ only_first_stage: bool, image_prompt: None) -> tuple[str | None, str | None]:
1238
 
1239
  logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=},{image_prompt=}')
1240
  if translate: