myniu commited on
Commit
903bfce
1 Parent(s): 15183dc
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -237,7 +237,7 @@ class Drag:
237
  frames = frames.flatten(0, 1) # [b*13, 3, 256, 256]
238
  sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256]
239
  mask = mask.flatten(0, 1) # [b*13, 2, 256, 256]
240
- cmp_flow = cmp.run(frames, sparse_optical_flow, mask) # [b*13, 2, 256, 256]
241
 
242
  if brush_mask is not None:
243
  brush_mask = torch.from_numpy(brush_mask) / 255.
@@ -323,7 +323,7 @@ class Drag:
323
 
324
  controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask)
325
 
326
- val_output = pipeline(
327
  input_first_frame_pil,
328
  input_first_frame_pil,
329
  controlnet_flow,
 
237
  frames = frames.flatten(0, 1) # [b*13, 3, 256, 256]
238
  sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256]
239
  mask = mask.flatten(0, 1) # [b*13, 2, 256, 256]
240
+ cmp_flow = self.cmp.run(frames, sparse_optical_flow, mask) # [b*13, 2, 256, 256]
241
 
242
  if brush_mask is not None:
243
  brush_mask = torch.from_numpy(brush_mask) / 255.
 
323
 
324
  controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask)
325
 
326
+ val_output = self.pipeline(
327
  input_first_frame_pil,
328
  input_first_frame_pil,
329
  controlnet_flow,