supersolar commited on
Commit
7c9838f
1 Parent(s): 42c15b4

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +10 -10
infer.py CHANGED
@@ -16,6 +16,7 @@ from utils.seed_all import seed_all
16
 
17
  from contextlib import nullcontext
18
  import cv2
 
19
 
20
  check_min_version('0.28.0.dev0')
21
 
@@ -53,14 +54,14 @@ def infer_pipe(pipe, image_input, task_name, seed, device):
53
  ).images[0]
54
 
55
  # Post-process the prediction
56
- # 在 infer_pipe 函数中
57
  if task_name == 'depth':
58
  output_npy = pred.mean(axis=-1)
59
- # 修改为输出灰度图
60
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
61
  else:
62
  output_npy = pred
63
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
 
64
  return output_color
65
 
66
  def lotus_video(input_video, task_name, seed, device):
@@ -98,7 +99,7 @@ def lotus_video(input_video, task_name, seed, device):
98
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
99
 
100
  output_g = []
101
- for frame in frames:
102
  if torch.backends.mps.is_available():
103
  autocast_ctx = nullcontext()
104
  else:
@@ -122,16 +123,15 @@ def lotus_video(input_video, task_name, seed, device):
122
  task_emb=task_emb,
123
  ).images[0]
124
  # Post-process the prediction
125
- # 在 lotus_video 函数中
126
  if task_name == 'depth':
127
- output_npy_g = pred_g.mean(axis=-1)
128
  # 修改为输出灰度图
129
- output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8), mode='L')
130
  else:
131
  output_npy_g = pred_g
132
  output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
133
- output_g.append(output_color_g)
134
-
135
 
136
  return output_g
137
 
@@ -307,14 +307,14 @@ def main():
307
 
308
  # Post-process the prediction
309
  save_file_name = os.path.basename(test_images[i])[:-4]
310
- # infer_pipe 函数中
311
- if task_name == 'depth':
312
  output_npy = pred.mean(axis=-1)
313
  # 修改为输出灰度图
314
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
315
  else:
316
  output_npy = pred
317
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
 
318
  output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png'))
319
  np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy)
320
 
 
16
 
17
  from contextlib import nullcontext
18
  import cv2
19
+ from tqdm import tqdm # 添加这一行以导入 tqdm
20
 
21
  check_min_version('0.28.0.dev0')
22
 
 
54
  ).images[0]
55
 
56
  # Post-process the prediction
 
57
  if task_name == 'depth':
58
  output_npy = pred.mean(axis=-1)
59
+ # 修改为输出灰度图
60
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
61
  else:
62
  output_npy = pred
63
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
64
+
65
  return output_color
66
 
67
  def lotus_video(input_video, task_name, seed, device):
 
99
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
100
 
101
  output_g = []
102
+ for frame in tqdm(frames, desc="Processing frames"): # 使用 tqdm 包裹 frames 列表
103
  if torch.backends.mps.is_available():
104
  autocast_ctx = nullcontext()
105
  else:
 
123
  task_emb=task_emb,
124
  ).images[0]
125
  # Post-process the prediction
 
126
  if task_name == 'depth':
127
+ output_npy = pred.mean(axis=-1)
128
  # 修改为输出灰度图
129
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
130
  else:
131
  output_npy_g = pred_g
132
  output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
133
+
134
+ output_g.append(output_color_g)
135
 
136
  return output_g
137
 
 
307
 
308
  # Post-process the prediction
309
  save_file_name = os.path.basename(test_images[i])[:-4]
310
+ if args.task_name == 'depth':
 
311
  output_npy = pred.mean(axis=-1)
312
  # 修改为输出灰度图
313
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
314
  else:
315
  output_npy = pred
316
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
317
+
318
  output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png'))
319
  np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy)
320