Jiatao Gu commited on
Commit
94ada0b
1 Parent(s): dfe4abb

add code from the original repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +25 -0
  2. README.md +66 -28
  3. app.py +213 -0
  4. conf/config.yaml +49 -0
  5. conf/hydra/local.yaml +16 -0
  6. conf/model/default.yaml +35 -0
  7. conf/model/stylenerf_afhq.yaml +108 -0
  8. conf/model/stylenerf_cars.yaml +108 -0
  9. conf/model/stylenerf_cars_debug.yaml +105 -0
  10. conf/model/stylenerf_ffhq.yaml +108 -0
  11. conf/model/stylenerf_ffhq_ae.yaml +118 -0
  12. conf/model/stylenerf_ffhq_ae_basic.yaml +110 -0
  13. conf/model/stylenerf_ffhq_debug.yaml +103 -0
  14. conf/model/stylenerf_ffhq_eg3d.yaml +100 -0
  15. conf/model/stylenerf_ffhq_warped_depth.yaml +97 -0
  16. conf/spec/cifar.yaml +13 -0
  17. conf/spec/nerf32.yaml +14 -0
  18. conf/spec/paper1024.yaml +14 -0
  19. conf/spec/paper256.yaml +14 -0
  20. conf/spec/paper512.yaml +12 -0
  21. conf/spec/stylegan2.yaml +14 -0
  22. dnnlib/__init__.py +9 -0
  23. dnnlib/camera.py +687 -0
  24. dnnlib/filters.py +81 -0
  25. dnnlib/geometry.py +406 -0
  26. dnnlib/util.py +531 -0
  27. generate.py +202 -0
  28. gui_utils/__init__.py +9 -0
  29. gui_utils/gl_utils.py +374 -0
  30. gui_utils/glfw_window.py +229 -0
  31. gui_utils/imgui_utils.py +169 -0
  32. gui_utils/imgui_window.py +103 -0
  33. gui_utils/text_utils.py +123 -0
  34. launcher.py +189 -0
  35. legacy.py +320 -0
  36. metrics/__init__.py +9 -0
  37. metrics/frechet_inception_distance.py +40 -0
  38. metrics/inception_score.py +38 -0
  39. metrics/kernel_inception_distance.py +46 -0
  40. metrics/metric_main.py +152 -0
  41. metrics/metric_utils.py +305 -0
  42. metrics/perceptual_path_length.py +131 -0
  43. metrics/precision_recall.py +62 -0
  44. renderer.py +322 -0
  45. requirements.txt +29 -0
  46. run_train.py +398 -0
  47. torch_utils/__init__.py +9 -0
  48. torch_utils/custom_ops.py +157 -0
  49. torch_utils/distributed_utils.py +213 -0
  50. torch_utils/misc.py +303 -0
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .cache/
3
+ datasets/
4
+ outputs/
5
+ out/
6
+ out2/
7
+ debug/
8
+ checkpoint/
9
+ *.zip
10
+ *.npy
11
+ core
12
+ history/
13
+ tools/*
14
+ tools
15
+ eval_outputs/
16
+ pretrained/
17
+ .nfs00d20000091cb3390001ead3
18
+ scripts/research/
19
+
20
+ .idea
21
+ .vscode
22
+ .github
23
+ .ipynb_checkpoints/
24
+ _screenshots/
25
+ flagged
README.md CHANGED
@@ -1,37 +1,75 @@
1
- ---
2
- title: StyleNeRF
3
- emoji: 🏃
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: gradio
7
- app_file: app.py
8
- pinned: false
9
- ---
10
 
11
- # Configuration
12
 
13
- `title`: _string_
14
- Display title for the Space
 
15
 
16
- `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
18
 
19
- `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
 
 
 
21
 
22
- `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
 
25
- `sdk`: _string_
26
- Can be either `gradio` or `streamlit`
 
27
 
28
- `sdk_version` : _string_
29
- Only applicable for `streamlit` SDK.
30
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
- Path is relative to the root of the repository.
35
 
36
- `pinned`: _boolean_
37
- Whether the Space stays on top of your list.
 
1
+ # StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis</sub>
 
 
 
 
 
 
 
 
2
 
3
+ ![Random Sample](./docs/random_sample.jpg)
4
 
5
+ **StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis**<br>
6
+ Jiatao Gu, Lingjie Liu, Peng Wang, Christian Theobalt<br>
7
+ ### [Project Page](http://jiataogu.me/style_nerf) | [Video](http://jiataogu.me/style_nerf) | [Paper](https://arxiv.org/abs/2110.08985) | [Data](#dataset)<br>
8
 
9
+ Abstract: *We propose StyleNeRF, a 3D-aware generative model for photo-realistic high-resolution image synthesis with high multi-view consistency, which can be trained on unstructured 2D images. Existing approaches either cannot synthesize high-resolution images with fine details or yield noticeable 3D-inconsistent artifacts. In addition, many of them lack control over style attributes and explicit 3D camera poses. StyleNeRF integrates the neural radiance field (NeRF) into a style-based generator to tackle the aforementioned challenges, i.e., improving rendering efficiency and 3D consistency for high-resolution image generation. We perform volume rendering only to produce a low-resolution feature map and progressively apply upsampling in 2D to address the first issue. To mitigate the inconsistencies caused by 2D upsampling, we propose multiple designs, including a better upsampler and a new regularization loss. With these designs, StyleNeRF can synthesize high-resolution images at interactive rates while preserving 3D consistency at high quality. StyleNeRF also enables control of camera poses and different levels of styles, which can generalize to unseen views. It also supports challenging tasks, including zoom-in and-out, style mixing, inversion, and semantic editing.*
 
10
 
11
+ ## Requirements
12
+ The codebase is tested on
13
+ * Python 3.7
14
+ * PyTorch 1.7.1
15
+ * 8 Nvidia GPU (Tesla V100 32GB) with CUDA version 11.0
16
 
17
+ For additional python libraries, please install by:
 
18
 
19
+ ```
20
+ pip install -r requirements.txt
21
+ ```
22
 
23
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch for additional software/hardware requirements.
24
+
25
+ ## Dataset
26
+ We follow the same dataset format as StyleGAN2-ADA supported, which can be either an image folder, or a zipped file.
27
+
28
+
29
+ ## Train a new StyleNeRF model
30
+ ```bash
31
+ python run_train.py outdir=${OUTDIR} data=${DATASET} spec=paper512 model=stylenerf_ffhq
32
+ ```
33
+ It will automatically detect all usable GPUs.
34
+
35
+ Please check configuration files at ```conf/model``` and ```conf/spec```. You can always add your own model config. More details on how to use hydra configuration please follow https://hydra.cc/docs/intro/.
36
+
37
+ ## Render the pretrained model
38
+ ```bash
39
+ python generate.py --outdir=${OUTDIR} --trunc=0.7 --seeds=${SEEDS} --network=${CHECKPOINT_PATH} --render-program="rotation_camera"
40
+ ```
41
+ It supports different rotation trajectories for rendering new videos.
42
+
43
+ ## Run a demo page
44
+ ```bash
45
+ python web_demo.py 21111
46
+ ```
47
+ It will in default run a Gradio-powered demo on https://localhost:21111
48
+ ![Web demo](./docs/web_demo.gif)
49
+ ## Run a GUI visualizer
50
+ ```bash
51
+ python visualizer.py
52
+ ```
53
+ An interative application will show up for users to play with.
54
+ ![GUI demo](./docs/gui_demo.gif)
55
+ ## Citation
56
+
57
+ ```
58
+ @inproceedings{
59
+ gu2022stylenerf,
60
+ title={StyleNeRF: A Style-based 3D Aware Generator for High-resolution Image Synthesis},
61
+ author={Jiatao Gu and Lingjie Liu and Peng Wang and Christian Theobalt},
62
+ booktitle={International Conference on Learning Representations},
63
+ year={2022},
64
+ url={https://openreview.net/forum?id=iUuzzTMUw9K}
65
+ }
66
+ ```
67
+
68
+
69
+ ## License
70
+
71
+ Copyright &copy; Facebook, Inc. All Rights Reserved.
72
+
73
+ The majority of StyleNeRF is licensed under [CC-BY-NC](https://creativecommons.org/licenses/by-nc/4.0/), however, portions of this project are available under a separate license terms: all codes used or modified from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch) is under the [Nvidia Source Code License](https://nvlabs.github.io/stylegan2-ada-pytorch/license.html).
74
 
 
 
 
75
 
 
 
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import os, sys
3
+ os.system('pip install -r requirements.txt')
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import dnnlib
8
+ import time
9
+ import legacy
10
+ import torch
11
+ import glob
12
+
13
+ import cv2
14
+ import signal
15
+ from torch_utils import misc
16
+ from renderer import Renderer
17
+ from training.networks import Generator
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ device = torch.device('cuda')
22
+ port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
23
+
24
+
25
+
26
+ def handler(signum, frame):
27
+ res = input("Ctrl-c was pressed. Do you really want to exit? y/n ")
28
+ if res == 'y':
29
+ gr.close_all()
30
+ exit(1)
31
+
32
+ signal.signal(signal.SIGINT, handler)
33
+
34
+
35
+ def set_random_seed(seed):
36
+ torch.manual_seed(seed)
37
+ np.random.seed(seed)
38
+
39
+
40
+ def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name='FFHQ512'):
41
+ gen = model.synthesis
42
+ range_u, range_v = gen.C.range_u, gen.C.range_v
43
+ if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
44
+ yaw, pitch = 0.5 * yaw, 0.3 * pitch
45
+ pitch = pitch + np.pi/2
46
+ u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
47
+ v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
48
+ else:
49
+ u = (yaw + 1) / 2
50
+ v = (pitch + 1) / 2
51
+ cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device, fov=fov)
52
+ return cam
53
+
54
+
55
+ def check_name(model_name='FFHQ512'):
56
+ """Gets model by name."""
57
+ if model_name == 'FFHQ512':
58
+ network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
59
+
60
+ # TODO: checkpoint to be updated!
61
+ # elif model_name == 'FFHQ512v2':
62
+ # network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
63
+ # elif model_name == 'AFHQ512':
64
+ # network_pkl = "./pretrained/afhq_512.pkl"
65
+ # elif model_name == 'MetFaces512':
66
+ # network_pkl = "./pretrained/metfaces_512.pkl"
67
+ # elif model_name == 'CompCars256':
68
+ # network_pkl = "./pretrained/cars_256.pkl"
69
+ # elif model_name == 'FFHQ1024':
70
+ # network_pkl = "./pretrained/ffhq_1024.pkl"
71
+ else:
72
+ if os.path.isdir(model_name):
73
+ network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
74
+ else:
75
+ network_pkl = model_name
76
+ return network_pkl
77
+
78
+
79
+ def get_model(network_pkl, render_option=None):
80
+ print('Loading networks from "%s"...' % network_pkl)
81
+ with dnnlib.util.open_url(network_pkl) as f:
82
+ network = legacy.load_network_pkl(f)
83
+ G = network['G_ema'].to(device) # type: ignore
84
+
85
+ with torch.no_grad():
86
+ G2 = Generator(*G.init_args, **G.init_kwargs).to(device)
87
+ misc.copy_params_and_buffers(G, G2, require_all=False)
88
+
89
+ print('compile and go through the initial image')
90
+ G2 = G2.eval()
91
+ init_z = torch.from_numpy(np.random.RandomState(0).rand(1, G2.z_dim)).to(device)
92
+ init_cam = get_camera_traj(G2, 0, 0, model_name=network_pkl)
93
+ dummy = G2(z=init_z, c=None, camera_matrices=init_cam, render_option=render_option, theta=0)
94
+ res = dummy['img'].shape[-1]
95
+ imgs = np.zeros((res, res//2, 3))
96
+ return G2, res, imgs
97
+
98
+
99
+ global_states = list(get_model(check_name()))
100
+ wss = [None, None]
101
+
102
+ def proc_seed(history, seed):
103
+ if isinstance(seed, str):
104
+ seed = 0
105
+ else:
106
+ seed = int(seed)
107
+
108
+
109
+ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
110
+ history = history or {}
111
+ seeds = []
112
+
113
+ if model_find != "":
114
+ model_name = model_find
115
+
116
+ model_name = check_name(model_name)
117
+ if model_name != history.get("model_name", None):
118
+ model, res, imgs = get_model(model_name, render_option)
119
+ global_states[0] = model
120
+ global_states[1] = res
121
+ global_states[2] = imgs
122
+
123
+ model, res, imgs = global_states
124
+ for idx, seed in enumerate([seed1, seed2]):
125
+ if isinstance(seed, str):
126
+ seed = 0
127
+ else:
128
+ seed = int(seed)
129
+
130
+ if (seed != history.get(f'seed{idx}', -1)) or \
131
+ (model_name != history.get("model_name", None)) or \
132
+ (trunc != history.get("trunc", 0.7)) or \
133
+ (wss[idx] is None):
134
+ print(f'use seed {seed}')
135
+ set_random_seed(seed)
136
+ z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
137
+ ws = model.mapping(z=z, c=None, truncation_psi=trunc)
138
+ img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0), render_option=render_option)
139
+ ws = ws.detach().cpu().numpy()
140
+ img = img[0].permute(1,2,0).detach().cpu().numpy()
141
+
142
+
143
+ imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
144
+ np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
145
+ (res//2, res//2), cv2.INTER_AREA)
146
+ wss[idx] = ws
147
+ else:
148
+ seed = history[f'seed{idx}']
149
+ seeds += [seed]
150
+
151
+ history[f'seed{idx}'] = seed
152
+ history['trunc'] = trunc
153
+ history['model_name'] = model_name
154
+
155
+ set_random_seed(sum(seeds))
156
+
157
+ # style mixing (?)
158
+ ws1, ws2 = [torch.from_numpy(ws).to(device) for ws in wss]
159
+ ws = ws1.clone()
160
+ ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
161
+ ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
162
+
163
+ # set visualization for other types of inputs.
164
+ if early == 'Normal Map':
165
+ render_option += ',normal,early'
166
+ elif early == 'Gradient Map':
167
+ render_option += ',gradient,early'
168
+
169
+ start_t = time.time()
170
+ with torch.no_grad():
171
+ cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
172
+ image = model.get_final_output(
173
+ styles=ws, camera_matrices=cam,
174
+ theta=roll * np.pi,
175
+ render_option=render_option)
176
+ end_t = time.time()
177
+
178
+ image = image[0].permute(1,2,0).detach().cpu().numpy().clip(-1, 1) * 0.5 + 0.5
179
+
180
+ if imgs.shape[0] == image.shape[0]:
181
+ image = np.concatenate([imgs, image], 1)
182
+ else:
183
+ a = image.shape[0]
184
+ b = int(imgs.shape[1] / imgs.shape[0] * a)
185
+ print(f'resize {a} {b} {image.shape} {imgs.shape}')
186
+ image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
187
+
188
+ print(f'rendering time = {end_t-start_t:.4f}s')
189
+ image = (image * 255).astype('uint8')
190
+ return image, history
191
+
192
+ model_name = gr.inputs.Dropdown(['FFHQ512']) # 'FFHQ512v2', 'AFHQ512', 'MetFaces512', 'CompCars256', 'FFHQ1024'
193
+ model_find = gr.inputs.Textbox(label="checkpoint path", default="")
194
+ render_option = gr.inputs.Textbox(label="rendering options", default='steps:40')
195
+ trunc = gr.inputs.Slider(default=0.7, maximum=1.0, minimum=0.0, label='truncation trick')
196
+ seed1 = gr.inputs.Number(default=1, label="seed1")
197
+ seed2 = gr.inputs.Number(default=9, label="seed2")
198
+ mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="linear mixing ratio (geometry)")
199
+ mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="linear mixing ratio (apparence)")
200
+ early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='intermedia output')
201
+ yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="yaw")
202
+ pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="pitch")
203
+ roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="roll (optional, not suggested)")
204
+ fov = gr.inputs.Slider(minimum=9, maximum=15, default=12, label="fov")
205
+ css = ".output_image {height: 40rem !important; width: 100% !important;}"
206
+
207
+ gr.Interface(fn=f_synthesis,
208
+ inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
209
+ title="Interctive Web Demo for StyleNeRF (ICLR 2022)",
210
+ outputs=["image", "state"],
211
+ layout='unaligned',
212
+ css=css, theme='dark-huggingface',
213
+ live=True).launch(server_port=port)
conf/config.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - model: default
4
+ - spec: paper512
5
+
6
+ # general options
7
+ outdir: ~
8
+ dry_run: False
9
+ debug: False
10
+ resume_run: ~
11
+
12
+ snap: 50 # Snapshot interval [default: 50 ticks]
13
+ imgsnap: 10
14
+ metrics: [ "fid50k_full" ]
15
+ seed: 2
16
+ num_fp16_res: 4
17
+ auto: False
18
+
19
+ # dataset
20
+ data: ~
21
+ resolution: ~
22
+ cond: False
23
+ subset: ~ # Train with only N images: <int>, default = all
24
+ mirror: False
25
+
26
+ # discriminator augmentation
27
+ aug: noaug
28
+ p: ~
29
+ target: ~
30
+ augpipe: ~
31
+
32
+ # transfer learning
33
+ resume: ~
34
+ freezed: ~
35
+
36
+ # performance options
37
+ fp32: False
38
+ nhwc: False
39
+ allow_tf32: False
40
+ nobench: False
41
+ workers: 3
42
+
43
+ launcher: "spawn"
44
+ partition: ~
45
+ comment: ~
46
+ gpus: ~ # Number of GPUs to use [default: 1]
47
+ port: ~
48
+ nodes: ~
49
+ timeout: ~
conf/hydra/local.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sweep:
2
+ dir: /checkpoint/${env:USER}/space/gan/${env:PREFIX}/${hydra.job.name}
3
+ subdir: ${hydra.job.num}
4
+ launcher:
5
+ submitit_folder: ${hydra.sweep.dir}
6
+ timeout_min: 4320
7
+ cpus_per_task: 64
8
+ gpus_per_node: 8
9
+ tasks_per_node: 1
10
+ mem_gb: 400
11
+ nodes: 1
12
+ name: ${env:PREFIX}_${hydra.job.config_name}
13
+ # partition: devlab,learnlab,learnfair,scavenge
14
+ # constraint: volta32gb
15
+ # max_num_timeout: 30
16
+ # exclude: learnfair1381,learnfair5192,learnfair2304
conf/model/default.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: default
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ num_fp16_res: ${num_fp16_res}
14
+ channel_base: ${spec.fmaps}
15
+ channel_max: 512
16
+ conv_clamp: 256
17
+ architecture: skip
18
+
19
+ D_kwargs:
20
+ class_name: "training.networks.Discriminator"
21
+ epilogue_kwargs:
22
+ mbstd_group_size: ${spec.mbstd}
23
+
24
+ num_fp16_res: ${num_fp16_res}
25
+ channel_base: ${spec.fmaps}
26
+ channel_max: 512
27
+ conv_clamp: 256
28
+ architecture: resnet
29
+
30
+ # loss kwargs
31
+ loss_kwargs:
32
+ pl_batch_shrink: 2
33
+ pl_decay: 0.01
34
+ pl_weight: 2
35
+ style_mixing_prob: 0.9
conf/model/stylenerf_afhq.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_afhq
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 1024
17
+ conv_clamp: 256
18
+ kernel_size: 1
19
+ architecture: skip
20
+ upsample_mode: "nn_cat"
21
+
22
+ z_dim_bg: 32
23
+ z_dim: 0
24
+ resolution_vol: 32
25
+ resolution_start: 32
26
+ rgb_out_dim: 256
27
+
28
+ use_noise: False
29
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
30
+ no_bbox: True
31
+ margin: 0
32
+ magnitude_ema_beta: 0.999
33
+
34
+ camera_kwargs:
35
+ range_v: [1.4157963267948965, 1.7257963267948966]
36
+ range_u: [-0.3, 0.3]
37
+ range_radius: [1.0, 1.0]
38
+ depth_range: [0.88, 1.12]
39
+ fov: 12
40
+ gaussian_camera: True
41
+ angular_camera: True
42
+ depth_transform: ~
43
+ dists_normalized: False
44
+ ray_align_corner: False
45
+ bg_start: 0.5
46
+
47
+ renderer_kwargs:
48
+ n_bg_samples: 4
49
+ n_ray_samples: 14
50
+ abs_sigma: False
51
+ hierarchical: True
52
+ no_background: False
53
+
54
+ foreground_kwargs:
55
+ positional_encoding: "normal"
56
+ downscale_p_by: 1
57
+ use_style: "StyleGAN2"
58
+ predict_rgb: True
59
+ use_viewdirs: False
60
+
61
+ background_kwargs:
62
+ positional_encoding: "normal"
63
+ hidden_size: 64
64
+ n_blocks: 4
65
+ downscale_p_by: 1
66
+ skips: []
67
+ inverse_sphere: True
68
+ use_style: "StyleGAN2"
69
+ predict_rgb: True
70
+ use_viewdirs: False
71
+
72
+ upsampler_kwargs:
73
+ channel_base: ${model.G_kwargs.synthesis_kwargs.channel_base}
74
+ channel_max: ${model.G_kwargs.synthesis_kwargs.channel_max}
75
+ no_2d_renderer: False
76
+ no_residual_img: False
77
+ block_reses: ~
78
+ shared_rgb_style: False
79
+ upsample_type: "bilinear"
80
+
81
+ progressive: True
82
+
83
+ # reuglarization
84
+ n_reg_samples: 16
85
+ reg_full: True
86
+
87
+ D_kwargs:
88
+ class_name: "training.stylenerf.Discriminator"
89
+ epilogue_kwargs:
90
+ mbstd_group_size: ${spec.mbstd}
91
+
92
+ num_fp16_res: ${num_fp16_res}
93
+ channel_base: ${spec.fmaps}
94
+ channel_max: 512
95
+ conv_clamp: 256
96
+ architecture: skip
97
+ progressive: ${model.G_kwargs.synthesis_kwargs.progressive}
98
+ lowres_head: ${model.G_kwargs.synthesis_kwargs.resolution_start}
99
+ upsample_type: "bilinear"
100
+ resize_real_early: True
101
+
102
+ # loss kwargs
103
+ loss_kwargs:
104
+ pl_batch_shrink: 2
105
+ pl_decay: 0.01
106
+ pl_weight: 2
107
+ style_mixing_prob: 0.9
108
+ curriculum: [500,5000]
conf/model/stylenerf_cars.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_ffhq
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 1024
17
+ conv_clamp: 256
18
+ kernel_size: 1
19
+ architecture: skip
20
+ upsample_mode: "pixelshuffle"
21
+
22
+ z_dim_bg: 32
23
+ z_dim: 0
24
+ resolution_vol: 32
25
+ resolution_start: 32
26
+ rgb_out_dim: 256
27
+
28
+ use_noise: False
29
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
30
+ no_bbox: True
31
+ margin: 0
32
+ magnitude_ema_beta: 0.999
33
+
34
+ camera_kwargs:
35
+ range_v: [1.4157963267948965, 1.7257963267948966]
36
+ range_u: [-3.141592653589793, 3.141592653589793]
37
+ range_radius: [1.0, 1.0]
38
+ depth_range: [0.8, 1.2]
39
+ fov: 16
40
+ gaussian_camera: False
41
+ angular_camera: True
42
+ depth_transform: ~
43
+ dists_normalized: False
44
+ ray_align_corner: False
45
+ bg_start: 0.5
46
+
47
+ renderer_kwargs:
48
+ n_bg_samples: 4
49
+ n_ray_samples: 16
50
+ abs_sigma: False
51
+ hierarchical: True
52
+ no_background: False
53
+
54
+ foreground_kwargs:
55
+ positional_encoding: "normal"
56
+ downscale_p_by: 1
57
+ use_style: "StyleGAN2"
58
+ predict_rgb: True
59
+ use_viewdirs: False
60
+
61
+ background_kwargs:
62
+ positional_encoding: "normal"
63
+ hidden_size: 64
64
+ n_blocks: 4
65
+ downscale_p_by: 1
66
+ skips: []
67
+ inverse_sphere: True
68
+ use_style: "StyleGAN2"
69
+ predict_rgb: True
70
+ use_viewdirs: False
71
+
72
+ upsampler_kwargs:
73
+ channel_base: ${model.G_kwargs.synthesis_kwargs.channel_base}
74
+ channel_max: ${model.G_kwargs.synthesis_kwargs.channel_max}
75
+ no_2d_renderer: False
76
+ no_residual_img: False
77
+ block_reses: ~
78
+ shared_rgb_style: False
79
+ upsample_type: "bilinear"
80
+
81
+ progressive: True
82
+
83
+ # reuglarization
84
+ n_reg_samples: 0
85
+ reg_full: False
86
+
87
+ D_kwargs:
88
+ class_name: "training.stylenerf.Discriminator"
89
+ epilogue_kwargs:
90
+ mbstd_group_size: ${spec.mbstd}
91
+
92
+ num_fp16_res: ${num_fp16_res}
93
+ channel_base: ${spec.fmaps}
94
+ channel_max: 512
95
+ conv_clamp: 256
96
+ architecture: skip
97
+ progressive: ${model.G_kwargs.synthesis_kwargs.progressive}
98
+ lowres_head: ${model.G_kwargs.synthesis_kwargs.resolution_start}
99
+ upsample_type: "bilinear"
100
+ resize_real_early: True
101
+
102
+ # loss kwargs
103
+ loss_kwargs:
104
+ pl_batch_shrink: 2
105
+ pl_decay: 0.01
106
+ pl_weight: 2
107
+ style_mixing_prob: 0.9
108
+ curriculum: [500,5000]
conf/model/stylenerf_cars_debug.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_ffhq
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 1024
17
+ conv_clamp: 256
18
+ kernel_size: 1
19
+ architecture: skip
20
+ upsample_mode: "pixelshuffle"
21
+
22
+ z_dim_bg: 0
23
+ z_dim: 0
24
+ resolution_vol: 128
25
+ resolution_start: 32
26
+ rgb_out_dim: 256
27
+
28
+ use_noise: False
29
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
30
+ no_bbox: True
31
+ margin: 0
32
+ magnitude_ema_beta: 0.999
33
+
34
+ camera_kwargs:
35
+ range_v: [1.4157963267948965, 1.7257963267948966]
36
+ range_u: [-3.141592653589793, 3.141592653589793]
37
+ range_radius: [1.0, 1.0]
38
+ depth_range: [0.8, 1.2]
39
+ fov: 16
40
+ gaussian_camera: False
41
+ angular_camera: True
42
+ depth_transform: ~
43
+ dists_normalized: False
44
+ ray_align_corner: False
45
+ bg_start: 0.5
46
+
47
+ renderer_kwargs:
48
+ n_bg_samples: 0
49
+ n_ray_samples: 32
50
+ abs_sigma: False
51
+ hierarchical: True
52
+ no_background: True
53
+
54
+ foreground_kwargs:
55
+ downscale_p_by: 1
56
+ use_style: "StyleGAN2"
57
+ predict_rgb: False
58
+ add_rgb: True
59
+ use_viewdirs: False
60
+ n_blocks: 0
61
+
62
+ input_kwargs:
63
+ output_mode: 'tri_plane_reshape'
64
+ input_mode: 'random'
65
+ in_res: 4
66
+ out_res: 256
67
+ out_dim: 32
68
+
69
+ upsampler_kwargs:
70
+ channel_base: ${model.G_kwargs.synthesis_kwargs.channel_base}
71
+ channel_max: ${model.G_kwargs.synthesis_kwargs.channel_max}
72
+ no_2d_renderer: False
73
+ no_residual_img: False
74
+ block_reses: ~
75
+ shared_rgb_style: False
76
+ upsample_type: "bilinear"
77
+
78
+ progressive: True
79
+
80
+ # reuglarization
81
+ n_reg_samples: 0
82
+ reg_full: False
83
+
84
+ D_kwargs:
85
+ class_name: "training.stylenerf.Discriminator"
86
+ epilogue_kwargs:
87
+ mbstd_group_size: ${spec.mbstd}
88
+
89
+ num_fp16_res: ${num_fp16_res}
90
+ channel_base: ${spec.fmaps}
91
+ channel_max: 512
92
+ conv_clamp: 256
93
+ architecture: skip
94
+ progressive: ${model.G_kwargs.synthesis_kwargs.progressive}
95
+ lowres_head: ${model.G_kwargs.synthesis_kwargs.resolution_start}
96
+ upsample_type: "bilinear"
97
+ resize_real_early: True
98
+
99
+ # loss kwargs
100
+ loss_kwargs:
101
+ pl_batch_shrink: 2
102
+ pl_decay: 0.01
103
+ pl_weight: 2
104
+ style_mixing_prob: 0.9
105
+ curriculum: [500,5000]
conf/model/stylenerf_ffhq.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_ffhq
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 1024
17
+ conv_clamp: 256
18
+ kernel_size: 1
19
+ architecture: skip
20
+ upsample_mode: "nn_cat"
21
+
22
+ z_dim_bg: 32
23
+ z_dim: 0
24
+ resolution_vol: 32
25
+ resolution_start: 32
26
+ rgb_out_dim: 256
27
+
28
+ use_noise: False
29
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
30
+ no_bbox: True
31
+ margin: 0
32
+ magnitude_ema_beta: 0.999
33
+
34
+ camera_kwargs:
35
+ range_v: [1.4157963267948965, 1.7257963267948966]
36
+ range_u: [-0.3, 0.3]
37
+ range_radius: [1.0, 1.0]
38
+ depth_range: [0.88, 1.12]
39
+ fov: 12
40
+ gaussian_camera: True
41
+ angular_camera: True
42
+ depth_transform: ~
43
+ dists_normalized: False
44
+ ray_align_corner: False
45
+ bg_start: 0.5
46
+
47
+ renderer_kwargs:
48
+ n_bg_samples: 4
49
+ n_ray_samples: 14
50
+ abs_sigma: False
51
+ hierarchical: True
52
+ no_background: False
53
+
54
+ foreground_kwargs:
55
+ positional_encoding: "normal"
56
+ downscale_p_by: 1
57
+ use_style: "StyleGAN2"
58
+ predict_rgb: True
59
+ use_viewdirs: False
60
+
61
+ background_kwargs:
62
+ positional_encoding: "normal"
63
+ hidden_size: 64
64
+ n_blocks: 4
65
+ downscale_p_by: 1
66
+ skips: []
67
+ inverse_sphere: True
68
+ use_style: "StyleGAN2"
69
+ predict_rgb: True
70
+ use_viewdirs: False
71
+
72
+ upsampler_kwargs:
73
+ channel_base: ${model.G_kwargs.synthesis_kwargs.channel_base}
74
+ channel_max: ${model.G_kwargs.synthesis_kwargs.channel_max}
75
+ no_2d_renderer: False
76
+ no_residual_img: False
77
+ block_reses: ~
78
+ shared_rgb_style: False
79
+ upsample_type: "bilinear"
80
+
81
+ progressive: True
82
+
83
+ # reuglarization
84
+ n_reg_samples: 16
85
+ reg_full: True
86
+
87
+ D_kwargs:
88
+ class_name: "training.stylenerf.Discriminator"
89
+ epilogue_kwargs:
90
+ mbstd_group_size: ${spec.mbstd}
91
+
92
+ num_fp16_res: ${num_fp16_res}
93
+ channel_base: ${spec.fmaps}
94
+ channel_max: 512
95
+ conv_clamp: 256
96
+ architecture: skip
97
+ progressive: ${model.G_kwargs.synthesis_kwargs.progressive}
98
+ lowres_head: ${model.G_kwargs.synthesis_kwargs.resolution_start}
99
+ upsample_type: "bilinear"
100
+ resize_real_early: True
101
+
102
+ # loss kwargs
103
+ loss_kwargs:
104
+ pl_batch_shrink: 2
105
+ pl_decay: 0.01
106
+ pl_weight: 2
107
+ style_mixing_prob: 0.9
108
+ curriculum: [500,5000]
conf/model/stylenerf_ffhq_ae.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_ffhq
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 1024
17
+ conv_clamp: 256
18
+ kernel_size: 1
19
+ architecture: skip
20
+ upsample_mode: "nn_cat"
21
+
22
+ z_dim: 0
23
+ resolution_vol: 128
24
+ resolution_start: 128
25
+ rgb_out_dim: 32
26
+
27
+ use_noise: False
28
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
29
+ no_bbox: True
30
+ margin: 0
31
+ magnitude_ema_beta: 0.999
32
+
33
+ camera_kwargs:
34
+ range_v: [1.4157963267948965, 1.7257963267948966]
35
+ range_u: [-0.3, 0.3]
36
+ range_radius: [1.0, 1.0]
37
+ depth_range: [0.88, 1.12]
38
+ fov: 12
39
+ gaussian_camera: True
40
+ angular_camera: True
41
+ depth_transform: ~
42
+ dists_normalized: True
43
+ ray_align_corner: False
44
+ bg_start: 0.5
45
+
46
+ renderer_kwargs:
47
+ n_ray_samples: 32
48
+ abs_sigma: False
49
+ hierarchical: True
50
+ no_background: True
51
+
52
+ foreground_kwargs:
53
+ downscale_p_by: 1
54
+ use_style: "StyleGAN2"
55
+ predict_rgb: False
56
+ use_viewdirs: False
57
+ add_rgb: True
58
+ n_blocks: 0
59
+
60
+ input_kwargs:
61
+ output_mode: 'tri_plane_reshape'
62
+ input_mode: 'random'
63
+ in_res: 4
64
+ out_res: 256
65
+ out_dim: 32
66
+
67
+ upsampler_kwargs:
68
+ no_2d_renderer: False
69
+ no_residual_img: False
70
+ block_reses: ~
71
+ shared_rgb_style: False
72
+ upsample_type: "bilinear"
73
+
74
+ progressive: True
75
+
76
+ # reuglarization
77
+ n_reg_samples: 0
78
+ reg_full: False
79
+
80
+ encoder_kwargs:
81
+ class_name: "training.stylenerf.Encoder"
82
+ num_fp16_res: ${num_fp16_res}
83
+ channel_base: ${spec.fmaps}
84
+ channel_max: 512
85
+ conv_clamp: 256
86
+ architecture: skip
87
+ progressive: ${..synthesis_kwargs.progressive}
88
+ lowres_head: ${..synthesis_kwargs.resolution_start}
89
+ upsample_type: "bilinear"
90
+ model_kwargs:
91
+ output_mode: "W+"
92
+ predict_camera: False
93
+
94
+ D_kwargs:
95
+ class_name: "training.stylenerf.Discriminator"
96
+ epilogue_kwargs:
97
+ mbstd_group_size: ${spec.mbstd}
98
+
99
+ num_fp16_res: ${num_fp16_res}
100
+ channel_base: ${spec.fmaps}
101
+ channel_max: 512
102
+ conv_clamp: 256
103
+ architecture: skip
104
+
105
+ predict_camera: True
106
+
107
+ progressive: ${model.G_kwargs.synthesis_kwargs.progressive}
108
+ lowres_head: ${model.G_kwargs.synthesis_kwargs.resolution_start}
109
+ upsample_type: "bilinear"
110
+ resize_real_early: True
111
+
112
+ # loss kwargs
113
+ loss_kwargs:
114
+ pl_batch_shrink: 2
115
+ pl_decay: 0.01
116
+ pl_weight: 2
117
+ style_mixing_prob: 0.9
118
+ curriculum: [500,5000]
conf/model/stylenerf_ffhq_ae_basic.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_ffhq
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 1024
17
+ conv_clamp: 256
18
+ kernel_size: 1
19
+ architecture: skip
20
+ upsample_mode: "nn_cat"
21
+
22
+ z_dim: 0
23
+ resolution_vol: 32
24
+ resolution_start: 32
25
+ rgb_out_dim: 32
26
+
27
+ use_noise: False
28
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
29
+ no_bbox: True
30
+ margin: 0
31
+ magnitude_ema_beta: 0.999
32
+
33
+ camera_kwargs:
34
+ range_v: [1.4157963267948965, 1.7257963267948966]
35
+ range_u: [-0.3, 0.3]
36
+ range_radius: [1.0, 1.0]
37
+ depth_range: [0.88, 1.12]
38
+ fov: 12
39
+ gaussian_camera: True
40
+ angular_camera: True
41
+ depth_transform: ~
42
+ dists_normalized: True
43
+ ray_align_corner: False
44
+ bg_start: 0.5
45
+
46
+ renderer_kwargs:
47
+ n_ray_samples: 32
48
+ abs_sigma: False
49
+ hierarchical: True
50
+ no_background: True
51
+
52
+ foreground_kwargs:
53
+ downscale_p_by: 1
54
+ use_style: "StyleGAN2"
55
+ predict_rgb: False
56
+ use_viewdirs: False
57
+ add_rgb: True
58
+
59
+ upsampler_kwargs:
60
+ no_2d_renderer: False
61
+ no_residual_img: False
62
+ block_reses: ~
63
+ shared_rgb_style: False
64
+ upsample_type: "bilinear"
65
+
66
+ progressive: True
67
+
68
+ # reuglarization
69
+ n_reg_samples: 0
70
+ reg_full: False
71
+
72
+ encoder_kwargs:
73
+ class_name: "training.stylenerf.Encoder"
74
+ num_fp16_res: ${num_fp16_res}
75
+ channel_base: ${spec.fmaps}
76
+ channel_max: 512
77
+ conv_clamp: 256
78
+ architecture: skip
79
+ progressive: ${..synthesis_kwargs.progressive}
80
+ lowres_head: ${..synthesis_kwargs.resolution_start}
81
+ upsample_type: "bilinear"
82
+ model_kwargs:
83
+ output_mode: "W+"
84
+ predict_camera: False
85
+
86
+ D_kwargs:
87
+ class_name: "training.stylenerf.Discriminator"
88
+ epilogue_kwargs:
89
+ mbstd_group_size: ${spec.mbstd}
90
+
91
+ num_fp16_res: ${num_fp16_res}
92
+ channel_base: ${spec.fmaps}
93
+ channel_max: 512
94
+ conv_clamp: 256
95
+ architecture: skip
96
+
97
+ predict_camera: True
98
+
99
+ progressive: ${model.G_kwargs.synthesis_kwargs.progressive}
100
+ lowres_head: ${model.G_kwargs.synthesis_kwargs.resolution_start}
101
+ upsample_type: "bilinear"
102
+ resize_real_early: True
103
+
104
+ # loss kwargs
105
+ loss_kwargs:
106
+ pl_batch_shrink: 2
107
+ pl_decay: 0.01
108
+ pl_weight: 2
109
+ style_mixing_prob: 0.9
110
+ curriculum: [500,5000]
conf/model/stylenerf_ffhq_debug.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_ffhq
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 1024
17
+ conv_clamp: 256
18
+ kernel_size: 1
19
+ architecture: skip
20
+ upsample_mode: "nn_cat"
21
+
22
+ z_dim: 0
23
+ resolution_vol: 128
24
+ resolution_start: 128
25
+ rgb_out_dim: 32
26
+
27
+ use_noise: False
28
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
29
+ no_bbox: True
30
+ margin: 0
31
+ magnitude_ema_beta: 0.999
32
+
33
+ camera_kwargs:
34
+ range_v: [1.4157963267948965, 1.7257963267948966]
35
+ range_u: [-0.3, 0.3]
36
+ range_radius: [1.0, 1.0]
37
+ depth_range: [0.88, 1.12]
38
+ fov: 12
39
+ gaussian_camera: True
40
+ angular_camera: True
41
+ depth_transform: ~
42
+ dists_normalized: True
43
+ ray_align_corner: False
44
+ bg_start: 0.5
45
+
46
+ renderer_kwargs:
47
+ n_ray_samples: 32
48
+ abs_sigma: False
49
+ hierarchical: True
50
+ no_background: True
51
+
52
+ foreground_kwargs:
53
+ downscale_p_by: 1
54
+ use_style: "StyleGAN2"
55
+ predict_rgb: False
56
+ use_viewdirs: False
57
+ add_rgb: True
58
+ n_blocks: 0
59
+
60
+ input_kwargs:
61
+ output_mode: 'tri_plane_reshape'
62
+ input_mode: 'random'
63
+ in_res: 4
64
+ out_res: 256
65
+ out_dim: 32
66
+ keep_posenc: -1
67
+ keep_nerf_latents: False
68
+
69
+ upsampler_kwargs:
70
+ no_2d_renderer: False
71
+ no_residual_img: False
72
+ block_reses: ~
73
+ shared_rgb_style: False
74
+ upsample_type: "bilinear"
75
+
76
+ progressive: True
77
+
78
+ # reuglarization
79
+ n_reg_samples: 0
80
+ reg_full: False
81
+
82
+ D_kwargs:
83
+ class_name: "training.stylenerf.Discriminator"
84
+ epilogue_kwargs:
85
+ mbstd_group_size: ${spec.mbstd}
86
+
87
+ num_fp16_res: ${num_fp16_res}
88
+ channel_base: ${spec.fmaps}
89
+ channel_max: 512
90
+ conv_clamp: 256
91
+ architecture: skip
92
+ progressive: ${model.G_kwargs.synthesis_kwargs.progressive}
93
+ lowres_head: ${model.G_kwargs.synthesis_kwargs.resolution_start}
94
+ upsample_type: "bilinear"
95
+ resize_real_early: True
96
+
97
+ # loss kwargs
98
+ loss_kwargs:
99
+ pl_batch_shrink: 2
100
+ pl_decay: 0.01
101
+ pl_weight: 2
102
+ style_mixing_prob: 0.9
103
+ curriculum: [500,5000]
conf/model/stylenerf_ffhq_eg3d.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_ffhq
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 512
17
+ conv_clamp: 256
18
+ kernel_size: 3
19
+ architecture: skip
20
+
21
+ z_dim: 0
22
+ resolution_vol: 128
23
+ resolution_start: 32
24
+ rgb_out_dim: 32
25
+
26
+ use_noise: False
27
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
28
+ no_bbox: True
29
+ margin: 0
30
+ magnitude_ema_beta: 0.999
31
+
32
+ camera_kwargs:
33
+ range_v: [1.4157963267948965, 1.7257963267948966]
34
+ range_u: [-0.3, 0.3]
35
+ range_radius: [1.0, 1.0]
36
+ depth_range: [0.88, 1.12]
37
+ fov: 12
38
+ gaussian_camera: True
39
+ angular_camera: True
40
+ depth_transform: ~
41
+ dists_normalized: True
42
+ ray_align_corner: False
43
+ bg_start: 0.5
44
+
45
+ renderer_kwargs:
46
+ n_ray_samples: 32
47
+ abs_sigma: False
48
+ hierarchical: True
49
+ no_background: True
50
+
51
+ foreground_kwargs:
52
+ downscale_p_by: 1
53
+ use_style: "StyleGAN2"
54
+ predict_rgb: False
55
+ use_viewdirs: False
56
+ add_rgb: True
57
+ n_blocks: 0
58
+
59
+ input_kwargs:
60
+ output_mode: 'tri_plane_reshape'
61
+ input_mode: 'random'
62
+ in_res: 4
63
+ out_res: 256
64
+ out_dim: 32
65
+
66
+ upsampler_kwargs:
67
+ no_2d_renderer: False
68
+ block_reses: ~
69
+ shared_rgb_style: False
70
+ upsample_type: "bilinear"
71
+
72
+ progressive: False
73
+ prog_nerf_only: True
74
+
75
+ # reuglarization
76
+ n_reg_samples: 0
77
+ reg_full: False
78
+
79
+ D_kwargs:
80
+ class_name: "training.stylenerf.Discriminator"
81
+ epilogue_kwargs:
82
+ mbstd_group_size: ${spec.mbstd}
83
+
84
+ num_fp16_res: ${num_fp16_res}
85
+ channel_base: ${spec.fmaps}
86
+ channel_max: 512
87
+ conv_clamp: 256
88
+ architecture: skip
89
+ progressive: False
90
+ dual_input_res: 128
91
+ upsample_type: "bilinear"
92
+ resize_real_early: True
93
+
94
+ # loss kwargs
95
+ loss_kwargs:
96
+ pl_batch_shrink: 2
97
+ pl_decay: 0.01
98
+ pl_weight: 2
99
+ style_mixing_prob: 0.9
100
+ curriculum: [0,5000]
conf/model/stylenerf_ffhq_warped_depth.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ name: stylenerf_ffhq_warped_depth
3
+
4
+ G_kwargs:
5
+ class_name: "training.networks.Generator"
6
+ z_dim: 512
7
+ w_dim: 512
8
+
9
+ mapping_kwargs:
10
+ num_layers: ${spec.map}
11
+
12
+ synthesis_kwargs:
13
+ # global settings
14
+ num_fp16_res: ${num_fp16_res}
15
+ channel_base: 1
16
+ channel_max: 1024
17
+ conv_clamp: 256
18
+ kernel_size: 1
19
+ architecture: skip
20
+ upsample_mode: "nn_cat"
21
+
22
+ z_dim_bg: 0
23
+ z_dim: 0
24
+ resolution_vol: 32
25
+ resolution_start: 32
26
+ rgb_out_dim: 256
27
+
28
+ use_noise: False
29
+ module_name: "training.stylenerf.NeRFSynthesisNetwork"
30
+ no_bbox: True
31
+ margin: 0
32
+ magnitude_ema_beta: 0.999
33
+
34
+ camera_kwargs:
35
+ range_v: [1.4157963267948965, 1.7257963267948966]
36
+ range_u: [-0.3, 0.3]
37
+ range_radius: [1.0, 1.0]
38
+ depth_range: [0.88, 3.2]
39
+ fov: 12
40
+ gaussian_camera: True
41
+ angular_camera: True
42
+ depth_transform: InverseWarp
43
+ dists_normalized: True
44
+ ray_align_corner: False
45
+ bg_start: 0.5
46
+
47
+ renderer_kwargs:
48
+ n_bg_samples: 0
49
+ n_ray_samples: 48
50
+ abs_sigma: False
51
+ hierarchical: True
52
+ no_background: True
53
+
54
+ foreground_kwargs:
55
+ positional_encoding: "normal"
56
+ downscale_p_by: 1
57
+ use_style: "StyleGAN2"
58
+ predict_rgb: True
59
+ use_viewdirs: False
60
+
61
+ upsampler_kwargs:
62
+ channel_base: ${model.G_kwargs.synthesis_kwargs.channel_base}
63
+ channel_max: ${model.G_kwargs.synthesis_kwargs.channel_max}
64
+ no_2d_renderer: False
65
+ no_residual_img: False
66
+ block_reses: ~
67
+ shared_rgb_style: False
68
+ upsample_type: "bilinear"
69
+
70
+ progressive: True
71
+
72
+ # reuglarization
73
+ n_reg_samples: 16
74
+ reg_full: True
75
+
76
+ D_kwargs:
77
+ class_name: "training.stylenerf.Discriminator"
78
+ epilogue_kwargs:
79
+ mbstd_group_size: ${spec.mbstd}
80
+
81
+ num_fp16_res: ${num_fp16_res}
82
+ channel_base: ${spec.fmaps}
83
+ channel_max: 512
84
+ conv_clamp: 256
85
+ architecture: skip
86
+ progressive: ${model.G_kwargs.synthesis_kwargs.progressive}
87
+ lowres_head: ${model.G_kwargs.synthesis_kwargs.resolution_start}
88
+ upsample_type: "bilinear"
89
+ resize_real_early: True
90
+
91
+ # loss kwargs
92
+ loss_kwargs:
93
+ pl_batch_shrink: 2
94
+ pl_decay: 0.01
95
+ pl_weight: 2
96
+ style_mixing_prob: 0.9
97
+ curriculum: [500,5000]
conf/spec/cifar.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: cifar
4
+ ref_gpus: 2
5
+ kimg: 100000
6
+ mb: 64
7
+ mbstd: 32
8
+ fmaps: 1
9
+ lrate: 0.0025
10
+ gamma: 0.01
11
+ ema: 500
12
+ ramp: 0.05
13
+ map: 2
conf/spec/nerf32.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: nerf32
4
+ ref_gpus: 8
5
+ kimg: 25000
6
+ mb: 64
7
+ mbstd: 8
8
+ fmaps: 0.5
9
+ lrate: 0.0025
10
+ lrate_disc: 0.0025
11
+ gamma: 0.003
12
+ ema: 20
13
+ ramp: ~
14
+ map: 8
conf/spec/paper1024.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: paper1024
4
+ ref_gpus: 8
5
+ kimg: 25000
6
+ mb: 32
7
+ mbstd: 4
8
+ fmaps: 1
9
+ lrate: 0.002
10
+ lrate_disc: 0.002
11
+ gamma: 2
12
+ ema: 10
13
+ ramp: ~
14
+ map: 8
conf/spec/paper256.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: paper256
4
+ ref_gpus: 8
5
+ kimg: 25000
6
+ mb: 64
7
+ mbstd: 8
8
+ fmaps: 0.5
9
+ lrate: 0.0025
10
+ lrate_disc: 0.0025
11
+ gamma: 0.5
12
+ ema: 20
13
+ ramp: ~
14
+ map: 8
conf/spec/paper512.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: paper512
2
+ ref_gpus: 8
3
+ kimg: 25000
4
+ mb: 64
5
+ mbstd: 8
6
+ fmaps: 1
7
+ lrate: 0.0025
8
+ lrate_disc: 0.0025
9
+ gamma: 0.5
10
+ ema: 20
11
+ ramp: ~
12
+ map: 8
conf/spec/stylegan2.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: stylegan2
4
+ ref_gpus: 8
5
+ kimg: 25000
6
+ mb: 32
7
+ mbstd: 4
8
+ fmaps: 1
9
+ lrate: 0.002
10
+ lrate_disc: 0.0025
11
+ gamma: 10
12
+ ema: 10
13
+ ramp: ~
14
+ map: 8
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/camera.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+
4
+ import numpy as np
5
+ from numpy.lib.function_base import angle
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import math
9
+
10
+ from scipy.spatial.transform import Rotation as Rot
11
+ HUGE_NUMBER = 1e10
12
+ TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
13
+
14
+
15
+ def get_camera_mat(fov=49.13, invert=True):
16
+ # fov = 2 * arctan(sensor / (2 * focal))
17
+ # focal = (sensor / 2) * 1 / (tan(0.5 * fov))
18
+ # in our case, sensor = 2 as pixels are in [-1, 1]
19
+ focal = 1. / np.tan(0.5 * fov * np.pi/180.)
20
+ focal = focal.astype(np.float32)
21
+ mat = torch.tensor([
22
+ [focal, 0., 0., 0.],
23
+ [0., focal, 0., 0.],
24
+ [0., 0., 1, 0.],
25
+ [0., 0., 0., 1.]
26
+ ]).reshape(1, 4, 4)
27
+ if invert:
28
+ mat = torch.inverse(mat)
29
+ return mat
30
+
31
+
32
+ def get_random_pose(range_u, range_v, range_radius, batch_size=32,
33
+ invert=False, gaussian=False, angular=False):
34
+ loc, (u, v) = sample_on_sphere(range_u, range_v, size=(batch_size), gaussian=gaussian, angular=angular)
35
+ radius = range_radius[0] + torch.rand(batch_size) * (range_radius[1] - range_radius[0])
36
+ loc = loc * radius.unsqueeze(-1)
37
+ R = look_at(loc)
38
+ RT = torch.eye(4).reshape(1, 4, 4).repeat(batch_size, 1, 1)
39
+ RT[:, :3, :3] = R
40
+ RT[:, :3, -1] = loc
41
+
42
+ if invert:
43
+ RT = torch.inverse(RT)
44
+
45
+ def N(a, range_a):
46
+ if range_a[0] == range_a[1]:
47
+ return a * 0
48
+ return (a - range_a[0]) / (range_a[1] - range_a[0])
49
+
50
+ val_u, val_v, val_r = N(u, range_u), N(v, range_v), N(radius, range_radius)
51
+ return RT, (val_u, val_v, val_r)
52
+
53
+
54
+ def get_camera_pose(range_u, range_v, range_r, val_u=0.5, val_v=0.5, val_r=0.5,
55
+ batch_size=32, invert=False, gaussian=False, angular=False):
56
+ r0, rr = range_r[0], range_r[1] - range_r[0]
57
+ r = r0 + val_r * rr
58
+ if not gaussian:
59
+ u0, ur = range_u[0], range_u[1] - range_u[0]
60
+ v0, vr = range_v[0], range_v[1] - range_v[0]
61
+ u = u0 + val_u * ur
62
+ v = v0 + val_v * vr
63
+ else:
64
+ mean_u, mean_v = sum(range_u) / 2, sum(range_v) / 2
65
+ vu, vv = mean_u - range_u[0], mean_v - range_v[0]
66
+ u = mean_u + vu * val_u
67
+ v = mean_v + vv * val_v
68
+
69
+ loc, _ = sample_on_sphere((u, u), (v, v), size=(batch_size), angular=angular)
70
+ radius = torch.ones(batch_size) * r
71
+ loc = loc * radius.unsqueeze(-1)
72
+ R = look_at(loc)
73
+ RT = torch.eye(4).reshape(1, 4, 4).repeat(batch_size, 1, 1)
74
+ RT[:, :3, :3] = R
75
+ RT[:, :3, -1] = loc
76
+
77
+ if invert:
78
+ RT = torch.inverse(RT)
79
+ return RT
80
+
81
+
82
+ def get_camera_pose_v2(range_u, range_v, range_r, mode, invert=False, gaussian=False, angular=False):
83
+ r0, rr = range_r[0], range_r[1] - range_r[0]
84
+ val_u, val_v = mode[:,0], mode[:,1]
85
+ val_r = torch.ones_like(val_u) * 0.5
86
+ if not gaussian:
87
+ u0, ur = range_u[0], range_u[1] - range_u[0]
88
+ v0, vr = range_v[0], range_v[1] - range_v[0]
89
+ u = u0 + val_u * ur
90
+ v = v0 + val_v * vr
91
+ else:
92
+ mean_u, mean_v = sum(range_u) / 2, sum(range_v) / 2
93
+ vu, vv = mean_u - range_u[0], mean_v - range_v[0]
94
+ u = mean_u + vu * val_u
95
+ v = mean_v + vv * val_v
96
+
97
+ loc = to_sphere(u, v, angular)
98
+ radius = r0 + val_r * rr
99
+ loc = loc * radius.unsqueeze(-1)
100
+ R = look_at(loc)
101
+ RT = torch.eye(4).to(R.device).reshape(1, 4, 4).repeat(R.size(0), 1, 1)
102
+ RT[:, :3, :3] = R
103
+ RT[:, :3, -1] = loc
104
+
105
+ if invert:
106
+ RT = torch.inverse(RT)
107
+ return RT, (val_u, val_v, val_r)
108
+
109
+
110
+ def to_sphere(u, v, angular=False):
111
+ T = torch if isinstance(u, torch.Tensor) else np
112
+ if not angular:
113
+ theta = 2 * math.pi * u
114
+ phi = T.arccos(1 - 2 * v)
115
+ else:
116
+ theta, phi = u, v
117
+
118
+ cx = T.sin(phi) * T.cos(theta)
119
+ cy = T.sin(phi) * T.sin(theta)
120
+ cz = T.cos(phi)
121
+ return T.stack([cx, cy, cz], -1)
122
+
123
+
124
+ def sample_on_sphere(range_u=(0, 1), range_v=(0, 1), size=(1,),
125
+ to_pytorch=True, gaussian=False, angular=False):
126
+ if not gaussian:
127
+ u = np.random.uniform(*range_u, size=size)
128
+ v = np.random.uniform(*range_v, size=size)
129
+ else:
130
+ mean_u, mean_v = sum(range_u) / 2, sum(range_v) / 2
131
+ var_u, var_v = mean_u - range_u[0], mean_v - range_v[0]
132
+ u = np.random.normal(size=size) * var_u + mean_u
133
+ v = np.random.normal(size=size) * var_v + mean_v
134
+
135
+ sample = to_sphere(u, v, angular)
136
+ if to_pytorch:
137
+ sample = torch.tensor(sample).float()
138
+ u, v = torch.tensor(u).float(), torch.tensor(v).float()
139
+
140
+ return sample, (u, v)
141
+
142
+
143
+ def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5,
144
+ to_pytorch=True):
145
+ if not isinstance(eye, torch.Tensor):
146
+ # this is the original code from GRAF
147
+ at = at.astype(float).reshape(1, 3)
148
+ up = up.astype(float).reshape(1, 3)
149
+ eye = eye.reshape(-1, 3)
150
+ up = up.repeat(eye.shape[0] // up.shape[0], axis=0)
151
+ eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)
152
+ z_axis = eye - at
153
+ z_axis /= np.max(np.stack([np.linalg.norm(z_axis,
154
+ axis=1, keepdims=True), eps]))
155
+ x_axis = np.cross(up, z_axis)
156
+ x_axis /= np.max(np.stack([np.linalg.norm(x_axis,
157
+ axis=1, keepdims=True), eps]))
158
+ y_axis = np.cross(z_axis, x_axis)
159
+ y_axis /= np.max(np.stack([np.linalg.norm(y_axis,
160
+ axis=1, keepdims=True), eps]))
161
+ r_mat = np.concatenate(
162
+ (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(
163
+ -1, 3, 1)), axis=2)
164
+ if to_pytorch:
165
+ r_mat = torch.tensor(r_mat).float()
166
+ else:
167
+
168
+ def normalize(x, axis=-1, order=2):
169
+ l2 = x.norm(p=order, dim=axis, keepdim=True).clamp(min=1e-8)
170
+ return x / l2
171
+
172
+ at, up = torch.from_numpy(at).float().to(eye.device), torch.from_numpy(up).float().to(eye.device)
173
+ z_axis = normalize(eye - at[None, :])
174
+ x_axis = normalize(torch.cross(up[None,:].expand_as(z_axis), z_axis, dim=-1))
175
+ y_axis = normalize(torch.cross(z_axis, x_axis, dim=-1))
176
+ r_mat = torch.stack([x_axis, y_axis, z_axis], dim=-1)
177
+
178
+ return r_mat
179
+
180
+
181
+ def get_rotation_matrix(axis='z', value=0., batch_size=32):
182
+ r = Rot.from_euler(axis, value * 2 * np.pi).as_dcm()
183
+ r = torch.from_numpy(r).reshape(1, 3, 3).repeat(batch_size, 1, 1)
184
+ return r
185
+
186
+
187
+ def get_corner_rays(corner_pixels, camera_matrices, res):
188
+ assert (res + 1) * (res + 1) == corner_pixels.size(1)
189
+ batch_size = camera_matrices[0].size(0)
190
+ rays, origins, _ = get_camera_rays(camera_matrices, corner_pixels)
191
+ corner_rays = torch.cat([rays, torch.cross(origins, rays, dim=-1)], -1)
192
+ corner_rays = corner_rays.reshape(batch_size, res+1, res+1, 6).permute(0,3,1,2)
193
+ corner_rays = torch.cat([corner_rays[..., :-1, :-1], corner_rays[..., 1:, :-1], corner_rays[..., 1:, 1:], corner_rays[..., :-1, 1:]], 1)
194
+ return corner_rays
195
+
196
+
197
+ def arange_pixels(
198
+ resolution=(128, 128),
199
+ batch_size=1,
200
+ subsample_to=None,
201
+ invert_y_axis=False,
202
+ margin=0,
203
+ corner_aligned=True,
204
+ jitter=None
205
+ ):
206
+ ''' Arranges pixels for given resolution in range image_range.
207
+
208
+ The function returns the unscaled pixel locations as integers and the
209
+ scaled float values.
210
+
211
+ Args:
212
+ resolution (tuple): image resolution
213
+ batch_size (int): batch size
214
+ subsample_to (int): if integer and > 0, the points are randomly
215
+ subsampled to this value
216
+ '''
217
+ h, w = resolution
218
+ n_points = resolution[0] * resolution[1]
219
+ uh = 1 if corner_aligned else 1 - (1 / h)
220
+ uw = 1 if corner_aligned else 1 - (1 / w)
221
+ if margin > 0:
222
+ uh = uh + (2 / h) * margin
223
+ uw = uw + (2 / w) * margin
224
+ w, h = w + margin * 2, h + margin * 2
225
+
226
+ x, y = torch.linspace(-uw, uw, w), torch.linspace(-uh, uh, h)
227
+ if jitter is not None:
228
+ dx = (torch.ones_like(x).uniform_() - 0.5) * 2 / w * jitter
229
+ dy = (torch.ones_like(y).uniform_() - 0.5) * 2 / h * jitter
230
+ x, y = x + dx, y + dy
231
+ x, y = torch.meshgrid(x, y)
232
+ pixel_scaled = torch.stack([x, y], -1).permute(1,0,2).reshape(1, -1, 2).repeat(batch_size, 1, 1)
233
+
234
+ # Subsample points if subsample_to is not None and > 0
235
+ if (subsample_to is not None and subsample_to > 0 and subsample_to < n_points):
236
+ idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),
237
+ replace=False)
238
+ pixel_scaled = pixel_scaled[:, idx]
239
+
240
+ if invert_y_axis:
241
+ pixel_scaled[..., -1] *= -1.
242
+
243
+ return pixel_scaled
244
+
245
+
246
+ def to_pytorch(tensor, return_type=False):
247
+ ''' Converts input tensor to pytorch.
248
+
249
+ Args:
250
+ tensor (tensor): Numpy or Pytorch tensor
251
+ return_type (bool): whether to return input type
252
+ '''
253
+ is_numpy = False
254
+ if type(tensor) == np.ndarray:
255
+ tensor = torch.from_numpy(tensor)
256
+ is_numpy = True
257
+ tensor = tensor.clone()
258
+ if return_type:
259
+ return tensor, is_numpy
260
+ return tensor
261
+
262
+
263
+ def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None,
264
+ invert=True, use_absolute_depth=True):
265
+ ''' Transforms pixel positions p with given depth value d to world coordinates.
266
+
267
+ Args:
268
+ pixels (tensor): pixel tensor of size B x N x 2
269
+ depth (tensor): depth tensor of size B x N x 1
270
+ camera_mat (tensor): camera matrix
271
+ world_mat (tensor): world matrix
272
+ scale_mat (tensor): scale matrix
273
+ invert (bool): whether to invert matrices (default: true)
274
+ '''
275
+ assert(pixels.shape[-1] == 2)
276
+ if scale_mat is None:
277
+ scale_mat = torch.eye(4).unsqueeze(0).repeat(
278
+ camera_mat.shape[0], 1, 1).to(camera_mat.device)
279
+
280
+ # Convert to pytorch
281
+ pixels, is_numpy = to_pytorch(pixels, True)
282
+ depth = to_pytorch(depth)
283
+ camera_mat = to_pytorch(camera_mat)
284
+ world_mat = to_pytorch(world_mat)
285
+ scale_mat = to_pytorch(scale_mat)
286
+
287
+ # Invert camera matrices
288
+ if invert:
289
+ camera_mat = torch.inverse(camera_mat)
290
+ world_mat = torch.inverse(world_mat)
291
+ scale_mat = torch.inverse(scale_mat)
292
+
293
+ # Transform pixels to homogen coordinates
294
+ pixels = pixels.permute(0, 2, 1)
295
+ pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)
296
+
297
+ # Project pixels into camera space
298
+ if use_absolute_depth:
299
+ pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()
300
+ pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)
301
+ else:
302
+ pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)
303
+
304
+ # Transform pixels to world space
305
+ p_world = scale_mat @ world_mat @ camera_mat @ pixels
306
+
307
+ # Transform p_world back to 3D coordinates
308
+ p_world = p_world[:, :3].permute(0, 2, 1)
309
+
310
+ if is_numpy:
311
+ p_world = p_world.numpy()
312
+ return p_world
313
+
314
+
315
+ def transform_to_camera_space(p_world, world_mat, camera_mat=None, scale_mat=None):
316
+ ''' Transforms world points to camera space.
317
+ Args:
318
+ p_world (tensor): world points tensor of size B x N x 3
319
+ camera_mat (tensor): camera matrix
320
+ world_mat (tensor): world matrix
321
+ scale_mat (tensor): scale matrix
322
+ '''
323
+ batch_size, n_p, _ = p_world.shape
324
+ device = p_world.device
325
+
326
+ # Transform world points to homogen coordinates
327
+ p_world = torch.cat([p_world, torch.ones(
328
+ batch_size, n_p, 1).to(device)], dim=-1).permute(0, 2, 1)
329
+
330
+ # Apply matrices to transform p_world to camera space
331
+ if scale_mat is None:
332
+ if camera_mat is None:
333
+ p_cam = world_mat @ p_world
334
+ else:
335
+ p_cam = camera_mat @ world_mat @ p_world
336
+ else:
337
+ p_cam = camera_mat @ world_mat @ scale_mat @ p_world
338
+
339
+ # Transform points back to 3D coordinates
340
+ p_cam = p_cam[:, :3].permute(0, 2, 1)
341
+ return p_cam
342
+
343
+
344
+ def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None,
345
+ invert=False):
346
+ ''' Transforms origin (camera location) to world coordinates.
347
+
348
+ Args:
349
+ n_points (int): how often the transformed origin is repeated in the
350
+ form (batch_size, n_points, 3)
351
+ camera_mat (tensor): camera matrix
352
+ world_mat (tensor): world matrix
353
+ scale_mat (tensor): scale matrix
354
+ invert (bool): whether to invert the matrices (default: true)
355
+ '''
356
+ batch_size = camera_mat.shape[0]
357
+ device = camera_mat.device
358
+ # Create origin in homogen coordinates
359
+ p = torch.zeros(batch_size, 4, n_points).to(device)
360
+ p[:, -1] = 1.
361
+
362
+ if scale_mat is None:
363
+ scale_mat = torch.eye(4).unsqueeze(
364
+ 0).repeat(batch_size, 1, 1).to(device)
365
+
366
+ # Invert matrices
367
+ if invert:
368
+ camera_mat = torch.inverse(camera_mat)
369
+ world_mat = torch.inverse(world_mat)
370
+ scale_mat = torch.inverse(scale_mat)
371
+
372
+ # Apply transformation
373
+ p_world = scale_mat @ world_mat @ camera_mat @ p
374
+
375
+ # Transform points back to 3D coordinates
376
+ p_world = p_world[:, :3].permute(0, 2, 1)
377
+ return p_world
378
+
379
+
380
+ def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None,
381
+ invert=False, negative_depth=True):
382
+ ''' Transforms points on image plane to world coordinates.
383
+
384
+ In contrast to transform_to_world, no depth value is needed as points on
385
+ the image plane have a fixed depth of 1.
386
+
387
+ Args:
388
+ image_points (tensor): image points tensor of size B x N x 2
389
+ camera_mat (tensor): camera matrix
390
+ world_mat (tensor): world matrix
391
+ scale_mat (tensor): scale matrix
392
+ invert (bool): whether to invert matrices
393
+ '''
394
+ batch_size, n_pts, dim = image_points.shape
395
+ assert(dim == 2)
396
+ device = image_points.device
397
+ d_image = torch.ones(batch_size, n_pts, 1).to(device)
398
+ if negative_depth:
399
+ d_image *= -1.
400
+ return transform_to_world(image_points, d_image, camera_mat, world_mat,
401
+ scale_mat, invert=invert)
402
+
403
+
404
+ def image_points_to_camera(image_points, camera_mat,
405
+ invert=False, negative_depth=True, use_absolute_depth=True):
406
+ batch_size, n_pts, dim = image_points.shape
407
+ assert(dim == 2)
408
+ device = image_points.device
409
+ d_image = torch.ones(batch_size, n_pts, 1).to(device)
410
+ if negative_depth:
411
+ d_image *= -1.
412
+
413
+ # Convert to pytorch
414
+ pixels, is_numpy = to_pytorch(image_points, True)
415
+ depth = to_pytorch(d_image)
416
+ camera_mat = to_pytorch(camera_mat)
417
+
418
+ # Invert camera matrices
419
+ if invert:
420
+ camera_mat = torch.inverse(camera_mat)
421
+
422
+ # Transform pixels to homogen coordinates
423
+ pixels = pixels.permute(0, 2, 1)
424
+ pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)
425
+
426
+ # Project pixels into camera space
427
+ if use_absolute_depth:
428
+ pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()
429
+ pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)
430
+ else:
431
+ pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)
432
+
433
+ # Transform pixels to world space
434
+ p_camera = camera_mat @ pixels
435
+
436
+ # Transform p_world back to 3D coordinates
437
+ p_camera = p_camera[:, :3].permute(0, 2, 1)
438
+
439
+ if is_numpy:
440
+ p_camera = p_camera.numpy()
441
+ return p_camera
442
+
443
+
444
+ def camera_points_to_image(camera_points, camera_mat,
445
+ invert=False, negative_depth=True, use_absolute_depth=True):
446
+ batch_size, n_pts, dim = camera_points.shape
447
+ assert(dim == 3)
448
+ device = camera_points.device
449
+
450
+ # Convert to pytorch
451
+ p_camera, is_numpy = to_pytorch(camera_points, True)
452
+ camera_mat = to_pytorch(camera_mat)
453
+
454
+ # Invert camera matrices
455
+ if invert:
456
+ camera_mat = torch.inverse(camera_mat)
457
+
458
+ # Transform world camera space to pixels
459
+ p_camera = p_camera.permute(0, 2, 1) # B x 3 x N
460
+ pixels = camera_mat[:, :3, :3] @ p_camera
461
+
462
+ assert use_absolute_depth and negative_depth
463
+ pixels, p_depths = pixels[:, :2], pixels[:, 2:3]
464
+ p_depths = -p_depths # negative depth
465
+ pixels = pixels / p_depths
466
+
467
+ pixels = pixels.permute(0, 2, 1)
468
+ if is_numpy:
469
+ pixels = pixels.numpy()
470
+ return pixels
471
+
472
+
473
+ def angular_interpolation(res, camera_mat):
474
+ batch_size = camera_mat.shape[0]
475
+ device = camera_mat.device
476
+ input_rays = image_points_to_camera(arange_pixels((res, res), batch_size,
477
+ invert_y_axis=True).to(device), camera_mat)
478
+ output_rays = image_points_to_camera(arange_pixels((res * 2, res * 2), batch_size,
479
+ invert_y_axis=True).to(device), camera_mat)
480
+ input_rays = input_rays / input_rays.norm(dim=-1, keepdim=True)
481
+ output_rays = output_rays / output_rays.norm(dim=-1, keepdim=True)
482
+
483
+ def dir2sph(v):
484
+ u = (v[..., :2] ** 2).sum(-1).sqrt()
485
+ theta = torch.atan2(u, v[..., 2]) / math.pi
486
+ phi = torch.atan2(v[..., 1], v[..., 0]) / math.pi
487
+ return torch.stack([theta, phi], 1)
488
+
489
+ input_rays = dir2sph(input_rays).reshape(batch_size, 2, res, res)
490
+ output_rays = dir2sph(output_rays).reshape(batch_size, 2, res * 2, res * 2)
491
+ return input_rays
492
+
493
+
494
+ def interpolate_sphere(z1, z2, t):
495
+ p = (z1 * z2).sum(dim=-1, keepdim=True)
496
+ p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt()
497
+ p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt()
498
+ omega = torch.acos(p)
499
+ s1 = torch.sin((1-t)*omega)/torch.sin(omega)
500
+ s2 = torch.sin(t*omega)/torch.sin(omega)
501
+ z = s1 * z1 + s2 * z2
502
+ return z
503
+
504
+
505
+ def get_camera_rays(camera_matrices, pixels=None, res=None, margin=0):
506
+ device = camera_matrices[0].device
507
+ batch_size = camera_matrices[0].shape[0]
508
+ if pixels is None:
509
+ assert res is not None
510
+ pixels = arange_pixels((res, res), batch_size, invert_y_axis=True, margin=margin).to(device)
511
+ n_points = pixels.size(1)
512
+ pixels_world = image_points_to_world(
513
+ pixels, camera_mat=camera_matrices[0],
514
+ world_mat=camera_matrices[1])
515
+ camera_world = origin_to_world(
516
+ n_points, camera_mat=camera_matrices[0],
517
+ world_mat=camera_matrices[1])
518
+ ray_vector = pixels_world - camera_world
519
+ ray_vector = ray_vector / ray_vector.norm(dim=-1, keepdim=True)
520
+ return ray_vector, camera_world, pixels_world
521
+
522
+
523
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
524
+ """
525
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
526
+ using Gram--Schmidt orthogonalization per Section B of [1].
527
+ Args:
528
+ d6: 6D rotation representation, of size (*, 6)
529
+
530
+ Returns:
531
+ batch of rotation matrices of size (*, 3, 3)
532
+
533
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
534
+ On the Continuity of Rotation Representations in Neural Networks.
535
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
536
+ Retrieved from http://arxiv.org/abs/1812.07035
537
+ """
538
+
539
+ a1, a2 = d6[..., :3], d6[..., 3:]
540
+ b1 = F.normalize(a1, dim=-1)
541
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
542
+ b2 = F.normalize(b2, dim=-1)
543
+ b3 = torch.cross(b1, b2, dim=-1)
544
+ return torch.stack((b1, b2, b3), dim=-2)
545
+
546
+
547
+ def camera_9d_to_16d(d9):
548
+ d6, translation = d9[..., :6], d9[..., 6:]
549
+ rotation = rotation_6d_to_matrix(d6)
550
+ RT = torch.eye(4).to(device=d9.device, dtype=d9.dtype).reshape(
551
+ 1, 4, 4).repeat(d6.size(0), 1, 1)
552
+ RT[:, :3, :3] = rotation
553
+ RT[:, :3, -1] = translation
554
+ return RT.reshape(-1, 16)
555
+
556
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
557
+ """
558
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
559
+ by dropping the last row. Note that 6D representation is not unique.
560
+ Args:
561
+ matrix: batch of rotation matrices of size (*, 3, 3)
562
+
563
+ Returns:
564
+ 6D rotation representation, of size (*, 6)
565
+
566
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
567
+ On the Continuity of Rotation Representations in Neural Networks.
568
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
569
+ Retrieved from http://arxiv.org/abs/1812.07035
570
+ """
571
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
572
+
573
+
574
+ def depth2pts_outside(ray_o, ray_d, depth):
575
+ '''
576
+ ray_o, ray_d: [..., 3]
577
+ depth: [...]; inverse of distance to sphere origin
578
+ '''
579
+ # note: d1 becomes negative if this mid point is behind camera
580
+ d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
581
+ p_mid = ray_o + d1.unsqueeze(-1) * ray_d
582
+ p_mid_norm = torch.norm(p_mid, dim=-1)
583
+ ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
584
+ d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos
585
+ p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d
586
+
587
+ rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
588
+ rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
589
+ phi = torch.asin(p_mid_norm)
590
+ theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
591
+ rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
592
+
593
+ # now rotate p_sphere
594
+ # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
595
+ p_sphere_new = p_sphere * torch.cos(rot_angle) + \
596
+ torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
597
+ rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))
598
+ p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
599
+ pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
600
+
601
+ # now calculate conventional depth
602
+ depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1
603
+ return pts, depth_real
604
+
605
+
606
+ def intersect_sphere(ray_o, ray_d, radius=1):
607
+ '''
608
+ ray_o, ray_d: [..., 3]
609
+ compute the depth of the intersection point between this ray and unit sphere
610
+ '''
611
+ # note: d1 becomes negative if this mid point is behind camera
612
+ d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
613
+ p = ray_o + d1.unsqueeze(-1) * ray_d
614
+ # consider the case where the ray does not intersect the sphere
615
+ ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
616
+ d2 = radius ** 2 - torch.sum(p * p, dim=-1)
617
+ mask = (d2 > 0)
618
+ d2 = torch.sqrt(d2.clamp(min=1e-6)) * ray_d_cos
619
+ d1, d2 = d1.unsqueeze(-1), d2.unsqueeze(-1)
620
+ depth_range = [d1 - d2, d1 + d2]
621
+ return depth_range, mask
622
+
623
+
624
+ def normalize(x, axis=-1, order=2):
625
+ if isinstance(x, torch.Tensor):
626
+ l2 = x.norm(p=order, dim=axis, keepdim=True)
627
+ return x / (l2 + 1e-8), l2
628
+
629
+ else:
630
+ l2 = np.linalg.norm(x, order, axis)
631
+ l2 = np.expand_dims(l2, axis)
632
+ l2[l2==0] = 1
633
+ return x / l2, l2
634
+
635
+
636
+ def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5):
637
+ """
638
+ Sample @N_importance samples from @bins with distribution defined by @weights.
639
+ Inputs:
640
+ bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
641
+ weights: (N_rays, N_samples_)
642
+ N_importance: the number of samples to draw from the distribution
643
+ det: deterministic or not
644
+ eps: a small number to prevent division by zero
645
+ Outputs:
646
+ samples: the sampled samples
647
+ Source: https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py
648
+ """
649
+ N_rays, N_samples_ = weights.shape
650
+ weights = weights + eps # prevent division by zero (don't do inplace op!)
651
+ pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
652
+ cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
653
+ cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
654
+ # padded to 0~1 inclusive
655
+
656
+ if det:
657
+ u = torch.linspace(0, 1, N_importance, device=bins.device)
658
+ u = u.expand(N_rays, N_importance)
659
+ else:
660
+ u = torch.rand(N_rays, N_importance, device=bins.device)
661
+ u = u.contiguous()
662
+
663
+ inds = torch.searchsorted(cdf, u)
664
+ below = torch.clamp_min(inds-1, 0)
665
+ above = torch.clamp_max(inds, N_samples_)
666
+
667
+ inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
668
+ cdf_g = torch.gather(cdf, 1, inds_sampled)
669
+ cdf_g = cdf_g.view(N_rays, N_importance, 2)
670
+ bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
671
+
672
+ denom = cdf_g[...,1]-cdf_g[...,0]
673
+ denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
674
+ # anyway, therefore any value for it is fine (set to 1 here)
675
+
676
+ samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
677
+ return samples
678
+
679
+
680
+ def normalization_inverse_sqrt_dist_centered(x_in_world, view_cell_center, max_depth):
681
+ localized = x_in_world - view_cell_center
682
+ local = torch.sqrt(torch.linalg.norm(localized, dim=-1))
683
+ res = localized / (math.sqrt(max_depth) * local[..., None])
684
+ return res
685
+
686
+
687
+ ######################################################################################
dnnlib/filters.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+
4
+ import math
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ def kaiser_attenuation(n_taps, f_h, sr):
12
+ df = (2 * f_h) / (sr / 2)
13
+ return 2.285 * (n_taps - 1) * math.pi * df + 7.95
14
+
15
+
16
+ def kaiser_beta(n_taps, f_h, sr):
17
+ atten = kaiser_attenuation(n_taps, f_h, sr)
18
+
19
+ if atten > 50:
20
+ return 0.1102 * (atten - 8.7)
21
+
22
+ elif 50 >= atten >= 21:
23
+ return 0.5842 * (atten - 21) ** 0.4 + 0.07886 * (atten - 21)
24
+
25
+ else:
26
+ return 0.0
27
+
28
+ def sinc(x, eps=1e-10):
29
+ y = torch.sin(math.pi * x) / (math.pi * x + eps)
30
+ y = y.masked_fill(x.eq(0), 1.0)
31
+ return y
32
+
33
+
34
+ def kaiser_window(n_taps, f_h, sr):
35
+ beta = kaiser_beta(n_taps, f_h, sr)
36
+ ind = torch.arange(n_taps) - (n_taps - 1) / 2
37
+ return torch.i0(beta * torch.sqrt(1 - ((2 * ind) / (n_taps - 1)) ** 2)) / torch.i0(
38
+ torch.tensor(beta)
39
+ )
40
+
41
+
42
+ def lowpass_filter(n_taps, cutoff, band_half, sr):
43
+ window = kaiser_window(n_taps, band_half, sr)
44
+ ind = torch.arange(n_taps) - (n_taps - 1) / 2
45
+ lowpass = 2 * cutoff / sr * sinc(2 * cutoff / sr * ind) * window
46
+ return lowpass
47
+
48
+
49
+ def filter_parameters(
50
+ n_layer,
51
+ n_critical,
52
+ sr_max,
53
+ cutoff_0,
54
+ cutoff_n,
55
+ stopband_0,
56
+ stopband_n
57
+ ):
58
+ cutoffs = []
59
+ stopbands = []
60
+ srs = []
61
+ band_halfs = []
62
+
63
+ for i in range(n_layer):
64
+ f_c = cutoff_0 * (cutoff_n / cutoff_0) ** min(i / (n_layer - n_critical), 1)
65
+ f_t = stopband_0 * (stopband_n / stopband_0) ** min(
66
+ i / (n_layer - n_critical), 1
67
+ )
68
+ s_i = 2 ** math.ceil(math.log(min(2 * f_t, sr_max), 2))
69
+ f_h = max(f_t, s_i / 2) - f_c
70
+
71
+ cutoffs.append(f_c)
72
+ stopbands.append(f_t)
73
+ srs.append(s_i)
74
+ band_halfs.append(f_h)
75
+
76
+ return {
77
+ "cutoffs": cutoffs,
78
+ "stopbands": stopbands,
79
+ "srs": srs,
80
+ "band_halfs": band_halfs,
81
+ }
dnnlib/geometry.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import math
7
+ import random
8
+ import numpy as np
9
+
10
+
11
+ def positional_encoding(p, size, pe='normal', use_pos=False):
12
+ if pe == 'gauss':
13
+ p_transformed = np.pi * p @ size
14
+ p_transformed = torch.cat(
15
+ [torch.sin(p_transformed), torch.cos(p_transformed)], dim=-1)
16
+ else:
17
+ p_transformed = torch.cat([torch.cat(
18
+ [torch.sin((2 ** i) * np.pi * p),
19
+ torch.cos((2 ** i) * np.pi * p)],
20
+ dim=-1) for i in range(size)], dim=-1)
21
+ if use_pos:
22
+ p_transformed = torch.cat([p_transformed, p], -1)
23
+ return p_transformed
24
+
25
+
26
+ def upsample(img_nerf, size, filter=None):
27
+ up = size // img_nerf.size(-1)
28
+ if up <= 1:
29
+ return img_nerf
30
+
31
+ if filter is not None:
32
+ from torch_utils.ops import upfirdn2d
33
+ for _ in range(int(math.log2(up))):
34
+ img_nerf = upfirdn2d.downsample2d(img_nerf, filter, up=2)
35
+ else:
36
+ img_nerf = F.interpolate(img_nerf, (size, size), mode='bilinear', align_corners=False)
37
+ return img_nerf
38
+
39
+
40
+ def downsample(img0, size, filter=None):
41
+ down = img0.size(-1) // size
42
+ if down <= 1:
43
+ return img0
44
+
45
+ if filter is not None:
46
+ from torch_utils.ops import upfirdn2d
47
+ for _ in range(int(math.log2(down))):
48
+ img0 = upfirdn2d.downsample2d(img0, filter, down=2)
49
+ else:
50
+ img0 = F.interpolate(img0, (size, size), mode='bilinear', align_corners=False)
51
+ return img0
52
+
53
+
54
+ def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Normalize vector lengths.
57
+ """
58
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
59
+
60
+
61
+ def repeat_vecs(vecs, n, dim=0):
62
+ return torch.stack(n*[vecs], dim=dim)
63
+
64
+
65
+ def get_grids(H, W, device, align=True):
66
+ ch = 1 if align else 1 - (1 / H)
67
+ cw = 1 if align else 1 - (1 / W)
68
+ x, y = torch.meshgrid(torch.linspace(-cw, cw, W, device=device),
69
+ torch.linspace(ch, -ch, H, device=device))
70
+ return torch.stack([x, y], -1)
71
+
72
+
73
+ def local_ensemble(pi, po, resolution):
74
+ ii = range(resolution)
75
+ ia = torch.tensor([max((i - 1)//2, 0) for i in ii]).long()
76
+ ib = torch.tensor([min((i + 1)//2, resolution//2-1) for i in ii]).long()
77
+
78
+ ul = torch.meshgrid(ia, ia)
79
+ ur = torch.meshgrid(ia, ib)
80
+ ll = torch.meshgrid(ib, ia)
81
+ lr = torch.meshgrid(ib, ib)
82
+
83
+ d_ul, p_ul = po - pi[ul], torch.stack(ul, -1)
84
+ d_ur, p_ur = po - pi[ur], torch.stack(ur, -1)
85
+ d_ll, p_ll = po - pi[ll], torch.stack(ll, -1)
86
+ d_lr, p_lr = po - pi[lr], torch.stack(lr, -1)
87
+
88
+ c_ul = d_ul.prod(dim=-1).abs()
89
+ c_ur = d_ur.prod(dim=-1).abs()
90
+ c_ll = d_ll.prod(dim=-1).abs()
91
+ c_lr = d_lr.prod(dim=-1).abs()
92
+
93
+ D = torch.stack([d_ul, d_ur, d_ll, d_lr], 0)
94
+ P = torch.stack([p_ul, p_ur, p_ll, p_lr], 0)
95
+ C = torch.stack([c_ul, c_ur, c_ll, c_lr], 0)
96
+ C = C / C.sum(dim=0, keepdim=True)
97
+ return D, P, C
98
+
99
+
100
+ def get_initial_rays_trig(num_steps, fov, resolution, ray_start, ray_end, device='cpu'):
101
+ """Returns sample points, z_vals, ray directions in camera space."""
102
+
103
+ W, H = resolution
104
+ # Create full screen NDC (-1 to +1) coords [x, y, 0, 1].
105
+ # Y is flipped to follow image memory layouts.
106
+ x, y = torch.meshgrid(torch.linspace(-1, 1, W, device=device),
107
+ torch.linspace(1, -1, H, device=device))
108
+ x = x.T.flatten()
109
+ y = y.T.flatten()
110
+ z = -torch.ones_like(x, device=device) / math.tan((2 * math.pi * fov / 360)/2)
111
+
112
+ rays_d_cam = normalize_vecs(torch.stack([x, y, z], -1))
113
+
114
+ z_vals = torch.linspace(ray_start, ray_end, num_steps, device=device).reshape(1, num_steps, 1).repeat(W*H, 1, 1)
115
+ points = rays_d_cam.unsqueeze(1).repeat(1, num_steps, 1) * z_vals
116
+ return points, z_vals, rays_d_cam
117
+
118
+
119
+ def sample_camera_positions(
120
+ device, n=1, r=1, horizontal_stddev=1, vertical_stddev=1,
121
+ horizontal_mean=math.pi*0.5, vertical_mean=math.pi*0.5, mode='normal'):
122
+ """
123
+ Samples n random locations along a sphere of radius r.
124
+ Uses a gaussian distribution for pitch and yaw
125
+ """
126
+ if mode == 'uniform':
127
+ theta = (torch.rand((n, 1),device=device) - 0.5) * 2 * horizontal_stddev + horizontal_mean
128
+ phi = (torch.rand((n, 1),device=device) - 0.5) * 2 * vertical_stddev + vertical_mean
129
+
130
+ elif mode == 'normal' or mode == 'gaussian':
131
+ theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean
132
+ phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean
133
+
134
+ elif mode == 'hybrid':
135
+ if random.random() < 0.5:
136
+ theta = (torch.rand((n, 1),device=device) - 0.5) * 2 * horizontal_stddev * 2 + horizontal_mean
137
+ phi = (torch.rand((n, 1),device=device) - 0.5) * 2 * vertical_stddev * 2 + vertical_mean
138
+ else:
139
+ theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean
140
+ phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean
141
+ else:
142
+ phi = torch.ones((n, 1), device=device, dtype=torch.float) * vertical_mean
143
+ theta = torch.ones((n, 1), device=device, dtype=torch.float) * horizontal_mean
144
+
145
+ phi = torch.clamp(phi, 1e-5, math.pi - 1e-5)
146
+
147
+ output_points = torch.zeros((n, 3), device=device)# torch.cuda.FloatTensor(n, 3).fill_(0)#torch.zeros((n, 3))
148
+
149
+ output_points[:, 0:1] = r*torch.sin(phi) * torch.cos(theta)
150
+ output_points[:, 2:3] = r*torch.sin(phi) * torch.sin(theta)
151
+ output_points[:, 1:2] = r*torch.cos(phi)
152
+
153
+ return output_points, phi, theta
154
+
155
+
156
+ def perturb_points(points, z_vals, ray_directions, device):
157
+ distance_between_points = z_vals[:,:,1:2,:] - z_vals[:,:,0:1,:]
158
+ offset = (torch.rand(z_vals.shape, device=device)-0.5) * distance_between_points
159
+ z_vals = z_vals + offset
160
+ points = points + offset * ray_directions.unsqueeze(2)
161
+ return points, z_vals
162
+
163
+
164
+ def create_cam2world_matrix(forward_vector, origin, device=None):
165
+ """Takes in the direction the camera is pointing and the camera origin and returns a world2cam matrix."""
166
+
167
+ forward_vector = normalize_vecs(forward_vector)
168
+ up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=device).expand_as(forward_vector)
169
+ left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
170
+ up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, dim=-1))
171
+
172
+ rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
173
+ rotation_matrix[:, :3, :3] = torch.stack((-left_vector, up_vector, -forward_vector), axis=-1)
174
+
175
+ translation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
176
+ translation_matrix[:, :3, 3] = origin
177
+
178
+ cam2world = translation_matrix @ rotation_matrix
179
+
180
+ return cam2world
181
+
182
+
183
+ def transform_sampled_points(
184
+ points, z_vals, ray_directions, device,
185
+ h_stddev=1, v_stddev=1, h_mean=math.pi * 0.5,
186
+ v_mean=math.pi * 0.5, mode='normal'):
187
+ """
188
+ points: batch_size x total_pixels x num_steps x 3
189
+ z_vals: batch_size x total_pixels x num_steps
190
+ """
191
+ n, num_rays, num_steps, channels = points.shape
192
+ points, z_vals = perturb_points(points, z_vals, ray_directions, device)
193
+ camera_origin, pitch, yaw = sample_camera_positions(
194
+ n=points.shape[0], r=1,
195
+ horizontal_stddev=h_stddev, vertical_stddev=v_stddev,
196
+ horizontal_mean=h_mean, vertical_mean=v_mean,
197
+ device=device, mode=mode)
198
+ forward_vector = normalize_vecs(-camera_origin)
199
+ cam2world_matrix = create_cam2world_matrix(forward_vector, camera_origin, device=device)
200
+
201
+ points_homogeneous = torch.ones((points.shape[0], points.shape[1], points.shape[2], points.shape[3] + 1), device=device)
202
+ points_homogeneous[:, :, :, :3] = points
203
+
204
+ # should be n x 4 x 4 , n x r^2 x num_steps x 4
205
+ transformed_points = torch.bmm(cam2world_matrix, points_homogeneous.reshape(n, -1, 4).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, num_steps, 4)
206
+ transformed_ray_directions = torch.bmm(cam2world_matrix[..., :3, :3], ray_directions.reshape(n, -1, 3).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, 3)
207
+
208
+ homogeneous_origins = torch.zeros((n, 4, num_rays), device=device)
209
+ homogeneous_origins[:, 3, :] = 1
210
+
211
+ transformed_ray_origins = torch.bmm(cam2world_matrix, homogeneous_origins).permute(0, 2, 1).reshape(n, num_rays, 4)[..., :3]
212
+ return transformed_points[..., :3], z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw
213
+
214
+
215
+ def integration(
216
+ rgb_sigma, z_vals, device, noise_std=0.5,
217
+ last_back=False, white_back=False, clamp_mode=None, fill_mode=None):
218
+
219
+ rgbs = rgb_sigma[..., :3]
220
+ sigmas = rgb_sigma[..., 3:]
221
+
222
+ deltas = z_vals[..., 1:, :] - z_vals[..., :-1, :]
223
+ delta_inf = 1e10 * torch.ones_like(deltas[..., :1, :])
224
+ deltas = torch.cat([deltas, delta_inf], -2)
225
+
226
+ if noise_std > 0:
227
+ noise = torch.randn(sigmas.shape, device=device) * noise_std
228
+ else:
229
+ noise = 0
230
+
231
+ if clamp_mode == 'softplus':
232
+ alphas = 1 - torch.exp(-deltas * (F.softplus(sigmas + noise)))
233
+ elif clamp_mode == 'relu':
234
+ alphas = 1 - torch.exp(-deltas * (F.relu(sigmas + noise)))
235
+ else:
236
+ raise "Need to choose clamp mode"
237
+
238
+ alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1, :]), 1-alphas + 1e-10], -2)
239
+ weights = alphas * torch.cumprod(alphas_shifted, -2)[..., :-1, :]
240
+ weights_sum = weights.sum(-2)
241
+
242
+ if last_back:
243
+ weights[..., -1, :] += (1 - weights_sum)
244
+
245
+ rgb_final = torch.sum(weights * rgbs, -2)
246
+ depth_final = torch.sum(weights * z_vals, -2)
247
+
248
+ if white_back:
249
+ rgb_final = rgb_final + 1-weights_sum
250
+
251
+ if fill_mode == 'debug':
252
+ rgb_final[weights_sum.squeeze(-1) < 0.9] = torch.tensor([1., 0, 0], device=rgb_final.device)
253
+ elif fill_mode == 'weight':
254
+ rgb_final = weights_sum.expand_as(rgb_final)
255
+
256
+ return rgb_final, depth_final, weights
257
+
258
+
259
+ def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64):
260
+ # return numpy array of forwarded sigma value
261
+ bound = (nerf.depth_range[1] - nerf.depth_range[0]) * 0.5
262
+ X = torch.linspace(-bound, bound, resolution).split(block_resolution)
263
+
264
+ sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32)
265
+
266
+ for xi, xs in enumerate(X):
267
+ for yi, ys in enumerate(X):
268
+ for zi, zs in enumerate(X):
269
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
270
+ pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C
271
+ block_shape = [1, len(xs), len(ys), len(zs)]
272
+ feat_out, sigma_out = nerf.fg_nerf.forward_style2(pts, None, block_shape, ws=styles)
273
+ sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \
274
+ yi * block_resolution: yi * block_resolution + len(ys), \
275
+ zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy()
276
+
277
+ return sigma_np, bound
278
+
279
+
280
+ def extract_geometry(nerf, styles, resolution, threshold):
281
+ import mcubes
282
+
283
+ print('threshold: {}'.format(threshold))
284
+ u, bound = get_sigma_field_np(nerf, styles, resolution)
285
+ vertices, triangles = mcubes.marching_cubes(u, threshold)
286
+ b_min_np = np.array([-bound, -bound, -bound])
287
+ b_max_np = np.array([ bound, bound, bound])
288
+
289
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
290
+ return vertices.astype('float32'), triangles
291
+
292
+
293
+ def render_mesh(meshes, camera_matrices, render_noise=True):
294
+ from pytorch3d.renderer import (
295
+ FoVPerspectiveCameras, look_at_view_transform,
296
+ RasterizationSettings, BlendParams,
297
+ MeshRenderer, MeshRasterizer, HardPhongShader, TexturesVertex
298
+ )
299
+ from pytorch3d.ops import interpolate_face_attributes
300
+ from pytorch3d.structures.meshes import Meshes
301
+
302
+ intrinsics, poses, _, _ = camera_matrices
303
+ device = poses.device
304
+ c2w = torch.matmul(poses, torch.diag(torch.tensor([-1.0, 1.0, -1.0, 1.0], device=device))[None, :, :]) # Different camera model...
305
+ w2c = torch.inverse(c2w)
306
+ R = c2w[:, :3, :3]
307
+ T = w2c[:, :3, 3] # So weird..... Why one is c2w and another is w2c?
308
+ focal = intrinsics[0, 0, 0]
309
+ fov = torch.arctan(focal) * 2.0 / np.pi * 180
310
+
311
+
312
+ colors = []
313
+ offset = 1
314
+ for res, (mesh, face_vert_noise) in meshes.items():
315
+ raster_settings = RasterizationSettings(
316
+ image_size=res,
317
+ blur_radius=0.0,
318
+ faces_per_pixel=1,
319
+ )
320
+ mesh = Meshes(
321
+ verts=[torch.from_numpy(mesh.vertices).float().to(device)],
322
+ faces=[torch.from_numpy(mesh.faces).long().to(device)])
323
+
324
+ _colors = []
325
+ for i in range(len(poses)):
326
+ cameras = FoVPerspectiveCameras(device=device, R=R[i: i+1], T=T[i: i+1], fov=fov)
327
+ rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
328
+ pix_to_face, zbuf, bary_coord, dists = rasterizer(mesh)
329
+ color = interpolate_face_attributes(pix_to_face, bary_coord, face_vert_noise).squeeze()
330
+
331
+ # hack
332
+ color[offset:, offset:] = color[:-offset, :-offset]
333
+ _colors += [color]
334
+ color = torch.stack(_colors, 0).permute(0,3,1,2)
335
+ colors += [color]
336
+ offset *= 2
337
+ return colors
338
+
339
+
340
+ def rotate_vects(v, theta):
341
+ theta = theta / math.pi * 2
342
+ theta = theta + (theta < 0).type_as(theta) * 4
343
+ v = v.reshape(v.size(0), v.size(1) // 4, 4, v.size(2), v.size(3))
344
+ vs = []
345
+ order = [0,2,3,1] # Not working
346
+ iorder = [0,3,1,2] # Not working
347
+ for b in range(len(v)):
348
+ if (theta[b] - 0) < 1e-6:
349
+ u, l = 0, 0
350
+ elif (theta[b] - 1) < 1e-6:
351
+ u, l = 0, 1
352
+ elif (theta[b] - 2) < 1e-6:
353
+ u, l = 0, 2
354
+ elif (theta[b] - 3) < 1e-6:
355
+ u, l = 0, 3
356
+ else:
357
+ u, l = math.modf(theta[b])
358
+ l, r = int(l), int(l + 1) % 4
359
+ vv = v[b, :, order] # 0 -> 1 -> 3 -> 2
360
+ vl = torch.cat([vv[:, l:], vv[:, :l]], 1)
361
+ if u > 0:
362
+ vr = torch.cat([vv[:, r:], vv[:, :r]], 1)
363
+ vv = vl * (1-u) + vr * u
364
+ else:
365
+ vv = vl
366
+ vs.append(vv[:, iorder])
367
+ v = torch.stack(vs, 0)
368
+ v = v.reshape(v.size(0), -1, v.size(-2), v.size(-1))
369
+ return v
370
+
371
+
372
+ def generate_option_outputs(render_option):
373
+ # output debugging outputs (not used in normal rendering process)
374
+ if ('depth' in render_option.split(',')):
375
+ img = camera_world[:, :1] + fg_depth_map * ray_vector
376
+ img = reformat(img, tgt_res)
377
+
378
+ if 'gradient' in render_option.split(','):
379
+ points = (camera_world[:,:,None]+di[:,:,:,None]*ray_vector[:,:,None]).reshape(
380
+ batch_size, tgt_res, tgt_res, di.size(-1), 3)
381
+ with torch.enable_grad():
382
+ gradients = self.fg_nerf.forward_style2(
383
+ points, None, [batch_size, tgt_res, di.size(-1), tgt_res], get_normal=True,
384
+ ws=styles, z_shape=z_shape_obj, z_app=z_app_obj).reshape(
385
+ batch_size, di.size(-1), 3, tgt_res * tgt_res).permute(0,3,1,2)
386
+ avg_grads = (gradients * fg_weights.unsqueeze(-1)).sum(-2)
387
+ normal = reformat(normalize(avg_grads, axis=2)[0], tgt_res)
388
+ img = normal
389
+
390
+ if 'value' in render_option.split(','):
391
+ fg_feat = fg_feat[:,:,3:].norm(dim=-1,keepdim=True)
392
+ img = reformat(fg_feat.repeat(1,1,3), tgt_res) / fg_feat.max() * 2 - 1
393
+
394
+ if 'opacity' in render_option.split(','):
395
+ opacity = bg_lambda.unsqueeze(-1).repeat(1,1,3) * 2 - 1
396
+ img = reformat(opacity, tgt_res)
397
+
398
+ if 'normal' in render_option.split(','):
399
+ shift_l, shift_r = img[:,:,2:,:], img[:,:,:-2,:]
400
+ shift_u, shift_d = img[:,:,:,2:], img[:,:,:,:-2]
401
+ diff_hor = normalize(shift_r - shift_l, axis=1)[0][:, :, :, 1:-1]
402
+ diff_ver = normalize(shift_u - shift_d, axis=1)[0][:, :, 1:-1, :]
403
+ normal = torch.cross(diff_hor, diff_ver, dim=1)
404
+ img = normalize(normal, axis=1)[0]
405
+
406
+ return {'full_out': (None, img), 'reg_loss': {}}
dnnlib/util.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Miscellaneous utility classes and functions."""
12
+
13
+ import ctypes
14
+ import fnmatch
15
+ import importlib
16
+ import inspect
17
+ import numpy as np
18
+ import os
19
+ import shutil
20
+ import sys
21
+ import types
22
+ import io
23
+ import pickle
24
+ import re
25
+ import requests
26
+ import html
27
+ import hashlib
28
+ import glob
29
+ import tempfile
30
+ import urllib
31
+ import urllib.request
32
+ import uuid
33
+ import torch
34
+
35
+ from distutils.util import strtobool
36
+ from typing import Any, List, Tuple, Union
37
+
38
+
39
+ # Util classes
40
+ # ------------------------------------------------------------------------------------------
41
+
42
+
43
+ class EasyDict(dict):
44
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
45
+
46
+ def __getattr__(self, name: str) -> Any:
47
+ try:
48
+ return self[name]
49
+ except KeyError:
50
+ raise AttributeError(name)
51
+
52
+ def __setattr__(self, name: str, value: Any) -> None:
53
+ self[name] = value
54
+
55
+ def __delattr__(self, name: str) -> None:
56
+ del self[name]
57
+
58
+
59
+ class Logger(object):
60
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
61
+
62
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
63
+ self.file = None
64
+
65
+ if file_name is not None:
66
+ self.file = open(file_name, file_mode)
67
+
68
+ self.should_flush = should_flush
69
+ self.stdout = sys.stdout
70
+ self.stderr = sys.stderr
71
+
72
+ sys.stdout = self
73
+ sys.stderr = self
74
+
75
+ def __enter__(self) -> "Logger":
76
+ return self
77
+
78
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
79
+ self.close()
80
+
81
+ def write(self, text: Union[str, bytes]) -> None:
82
+ """Write text to stdout (and a file) and optionally flush."""
83
+ if isinstance(text, bytes):
84
+ text = text.decode()
85
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
86
+ return
87
+
88
+ if self.file is not None:
89
+ self.file.write(text)
90
+
91
+ self.stdout.write(text)
92
+
93
+ if self.should_flush:
94
+ self.flush()
95
+
96
+ def flush(self) -> None:
97
+ """Flush written text to both stdout and a file, if open."""
98
+ if self.file is not None:
99
+ self.file.flush()
100
+
101
+ self.stdout.flush()
102
+
103
+ def close(self) -> None:
104
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
105
+ self.flush()
106
+
107
+ # if using multiple loggers, prevent closing in wrong order
108
+ if sys.stdout is self:
109
+ sys.stdout = self.stdout
110
+ if sys.stderr is self:
111
+ sys.stderr = self.stderr
112
+
113
+ if self.file is not None:
114
+ self.file.close()
115
+ self.file = None
116
+
117
+
118
+ # Cache directories
119
+ # ------------------------------------------------------------------------------------------
120
+
121
+ _dnnlib_cache_dir = None
122
+
123
+ def set_cache_dir(path: str) -> None:
124
+ global _dnnlib_cache_dir
125
+ _dnnlib_cache_dir = path
126
+
127
+ def make_cache_dir_path(*paths: str) -> str:
128
+ if _dnnlib_cache_dir is not None:
129
+ return os.path.join(_dnnlib_cache_dir, *paths)
130
+ if 'DNNLIB_CACHE_DIR' in os.environ:
131
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
132
+ if 'HOME' in os.environ:
133
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
134
+ if 'USERPROFILE' in os.environ:
135
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
136
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
137
+
138
+ # Small util functions
139
+ # ------------------------------------------------------------------------------------------
140
+
141
+
142
+ def format_time(seconds: Union[int, float]) -> str:
143
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
144
+ s = int(np.rint(seconds))
145
+
146
+ if s < 60:
147
+ return "{0}s".format(s)
148
+ elif s < 60 * 60:
149
+ return "{0}m {1:02}s".format(s // 60, s % 60)
150
+ elif s < 24 * 60 * 60:
151
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
152
+ else:
153
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
154
+
155
+
156
+ def ask_yes_no(question: str) -> bool:
157
+ """Ask the user the question until the user inputs a valid answer."""
158
+ while True:
159
+ try:
160
+ print("{0} [y/n]".format(question))
161
+ return strtobool(input().lower())
162
+ except ValueError:
163
+ pass
164
+
165
+
166
+ def tuple_product(t: Tuple) -> Any:
167
+ """Calculate the product of the tuple elements."""
168
+ result = 1
169
+
170
+ for v in t:
171
+ result *= v
172
+
173
+ return result
174
+
175
+
176
+ _str_to_ctype = {
177
+ "uint8": ctypes.c_ubyte,
178
+ "uint16": ctypes.c_uint16,
179
+ "uint32": ctypes.c_uint32,
180
+ "uint64": ctypes.c_uint64,
181
+ "int8": ctypes.c_byte,
182
+ "int16": ctypes.c_int16,
183
+ "int32": ctypes.c_int32,
184
+ "int64": ctypes.c_int64,
185
+ "float32": ctypes.c_float,
186
+ "float64": ctypes.c_double
187
+ }
188
+
189
+
190
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
191
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
192
+ type_str = None
193
+
194
+ if isinstance(type_obj, str):
195
+ type_str = type_obj
196
+ elif hasattr(type_obj, "__name__"):
197
+ type_str = type_obj.__name__
198
+ elif hasattr(type_obj, "name"):
199
+ type_str = type_obj.name
200
+ else:
201
+ raise RuntimeError("Cannot infer type name from input")
202
+
203
+ assert type_str in _str_to_ctype.keys()
204
+
205
+ my_dtype = np.dtype(type_str)
206
+ my_ctype = _str_to_ctype[type_str]
207
+
208
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
209
+
210
+ return my_dtype, my_ctype
211
+
212
+
213
+ def is_pickleable(obj: Any) -> bool:
214
+ try:
215
+ with io.BytesIO() as stream:
216
+ pickle.dump(obj, stream)
217
+ return True
218
+ except:
219
+ return False
220
+
221
+
222
+ # Functionality to import modules/objects by name, and call functions by name
223
+ # ------------------------------------------------------------------------------------------
224
+
225
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
226
+ """Searches for the underlying module behind the name to some python object.
227
+ Returns the module and the object name (original name with module part removed)."""
228
+
229
+ # allow convenience shorthands, substitute them by full names
230
+ obj_name = re.sub("^np.", "numpy.", obj_name)
231
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
232
+
233
+ # list alternatives for (module_name, local_obj_name)
234
+ parts = obj_name.split(".")
235
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
236
+
237
+ # try each alternative in turn
238
+ for module_name, local_obj_name in name_pairs:
239
+ try:
240
+ module = importlib.import_module(module_name) # may raise ImportError
241
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
242
+ return module, local_obj_name
243
+ except:
244
+ pass
245
+
246
+ # maybe some of the modules themselves contain errors?
247
+ for module_name, _local_obj_name in name_pairs:
248
+ try:
249
+ importlib.import_module(module_name) # may raise ImportError
250
+ except ImportError:
251
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
252
+ raise
253
+
254
+ # maybe the requested attribute is missing?
255
+ for module_name, local_obj_name in name_pairs:
256
+ try:
257
+ module = importlib.import_module(module_name) # may raise ImportError
258
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
259
+ except ImportError:
260
+ pass
261
+
262
+ # we are out of luck, but we have no idea why
263
+ raise ImportError(obj_name)
264
+
265
+
266
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
267
+ """Traverses the object name and returns the last (rightmost) python object."""
268
+ if obj_name == '':
269
+ return module
270
+ obj = module
271
+ for part in obj_name.split("."):
272
+ obj = getattr(obj, part)
273
+ return obj
274
+
275
+
276
+ def get_obj_by_name(name: str) -> Any:
277
+ """Finds the python object with the given name."""
278
+ module, obj_name = get_module_from_obj_name(name)
279
+ return get_obj_from_module(module, obj_name)
280
+
281
+
282
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
283
+ """Finds the python object with the given name and calls it as a function."""
284
+ assert func_name is not None
285
+ func_obj = get_obj_by_name(func_name)
286
+ assert callable(func_obj)
287
+ return func_obj(*args, **kwargs)
288
+
289
+
290
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
291
+ """Finds the python class with the given name and constructs it with the given arguments."""
292
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
293
+
294
+
295
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
296
+ """Get the directory path of the module containing the given object name."""
297
+ module, _ = get_module_from_obj_name(obj_name)
298
+ return os.path.dirname(inspect.getfile(module))
299
+
300
+
301
+ def is_top_level_function(obj: Any) -> bool:
302
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
303
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
304
+
305
+
306
+ def get_top_level_function_name(obj: Any) -> str:
307
+ """Return the fully-qualified name of a top-level function."""
308
+ assert is_top_level_function(obj)
309
+ module = obj.__module__
310
+ if module == '__main__':
311
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
312
+ return module + "." + obj.__name__
313
+
314
+
315
+ # File system helpers
316
+ # ------------------------------------------------------------------------------------------
317
+
318
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
319
+ """List all files recursively in a given directory while ignoring given file and directory names.
320
+ Returns list of tuples containing both absolute and relative paths."""
321
+ assert os.path.isdir(dir_path)
322
+ base_name = os.path.basename(os.path.normpath(dir_path))
323
+
324
+ if ignores is None:
325
+ ignores = []
326
+
327
+ result = []
328
+
329
+ for root, dirs, files in os.walk(dir_path, topdown=True):
330
+ for ignore_ in ignores:
331
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
332
+
333
+ # dirs need to be edited in-place
334
+ for d in dirs_to_remove:
335
+ dirs.remove(d)
336
+
337
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
338
+
339
+ absolute_paths = [os.path.join(root, f) for f in files]
340
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
341
+
342
+ if add_base_to_relative:
343
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
344
+
345
+ assert len(absolute_paths) == len(relative_paths)
346
+ result += zip(absolute_paths, relative_paths)
347
+
348
+ return result
349
+
350
+
351
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
352
+ """Takes in a list of tuples of (src, dst) paths and copies files.
353
+ Will create all necessary directories."""
354
+ for file in files:
355
+ target_dir_name = os.path.dirname(file[1])
356
+
357
+ # will create all intermediate-level directories
358
+ if not os.path.exists(target_dir_name):
359
+ os.makedirs(target_dir_name)
360
+
361
+ shutil.copyfile(file[0], file[1])
362
+
363
+
364
+ # URL helpers
365
+ # ------------------------------------------------------------------------------------------
366
+
367
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
368
+ """Determine whether the given object is a valid URL string."""
369
+ if not isinstance(obj, str) or not "://" in obj:
370
+ return False
371
+ if allow_file_urls and obj.startswith('file://'):
372
+ return True
373
+ try:
374
+ res = requests.compat.urlparse(obj)
375
+ if not res.scheme or not res.netloc or not "." in res.netloc:
376
+ return False
377
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
378
+ if not res.scheme or not res.netloc or not "." in res.netloc:
379
+ return False
380
+ except:
381
+ return False
382
+ return True
383
+
384
+
385
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
386
+ """Download the given URL and return a binary-mode file object to access the data."""
387
+ assert num_attempts >= 1
388
+ assert not (return_filename and (not cache))
389
+
390
+ # Doesn't look like an URL scheme so interpret it as a local filename.
391
+ if not re.match('^[a-z]+://', url):
392
+ return url if return_filename else open(url, "rb")
393
+
394
+ # Handle file URLs. This code handles unusual file:// patterns that
395
+ # arise on Windows:
396
+ #
397
+ # file:///c:/foo.txt
398
+ #
399
+ # which would translate to a local '/c:/foo.txt' filename that's
400
+ # invalid. Drop the forward slash for such pathnames.
401
+ #
402
+ # If you touch this code path, you should test it on both Linux and
403
+ # Windows.
404
+ #
405
+ # Some internet resources suggest using urllib.request.url2pathname() but
406
+ # but that converts forward slashes to backslashes and this causes
407
+ # its own set of problems.
408
+ if url.startswith('file://'):
409
+ filename = urllib.parse.urlparse(url).path
410
+ if re.match(r'^/[a-zA-Z]:', filename):
411
+ filename = filename[1:]
412
+ return filename if return_filename else open(filename, "rb")
413
+
414
+ assert is_url(url)
415
+
416
+ # Lookup from cache.
417
+ if cache_dir is None:
418
+ cache_dir = make_cache_dir_path('downloads')
419
+
420
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
421
+ if cache:
422
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
423
+ if len(cache_files) == 1:
424
+ filename = cache_files[0]
425
+ return filename if return_filename else open(filename, "rb")
426
+
427
+ # Download.
428
+ url_name = None
429
+ url_data = None
430
+ with requests.Session() as session:
431
+ if verbose:
432
+ print("Downloading %s ..." % url, end="", flush=True)
433
+ for attempts_left in reversed(range(num_attempts)):
434
+ try:
435
+ with session.get(url) as res:
436
+ res.raise_for_status()
437
+ if len(res.content) == 0:
438
+ raise IOError("No data received")
439
+
440
+ if len(res.content) < 8192:
441
+ content_str = res.content.decode("utf-8")
442
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
443
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
444
+ if len(links) == 1:
445
+ url = requests.compat.urljoin(url, links[0])
446
+ raise IOError("Google Drive virus checker nag")
447
+ if "Google Drive - Quota exceeded" in content_str:
448
+ raise IOError("Google Drive download quota exceeded -- please try again later")
449
+
450
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
451
+ url_name = match[1] if match else url
452
+ url_data = res.content
453
+ if verbose:
454
+ print(" done")
455
+ break
456
+ except KeyboardInterrupt:
457
+ raise
458
+ except:
459
+ if not attempts_left:
460
+ if verbose:
461
+ print(" failed")
462
+ raise
463
+ if verbose:
464
+ print(".", end="", flush=True)
465
+
466
+ # Save to cache.
467
+ if cache:
468
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
469
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
470
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
471
+ os.makedirs(cache_dir, exist_ok=True)
472
+ with open(temp_file, "wb") as f:
473
+ f.write(url_data)
474
+ os.replace(temp_file, cache_file) # atomic
475
+ if return_filename:
476
+ return cache_file
477
+
478
+ # Return data as file object.
479
+ assert not return_filename
480
+ return io.BytesIO(url_data)
481
+
482
+
483
+ def dividable(n, k=2):
484
+ if k == 2:
485
+ for i in range(int(np.sqrt(n)), 0, -1):
486
+ if n % i == 0:
487
+ break
488
+ return i, n // i
489
+ elif k == 3:
490
+ for i in range(int(float(n) ** (1/3)), 0, -1):
491
+ if n % i == 0:
492
+ b, c = dividable(n // i, 2)
493
+ return i, b, c
494
+ else:
495
+ raise NotImplementedError
496
+
497
+
498
+ def visualize_feature_map(x, scale=1.0, mask=None, loc=None):
499
+ B, C, H, W = x.size()
500
+ lh, lw = dividable(C)
501
+ x = x.reshape(B, lh, lw, H, W).permute(0,1,3,2,4)
502
+ # loc = [(3,1), (6,3), (4,0)]
503
+ # loc = [(4,0), (0,7), (6,2)]
504
+ loc = [(3, 11), (5,3), (3,9)]
505
+ # loc = [(1,3), (5,3), (7,4)]
506
+ # loc = [(0,5), (10,0), (3,14)]
507
+ if loc is None:
508
+ x = x.reshape(B, 1, lh*H, lw*W).repeat(1,3,1,1)
509
+ else:
510
+ x = [x[:, l[0], :, l[1]] for l in loc]
511
+ x = torch.stack(x, 1)
512
+ x = x / x.norm(dim=1, keepdim=True)
513
+ x = x / scale
514
+ return x
515
+
516
+
517
+ def hash_func(x, res, T):
518
+ d = x.size(-1)
519
+ assert d <= 3
520
+
521
+ h = x[..., 0]
522
+ if res ** d < T:
523
+ f = [1, res, res * res]
524
+ for i in range(1, d):
525
+ h += x[..., i] * f[i]
526
+ else:
527
+ f = [1, 19349663, 83492791]
528
+ for i in range(1, d):
529
+ h = torch.bitwise_xor(h, x[..., i] * f[i])
530
+ h = h % T
531
+ return h
generate.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Generate images using pretrained network pickle."""
12
+
13
+ import os
14
+ import re
15
+ import time
16
+ import glob
17
+ from typing import List, Optional
18
+
19
+ import click
20
+ import dnnlib
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ import imageio
25
+ import legacy
26
+ from renderer import Renderer
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ def num_range(s: str) -> List[int]:
31
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
32
+
33
+ range_re = re.compile(r'^(\d+)-(\d+)$')
34
+ m = range_re.match(s)
35
+ if m:
36
+ return list(range(int(m.group(1)), int(m.group(2))+1))
37
+ vals = s.split(',')
38
+ return [int(x) for x in vals]
39
+
40
+ #----------------------------------------------------------------------------
41
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
42
+
43
+ @click.command()
44
+ @click.pass_context
45
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
46
+ @click.option('--seeds', type=num_range, help='List of random seeds')
47
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
48
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
49
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
50
+ @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
51
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
52
+ @click.option('--render-program', default=None, show_default=True)
53
+ @click.option('--render-option', default=None, type=str, help="e.g. up_256, camera, depth")
54
+ @click.option('--n_steps', default=8, type=int, help="number of steps for each seed")
55
+ @click.option('--no-video', default=False)
56
+ @click.option('--relative_range_u_scale', default=1.0, type=float, help="relative scale on top of the original range u")
57
+ def generate_images(
58
+ ctx: click.Context,
59
+ network_pkl: str,
60
+ seeds: Optional[List[int]],
61
+ truncation_psi: float,
62
+ noise_mode: str,
63
+ outdir: str,
64
+ class_idx: Optional[int],
65
+ projected_w: Optional[str],
66
+ render_program=None,
67
+ render_option=None,
68
+ n_steps=8,
69
+ no_video=False,
70
+ relative_range_u_scale=1.0
71
+ ):
72
+
73
+
74
+ device = torch.device('cuda')
75
+ if os.path.isdir(network_pkl):
76
+ network_pkl = sorted(glob.glob(network_pkl + '/*.pkl'))[-1]
77
+ print('Loading networks from "%s"...' % network_pkl)
78
+
79
+ with dnnlib.util.open_url(network_pkl) as f:
80
+ network = legacy.load_network_pkl(f)
81
+ G = network['G_ema'].to(device) # type: ignore
82
+ D = network['D'].to(device)
83
+ # from fairseq import pdb;pdb.set_trace()
84
+ os.makedirs(outdir, exist_ok=True)
85
+
86
+ # Labels.
87
+ label = torch.zeros([1, G.c_dim], device=device)
88
+ if G.c_dim != 0:
89
+ if class_idx is None:
90
+ ctx.fail('Must specify class label with --class when using a conditional network')
91
+ label[:, class_idx] = 1
92
+ else:
93
+ if class_idx is not None:
94
+ print ('warn: --class=lbl ignored when running on an unconditional network')
95
+
96
+ # avoid persistent classes...
97
+ from training.networks import Generator
98
+ # from training.stylenerf import Discriminator
99
+ from torch_utils import misc
100
+ with torch.no_grad():
101
+ G2 = Generator(*G.init_args, **G.init_kwargs).to(device)
102
+ misc.copy_params_and_buffers(G, G2, require_all=False)
103
+ # D2 = Discriminator(*D.init_args, **D.init_kwargs).to(device)
104
+ # misc.copy_params_and_buffers(D, D2, require_all=False)
105
+ G2 = Renderer(G2, D, program=render_program)
106
+
107
+ # Generate images.
108
+ all_imgs = []
109
+
110
+ def stack_imgs(imgs):
111
+ img = torch.stack(imgs, dim=2)
112
+ return img.reshape(img.size(0) * img.size(1), img.size(2) * img.size(3), 3)
113
+
114
+ def proc_img(img):
115
+ return (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu()
116
+
117
+ if projected_w is not None:
118
+ ws = np.load(projected_w)
119
+ ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
120
+ img = G2(styles=ws, truncation_psi=truncation_psi, noise_mode=noise_mode, render_option=render_option)
121
+ assert isinstance(img, List)
122
+ imgs = [proc_img(i) for i in img]
123
+ all_imgs += [imgs]
124
+
125
+ else:
126
+ for seed_idx, seed in enumerate(seeds):
127
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
128
+ G2.set_random_seed(seed)
129
+ z = torch.from_numpy(np.random.RandomState(seed).randn(2, G.z_dim)).to(device)
130
+ relative_range_u = [0.5 - 0.5 * relative_range_u_scale, 0.5 + 0.5 * relative_range_u_scale]
131
+ outputs = G2(
132
+ z=z,
133
+ c=label,
134
+ truncation_psi=truncation_psi,
135
+ noise_mode=noise_mode,
136
+ render_option=render_option,
137
+ n_steps=n_steps,
138
+ relative_range_u=relative_range_u,
139
+ return_cameras=True)
140
+ if isinstance(outputs, tuple):
141
+ img, cameras = outputs
142
+ else:
143
+ img = outputs
144
+
145
+ if isinstance(img, List):
146
+ imgs = [proc_img(i) for i in img]
147
+ if not no_video:
148
+ all_imgs += [imgs]
149
+
150
+ curr_out_dir = os.path.join(outdir, 'seed_{:0>6d}'.format(seed))
151
+ os.makedirs(curr_out_dir, exist_ok=True)
152
+
153
+ if (render_option is not None) and ("gen_ibrnet_metadata" in render_option):
154
+ intrinsics = []
155
+ poses = []
156
+ _, H, W, _ = imgs[0].shape
157
+ for i, camera in enumerate(cameras):
158
+ intri, pose, _, _ = camera
159
+ focal = (H - 1) * 0.5 / intri[0, 0, 0].item()
160
+ intri = np.diag([focal, focal, 1.0, 1.0]).astype(np.float32)
161
+ intri[0, 2], intri[1, 2] = (W - 1) * 0.5, (H - 1) * 0.5
162
+
163
+ pose = pose.squeeze().detach().cpu().numpy() @ np.diag([1, -1, -1, 1]).astype(np.float32)
164
+ intrinsics.append(intri)
165
+ poses.append(pose)
166
+
167
+ intrinsics = np.stack(intrinsics, axis=0)
168
+ poses = np.stack(poses, axis=0)
169
+
170
+ np.savez(os.path.join(curr_out_dir, 'cameras.npz'), intrinsics=intrinsics, poses=poses)
171
+ with open(os.path.join(curr_out_dir, 'meta.conf'), 'w') as f:
172
+ f.write('depth_range = {}\ntest_hold_out = {}\nheight = {}\nwidth = {}'.
173
+ format(G2.generator.synthesis.depth_range, 2, H, W))
174
+
175
+ img_dir = os.path.join(curr_out_dir, 'images_raw')
176
+ os.makedirs(img_dir, exist_ok=True)
177
+ for step, img in enumerate(imgs):
178
+ PIL.Image.fromarray(img[0].detach().cpu().numpy(), 'RGB').save(f'{img_dir}/{step:03d}.png')
179
+
180
+ else:
181
+ img = proc_img(img)[0]
182
+ PIL.Image.fromarray(img.numpy(), 'RGB').save(f'{outdir}/seed_{seed:0>6d}.png')
183
+
184
+ if len(all_imgs) > 0 and (not no_video):
185
+ # write to video
186
+ timestamp = time.strftime('%Y%m%d.%H%M%S',time.localtime(time.time()))
187
+ seeds = ','.join([str(s) for s in seeds]) if seeds is not None else 'projected'
188
+ network_pkl = network_pkl.split('/')[-1].split('.')[0]
189
+ all_imgs = [stack_imgs([a[k] for a in all_imgs]).numpy() for k in range(len(all_imgs[0]))]
190
+ imageio.mimwrite(f'{outdir}/{network_pkl}_{timestamp}_{seeds}.mp4', all_imgs, fps=30, quality=8)
191
+ outdir = f'{outdir}/{network_pkl}_{timestamp}_{seeds}'
192
+ os.makedirs(outdir, exist_ok=True)
193
+ for step, img in enumerate(all_imgs):
194
+ PIL.Image.fromarray(img, 'RGB').save(f'{outdir}/{step:04d}.png')
195
+
196
+
197
+ #----------------------------------------------------------------------------
198
+
199
+ if __name__ == "__main__":
200
+ generate_images() # pylint: disable=no-value-for-parameter
201
+
202
+ #----------------------------------------------------------------------------
gui_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
gui_utils/gl_utils.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import functools
11
+ import contextlib
12
+ import numpy as np
13
+ import OpenGL.GL as gl
14
+ import OpenGL.GL.ARB.texture_float
15
+ import dnnlib
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def init_egl():
20
+ assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
21
+ import OpenGL.EGL as egl
22
+ import ctypes
23
+
24
+ # Initialize EGL.
25
+ display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
26
+ assert display != egl.EGL_NO_DISPLAY
27
+ major = ctypes.c_int32()
28
+ minor = ctypes.c_int32()
29
+ ok = egl.eglInitialize(display, major, minor)
30
+ assert ok
31
+ assert major.value * 10 + minor.value >= 14
32
+
33
+ # Choose config.
34
+ config_attribs = [
35
+ egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
36
+ egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
37
+ egl.EGL_NONE
38
+ ]
39
+ configs = (ctypes.c_int32 * 1)()
40
+ num_configs = ctypes.c_int32()
41
+ ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
42
+ assert ok
43
+ assert num_configs.value == 1
44
+ config = configs[0]
45
+
46
+ # Create dummy pbuffer surface.
47
+ surface_attribs = [
48
+ egl.EGL_WIDTH, 1,
49
+ egl.EGL_HEIGHT, 1,
50
+ egl.EGL_NONE
51
+ ]
52
+ surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
53
+ assert surface != egl.EGL_NO_SURFACE
54
+
55
+ # Setup GL context.
56
+ ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
57
+ assert ok
58
+ context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
59
+ assert context != egl.EGL_NO_CONTEXT
60
+ ok = egl.eglMakeCurrent(display, surface, surface, context)
61
+ assert ok
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ _texture_formats = {
66
+ ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
67
+ ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
68
+ ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
69
+ ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
70
+ ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
71
+ ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
72
+ ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
73
+ ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
74
+ }
75
+
76
+ def get_texture_format(dtype, channels):
77
+ return _texture_formats[(np.dtype(dtype).name, int(channels))]
78
+
79
+ #----------------------------------------------------------------------------
80
+
81
+ def prepare_texture_data(image):
82
+ image = np.asarray(image)
83
+ if image.ndim == 2:
84
+ image = image[:, :, np.newaxis]
85
+ if image.dtype.name == 'float64':
86
+ image = image.astype('float32')
87
+ return image
88
+
89
+ #----------------------------------------------------------------------------
90
+
91
+ def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
92
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
93
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
94
+ align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
95
+ image = prepare_texture_data(image)
96
+ height, width, channels = image.shape
97
+ size = zoom * [width, height]
98
+ pos = pos - size * align
99
+ if rint:
100
+ pos = np.rint(pos)
101
+ fmt = get_texture_format(image.dtype, channels)
102
+
103
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
104
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
105
+ gl.glRasterPos2f(pos[0], pos[1])
106
+ gl.glPixelZoom(zoom[0], -zoom[1])
107
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
108
+ gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
109
+ gl.glPopClientAttrib()
110
+ gl.glPopAttrib()
111
+
112
+ #----------------------------------------------------------------------------
113
+
114
+ def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
115
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
116
+ dtype = np.dtype(dtype)
117
+ fmt = get_texture_format(dtype, channels)
118
+ image = np.empty([height, width, channels], dtype=dtype)
119
+
120
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
121
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
122
+ gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
123
+ gl.glPopClientAttrib()
124
+ return np.flipud(image)
125
+
126
+ #----------------------------------------------------------------------------
127
+
128
+ class Texture:
129
+ def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
130
+ self.gl_id = None
131
+ self.bilinear = bilinear
132
+ self.mipmap = mipmap
133
+
134
+ # Determine size and dtype.
135
+ if image is not None:
136
+ image = prepare_texture_data(image)
137
+ self.height, self.width, self.channels = image.shape
138
+ self.dtype = image.dtype
139
+ else:
140
+ assert width is not None and height is not None
141
+ self.width = width
142
+ self.height = height
143
+ self.channels = channels if channels is not None else 3
144
+ self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
145
+
146
+ # Validate size and dtype.
147
+ assert isinstance(self.width, int) and self.width >= 0
148
+ assert isinstance(self.height, int) and self.height >= 0
149
+ assert isinstance(self.channels, int) and self.channels >= 1
150
+ assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
151
+
152
+ # Create texture object.
153
+ self.gl_id = gl.glGenTextures(1)
154
+ with self.bind():
155
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
156
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
157
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
158
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
159
+ self.update(image)
160
+
161
+ def delete(self):
162
+ if self.gl_id is not None:
163
+ gl.glDeleteTextures([self.gl_id])
164
+ self.gl_id = None
165
+
166
+ def __del__(self):
167
+ try:
168
+ self.delete()
169
+ except:
170
+ pass
171
+
172
+ @contextlib.contextmanager
173
+ def bind(self):
174
+ prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
175
+ gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
176
+ yield
177
+ gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
178
+
179
+ def update(self, image):
180
+ if image is not None:
181
+ image = prepare_texture_data(image)
182
+ assert self.is_compatible(image=image)
183
+ with self.bind():
184
+ fmt = get_texture_format(self.dtype, self.channels)
185
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
186
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
187
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
188
+ if self.mipmap:
189
+ gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
190
+ gl.glPopClientAttrib()
191
+
192
+ def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
193
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
194
+ size = zoom * [self.width, self.height]
195
+ with self.bind():
196
+ gl.glPushAttrib(gl.GL_ENABLE_BIT)
197
+ gl.glEnable(gl.GL_TEXTURE_2D)
198
+ draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
199
+ gl.glPopAttrib()
200
+
201
+ def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
202
+ if image is not None:
203
+ if image.ndim != 3:
204
+ return False
205
+ ih, iw, ic = image.shape
206
+ if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
207
+ return False
208
+ if width is not None and self.width != width:
209
+ return False
210
+ if height is not None and self.height != height:
211
+ return False
212
+ if channels is not None and self.channels != channels:
213
+ return False
214
+ if dtype is not None and self.dtype != dtype:
215
+ return False
216
+ return True
217
+
218
+ #----------------------------------------------------------------------------
219
+
220
+ class Framebuffer:
221
+ def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
222
+ self.texture = texture
223
+ self.gl_id = None
224
+ self.gl_color = None
225
+ self.gl_depth_stencil = None
226
+ self.msaa = msaa
227
+
228
+ # Determine size and dtype.
229
+ if texture is not None:
230
+ assert isinstance(self.texture, Texture)
231
+ self.width = texture.width
232
+ self.height = texture.height
233
+ self.channels = texture.channels
234
+ self.dtype = texture.dtype
235
+ else:
236
+ assert width is not None and height is not None
237
+ self.width = width
238
+ self.height = height
239
+ self.channels = channels if channels is not None else 4
240
+ self.dtype = np.dtype(dtype) if dtype is not None else np.float32
241
+
242
+ # Validate size and dtype.
243
+ assert isinstance(self.width, int) and self.width >= 0
244
+ assert isinstance(self.height, int) and self.height >= 0
245
+ assert isinstance(self.channels, int) and self.channels >= 1
246
+ assert width is None or width == self.width
247
+ assert height is None or height == self.height
248
+ assert channels is None or channels == self.channels
249
+ assert dtype is None or dtype == self.dtype
250
+
251
+ # Create framebuffer object.
252
+ self.gl_id = gl.glGenFramebuffers(1)
253
+ with self.bind():
254
+
255
+ # Setup color buffer.
256
+ if self.texture is not None:
257
+ assert self.msaa == 0
258
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
259
+ else:
260
+ fmt = get_texture_format(self.dtype, self.channels)
261
+ self.gl_color = gl.glGenRenderbuffers(1)
262
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
263
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
264
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
265
+
266
+ # Setup depth/stencil buffer.
267
+ self.gl_depth_stencil = gl.glGenRenderbuffers(1)
268
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
269
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
270
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
271
+
272
+ def delete(self):
273
+ if self.gl_id is not None:
274
+ gl.glDeleteFramebuffers([self.gl_id])
275
+ self.gl_id = None
276
+ if self.gl_color is not None:
277
+ gl.glDeleteRenderbuffers(1, [self.gl_color])
278
+ self.gl_color = None
279
+ if self.gl_depth_stencil is not None:
280
+ gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
281
+ self.gl_depth_stencil = None
282
+
283
+ def __del__(self):
284
+ try:
285
+ self.delete()
286
+ except:
287
+ pass
288
+
289
+ @contextlib.contextmanager
290
+ def bind(self):
291
+ prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
292
+ prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
293
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
294
+ if self.width is not None and self.height is not None:
295
+ gl.glViewport(0, 0, self.width, self.height)
296
+ yield
297
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
298
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
299
+
300
+ def blit(self, dst=None):
301
+ assert dst is None or isinstance(dst, Framebuffer)
302
+ with self.bind():
303
+ gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
304
+ gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
305
+
306
+ #----------------------------------------------------------------------------
307
+
308
+ def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
309
+ assert vertices.ndim == 2 and vertices.shape[1] == 2
310
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
311
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
312
+ color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
313
+ alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
314
+
315
+ gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
316
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
317
+ gl.glMatrixMode(gl.GL_MODELVIEW)
318
+ gl.glPushMatrix()
319
+
320
+ gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
321
+ gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
322
+ gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
323
+ gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
324
+ gl.glTranslate(pos[0], pos[1], 0)
325
+ gl.glScale(size[0], size[1], 1)
326
+ gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
327
+ gl.glDrawArrays(mode, 0, vertices.shape[0])
328
+
329
+ gl.glPopMatrix()
330
+ gl.glPopAttrib()
331
+ gl.glPopClientAttrib()
332
+
333
+ #----------------------------------------------------------------------------
334
+
335
+ def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
336
+ assert pos2 is None or size is None
337
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
338
+ pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
339
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
340
+ size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
341
+ pos = pos - size * align
342
+ if rint:
343
+ pos = np.rint(pos)
344
+ rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
345
+ rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
346
+ if np.min(rounding) == 0:
347
+ rounding *= 0
348
+ vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
349
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
350
+
351
+ @functools.lru_cache(maxsize=10000)
352
+ def _setup_rect(rx, ry):
353
+ t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
354
+ s = 1 - np.sin(t); c = 1 - np.cos(t)
355
+ x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
356
+ y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
357
+ v = np.stack([x, y], axis=-1).reshape(-1, 2)
358
+ return v.astype('float32')
359
+
360
+ #----------------------------------------------------------------------------
361
+
362
+ def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
363
+ hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
364
+ vertices = _setup_circle(float(hole))
365
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
366
+
367
+ @functools.lru_cache(maxsize=10000)
368
+ def _setup_circle(hole):
369
+ t = np.linspace(0, np.pi * 2, 128)
370
+ s = np.sin(t); c = np.cos(t)
371
+ v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
372
+ return v.astype('float32')
373
+
374
+ #----------------------------------------------------------------------------
gui_utils/glfw_window.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import time
10
+ import glfw
11
+ import OpenGL.GL as gl
12
+ from . import gl_utils
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ class GlfwWindow: # pylint: disable=too-many-public-methods
17
+ def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
18
+ self._glfw_window = None
19
+ self._drawing_frame = False
20
+ self._frame_start_time = None
21
+ self._frame_delta = 0
22
+ self._fps_limit = None
23
+ self._vsync = None
24
+ self._skip_frames = 0
25
+ self._deferred_show = deferred_show
26
+ self._close_on_esc = close_on_esc
27
+ self._esc_pressed = False
28
+ self._drag_and_drop_paths = None
29
+ self._capture_next_frame = False
30
+ self._captured_frame = None
31
+
32
+ # Create window.
33
+ glfw.init()
34
+ glfw.window_hint(glfw.VISIBLE, False)
35
+ self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
36
+ self._attach_glfw_callbacks()
37
+ self.make_context_current()
38
+
39
+ # Adjust window.
40
+ self.set_vsync(False)
41
+ self.set_window_size(window_width, window_height)
42
+ if not self._deferred_show:
43
+ glfw.show_window(self._glfw_window)
44
+
45
+ def close(self):
46
+ if self._drawing_frame:
47
+ self.end_frame()
48
+ if self._glfw_window is not None:
49
+ glfw.destroy_window(self._glfw_window)
50
+ self._glfw_window = None
51
+ #glfw.terminate() # Commented out to play it nice with other glfw clients.
52
+
53
+ def __del__(self):
54
+ try:
55
+ self.close()
56
+ except:
57
+ pass
58
+
59
+ @property
60
+ def window_width(self):
61
+ return self.content_width
62
+
63
+ @property
64
+ def window_height(self):
65
+ return self.content_height + self.title_bar_height
66
+
67
+ @property
68
+ def content_width(self):
69
+ width, _height = glfw.get_window_size(self._glfw_window)
70
+ return width
71
+
72
+ @property
73
+ def content_height(self):
74
+ _width, height = glfw.get_window_size(self._glfw_window)
75
+ return height
76
+
77
+ @property
78
+ def title_bar_height(self):
79
+ _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
80
+ return top
81
+
82
+ @property
83
+ def monitor_width(self):
84
+ _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
85
+ return width
86
+
87
+ @property
88
+ def monitor_height(self):
89
+ _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
90
+ return height
91
+
92
+ @property
93
+ def frame_delta(self):
94
+ return self._frame_delta
95
+
96
+ def set_title(self, title):
97
+ glfw.set_window_title(self._glfw_window, title)
98
+
99
+ def set_window_size(self, width, height):
100
+ width = min(width, self.monitor_width)
101
+ height = min(height, self.monitor_height)
102
+ glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
103
+ if width == self.monitor_width and height == self.monitor_height:
104
+ self.maximize()
105
+
106
+ def set_content_size(self, width, height):
107
+ self.set_window_size(width, height + self.title_bar_height)
108
+
109
+ def maximize(self):
110
+ glfw.maximize_window(self._glfw_window)
111
+
112
+ def set_position(self, x, y):
113
+ glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
114
+
115
+ def center(self):
116
+ self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
117
+
118
+ def set_vsync(self, vsync):
119
+ vsync = bool(vsync)
120
+ if vsync != self._vsync:
121
+ glfw.swap_interval(1 if vsync else 0)
122
+ self._vsync = vsync
123
+
124
+ def set_fps_limit(self, fps_limit):
125
+ self._fps_limit = int(fps_limit)
126
+
127
+ def should_close(self):
128
+ return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
129
+
130
+ def skip_frame(self):
131
+ self.skip_frames(1)
132
+
133
+ def skip_frames(self, num): # Do not update window for the next N frames.
134
+ self._skip_frames = max(self._skip_frames, int(num))
135
+
136
+ def is_skipping_frames(self):
137
+ return self._skip_frames > 0
138
+
139
+ def capture_next_frame(self):
140
+ self._capture_next_frame = True
141
+
142
+ def pop_captured_frame(self):
143
+ frame = self._captured_frame
144
+ self._captured_frame = None
145
+ return frame
146
+
147
+ def pop_drag_and_drop_paths(self):
148
+ paths = self._drag_and_drop_paths
149
+ self._drag_and_drop_paths = None
150
+ return paths
151
+
152
+ def draw_frame(self): # To be overridden by subclass.
153
+ self.begin_frame()
154
+ # Rendering code goes here.
155
+ self.end_frame()
156
+
157
+ def make_context_current(self):
158
+ if self._glfw_window is not None:
159
+ glfw.make_context_current(self._glfw_window)
160
+
161
+ def begin_frame(self):
162
+ # End previous frame.
163
+ if self._drawing_frame:
164
+ self.end_frame()
165
+
166
+ # Apply FPS limit.
167
+ if self._frame_start_time is not None and self._fps_limit is not None:
168
+ delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
169
+ if delay > 0:
170
+ time.sleep(delay)
171
+ cur_time = time.perf_counter()
172
+ if self._frame_start_time is not None:
173
+ self._frame_delta = cur_time - self._frame_start_time
174
+ self._frame_start_time = cur_time
175
+
176
+ # Process events.
177
+ glfw.poll_events()
178
+
179
+ # Begin frame.
180
+ self._drawing_frame = True
181
+ self.make_context_current()
182
+
183
+ # Initialize GL state.
184
+ gl.glViewport(0, 0, self.content_width, self.content_height)
185
+ gl.glMatrixMode(gl.GL_PROJECTION)
186
+ gl.glLoadIdentity()
187
+ gl.glTranslate(-1, 1, 0)
188
+ gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
189
+ gl.glMatrixMode(gl.GL_MODELVIEW)
190
+ gl.glLoadIdentity()
191
+ gl.glEnable(gl.GL_BLEND)
192
+ gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
193
+
194
+ # Clear.
195
+ gl.glClearColor(0, 0, 0, 1)
196
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
197
+
198
+ def end_frame(self):
199
+ assert self._drawing_frame
200
+ self._drawing_frame = False
201
+
202
+ # Skip frames if requested.
203
+ if self._skip_frames > 0:
204
+ self._skip_frames -= 1
205
+ return
206
+
207
+ # Capture frame if requested.
208
+ if self._capture_next_frame:
209
+ self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
210
+ self._capture_next_frame = False
211
+
212
+ # Update window.
213
+ if self._deferred_show:
214
+ glfw.show_window(self._glfw_window)
215
+ self._deferred_show = False
216
+ glfw.swap_buffers(self._glfw_window)
217
+
218
+ def _attach_glfw_callbacks(self):
219
+ glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
220
+ glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
221
+
222
+ def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
223
+ if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
224
+ self._esc_pressed = True
225
+
226
+ def _glfw_drop_callback(self, _window, paths):
227
+ self._drag_and_drop_paths = paths
228
+
229
+ #----------------------------------------------------------------------------
gui_utils/imgui_utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import contextlib
10
+ import imgui
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
15
+ s = imgui.get_style()
16
+ s.window_padding = [spacing, spacing]
17
+ s.item_spacing = [spacing, spacing]
18
+ s.item_inner_spacing = [spacing, spacing]
19
+ s.columns_min_spacing = spacing
20
+ s.indent_spacing = indent
21
+ s.scrollbar_size = scrollbar
22
+ s.frame_padding = [4, 3]
23
+ s.window_border_size = 1
24
+ s.child_border_size = 1
25
+ s.popup_border_size = 1
26
+ s.frame_border_size = 1
27
+ s.window_rounding = 0
28
+ s.child_rounding = 0
29
+ s.popup_rounding = 3
30
+ s.frame_rounding = 3
31
+ s.scrollbar_rounding = 3
32
+ s.grab_rounding = 3
33
+
34
+ getattr(imgui, f'style_colors_{color_scheme}')(s)
35
+ c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
36
+ c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
37
+ s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ @contextlib.contextmanager
42
+ def grayed_out(cond=True):
43
+ if cond:
44
+ s = imgui.get_style()
45
+ text = s.colors[imgui.COLOR_TEXT_DISABLED]
46
+ grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
47
+ back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
48
+ imgui.push_style_color(imgui.COLOR_TEXT, *text)
49
+ imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
50
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
51
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
52
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
53
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
54
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
55
+ imgui.push_style_color(imgui.COLOR_BUTTON, *back)
56
+ imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
57
+ imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
58
+ imgui.push_style_color(imgui.COLOR_HEADER, *back)
59
+ imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
60
+ imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
61
+ imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
62
+ yield
63
+ imgui.pop_style_color(14)
64
+ else:
65
+ yield
66
+
67
+ #----------------------------------------------------------------------------
68
+
69
+ @contextlib.contextmanager
70
+ def item_width(width=None):
71
+ if width is not None:
72
+ imgui.push_item_width(width)
73
+ yield
74
+ imgui.pop_item_width()
75
+ else:
76
+ yield
77
+
78
+ #----------------------------------------------------------------------------
79
+
80
+ def scoped_by_object_id(method):
81
+ def decorator(self, *args, **kwargs):
82
+ imgui.push_id(str(id(self)))
83
+ res = method(self, *args, **kwargs)
84
+ imgui.pop_id()
85
+ return res
86
+ return decorator
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ def button(label, width=0, enabled=True):
91
+ with grayed_out(not enabled):
92
+ clicked = imgui.button(label, width=width)
93
+ clicked = clicked and enabled
94
+ return clicked
95
+
96
+ #----------------------------------------------------------------------------
97
+
98
+ def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
99
+ expanded = False
100
+ if show:
101
+ if default:
102
+ flags |= imgui.TREE_NODE_DEFAULT_OPEN
103
+ if not enabled:
104
+ flags |= imgui.TREE_NODE_LEAF
105
+ with grayed_out(not enabled):
106
+ expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
107
+ expanded = expanded and enabled
108
+ return expanded, visible
109
+
110
+ #----------------------------------------------------------------------------
111
+
112
+ def popup_button(label, width=0, enabled=True):
113
+ if button(label, width, enabled):
114
+ imgui.open_popup(label)
115
+ opened = imgui.begin_popup(label)
116
+ return opened
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ def input_text(label, value, buffer_length, flags, width=None, help_text=''):
121
+ old_value = value
122
+ color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
123
+ if value == '':
124
+ color[-1] *= 0.5
125
+ with item_width(width):
126
+ imgui.push_style_color(imgui.COLOR_TEXT, *color)
127
+ value = value if value != '' else help_text
128
+ changed, value = imgui.input_text(label, value, buffer_length, flags)
129
+ value = value if value != help_text else ''
130
+ imgui.pop_style_color(1)
131
+ if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
132
+ changed = (value != old_value)
133
+ return changed, value
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ def drag_previous_control(enabled=True):
138
+ dragging = False
139
+ dx = 0
140
+ dy = 0
141
+ if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
142
+ if enabled:
143
+ dragging = True
144
+ dx, dy = imgui.get_mouse_drag_delta()
145
+ imgui.reset_mouse_drag_delta()
146
+ imgui.end_drag_drop_source()
147
+ return dragging, dx, dy
148
+
149
+ #----------------------------------------------------------------------------
150
+
151
+ def drag_button(label, width=0, enabled=True):
152
+ clicked = button(label, width=width, enabled=enabled)
153
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
154
+ return clicked, dragging, dx, dy
155
+
156
+ #----------------------------------------------------------------------------
157
+
158
+ def drag_hidden_window(label, x, y, width, height, enabled=True):
159
+ imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
160
+ imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
161
+ imgui.set_next_window_position(x, y)
162
+ imgui.set_next_window_size(width, height)
163
+ imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
164
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
165
+ imgui.end()
166
+ imgui.pop_style_color(2)
167
+ return dragging, dx, dy
168
+
169
+ #----------------------------------------------------------------------------
gui_utils/imgui_window.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import imgui
11
+ import imgui.integrations.glfw
12
+
13
+ from . import glfw_window
14
+ from . import imgui_utils
15
+ from . import text_utils
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ class ImguiWindow(glfw_window.GlfwWindow):
20
+ def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
21
+ if font is None:
22
+ font = text_utils.get_default_font()
23
+ font_sizes = {int(size) for size in font_sizes}
24
+ super().__init__(title=title, **glfw_kwargs)
25
+
26
+ # Init fields.
27
+ self._imgui_context = None
28
+ self._imgui_renderer = None
29
+ self._imgui_fonts = None
30
+ self._cur_font_size = max(font_sizes)
31
+
32
+ # Delete leftover imgui.ini to avoid unexpected behavior.
33
+ if os.path.isfile('imgui.ini'):
34
+ os.remove('imgui.ini')
35
+
36
+ # Init ImGui.
37
+ self._imgui_context = imgui.create_context()
38
+ self._imgui_renderer = _GlfwRenderer(self._glfw_window)
39
+ self._attach_glfw_callbacks()
40
+ imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
41
+ imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
42
+ self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
43
+ self._imgui_renderer.refresh_font_texture()
44
+
45
+ def close(self):
46
+ self.make_context_current()
47
+ self._imgui_fonts = None
48
+ if self._imgui_renderer is not None:
49
+ self._imgui_renderer.shutdown()
50
+ self._imgui_renderer = None
51
+ if self._imgui_context is not None:
52
+ #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
53
+ self._imgui_context = None
54
+ super().close()
55
+
56
+ def _glfw_key_callback(self, *args):
57
+ super()._glfw_key_callback(*args)
58
+ self._imgui_renderer.keyboard_callback(*args)
59
+
60
+ @property
61
+ def font_size(self):
62
+ return self._cur_font_size
63
+
64
+ @property
65
+ def spacing(self):
66
+ return round(self._cur_font_size * 0.4)
67
+
68
+ def set_font_size(self, target): # Applied on next frame.
69
+ self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
70
+
71
+ def begin_frame(self):
72
+ # Begin glfw frame.
73
+ super().begin_frame()
74
+
75
+ # Process imgui events.
76
+ self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
77
+ if self.content_width > 0 and self.content_height > 0:
78
+ self._imgui_renderer.process_inputs()
79
+
80
+ # Begin imgui frame.
81
+ imgui.new_frame()
82
+ imgui.push_font(self._imgui_fonts[self._cur_font_size])
83
+ imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
84
+
85
+ def end_frame(self):
86
+ imgui.pop_font()
87
+ imgui.render()
88
+ imgui.end_frame()
89
+ self._imgui_renderer.render(imgui.get_draw_data())
90
+ super().end_frame()
91
+
92
+ #----------------------------------------------------------------------------
93
+ # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
94
+
95
+ class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
96
+ def __init__(self, *args, **kwargs):
97
+ super().__init__(*args, **kwargs)
98
+ self.mouse_wheel_multiplier = 1
99
+
100
+ def scroll_callback(self, window, x_offset, y_offset):
101
+ self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
102
+
103
+ #----------------------------------------------------------------------------
gui_utils/text_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import functools
10
+ from typing import Optional
11
+
12
+ import dnnlib
13
+ import numpy as np
14
+ import PIL.Image
15
+ import PIL.ImageFont
16
+ import scipy.ndimage
17
+
18
+ from . import gl_utils
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def get_default_font():
23
+ url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
24
+ return dnnlib.util.open_url(url, return_filename=True)
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ @functools.lru_cache(maxsize=None)
29
+ def get_pil_font(font=None, size=32):
30
+ if font is None:
31
+ font = get_default_font()
32
+ return PIL.ImageFont.truetype(font=font, size=size)
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def get_array(string, *, dropshadow_radius: int=None, **kwargs):
37
+ if dropshadow_radius is not None:
38
+ offset_x = int(np.ceil(dropshadow_radius*2/3))
39
+ offset_y = int(np.ceil(dropshadow_radius*2/3))
40
+ return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
41
+ else:
42
+ return _get_array_priv(string, **kwargs)
43
+
44
+ @functools.lru_cache(maxsize=10000)
45
+ def _get_array_priv(
46
+ string: str, *,
47
+ size: int = 32,
48
+ max_width: Optional[int]=None,
49
+ max_height: Optional[int]=None,
50
+ min_size=10,
51
+ shrink_coef=0.8,
52
+ dropshadow_radius: int=None,
53
+ offset_x: int=None,
54
+ offset_y: int=None,
55
+ **kwargs
56
+ ):
57
+ cur_size = size
58
+ array = None
59
+ while True:
60
+ if dropshadow_radius is not None:
61
+ # separate implementation for dropshadow text rendering
62
+ array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
63
+ else:
64
+ array = _get_array_impl(string, size=cur_size, **kwargs)
65
+ height, width, _ = array.shape
66
+ if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
67
+ break
68
+ cur_size = max(int(cur_size * shrink_coef), min_size)
69
+ return array
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ @functools.lru_cache(maxsize=10000)
74
+ def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
75
+ pil_font = get_pil_font(font=font, size=size)
76
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
77
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
78
+ width = max(line.shape[1] for line in lines)
79
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
80
+ line_spacing = line_pad if line_pad is not None else size // 2
81
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
82
+ mask = np.concatenate(lines, axis=0)
83
+ alpha = mask
84
+ if outline > 0:
85
+ mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
86
+ alpha = mask.astype(np.float32) / 255
87
+ alpha = scipy.ndimage.gaussian_filter(alpha, outline)
88
+ alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
89
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
90
+ alpha = np.maximum(alpha, mask)
91
+ return np.stack([mask, alpha], axis=-1)
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+ @functools.lru_cache(maxsize=10000)
96
+ def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
97
+ assert (offset_x > 0) and (offset_y > 0)
98
+ pil_font = get_pil_font(font=font, size=size)
99
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
100
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
101
+ width = max(line.shape[1] for line in lines)
102
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
103
+ line_spacing = line_pad if line_pad is not None else size // 2
104
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
105
+ mask = np.concatenate(lines, axis=0)
106
+ alpha = mask
107
+
108
+ mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
109
+ alpha = mask.astype(np.float32) / 255
110
+ alpha = scipy.ndimage.gaussian_filter(alpha, radius)
111
+ alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
112
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
113
+ alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
114
+ alpha = np.maximum(alpha, mask)
115
+ return np.stack([mask, alpha], axis=-1)
116
+
117
+ #----------------------------------------------------------------------------
118
+
119
+ @functools.lru_cache(maxsize=10000)
120
+ def get_texture(string, bilinear=True, mipmap=True, **kwargs):
121
+ return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
122
+
123
+ #----------------------------------------------------------------------------
launcher.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+
4
+ import random, shlex, datetime
5
+ import os, sys, subprocess, shutil
6
+ from glob import iglob
7
+
8
+
9
+ def copy_all_python_files(
10
+ source, snapshot_main_dir, code_snapshot_hash, recurse_dirs="fairseq"
11
+ ):
12
+ """
13
+ Copies following files from source to destination:
14
+ a) all *.py files at direct source location.
15
+ b) all fairseq/*.py recursively (default); recurse through comma-separated recurse_dirs
16
+ """
17
+ os.makedirs(snapshot_main_dir, exist_ok=True)
18
+ destination = os.path.join(snapshot_main_dir, code_snapshot_hash)
19
+ assert not os.path.exists(destination), "Code snapshot: {0} alredy exists".format(
20
+ code_snapshot_hash
21
+ )
22
+ os.makedirs(destination)
23
+
24
+ def all_pys(recurse_dirs):
25
+ yield from iglob(os.path.join(source, "*.py"))
26
+ for d in recurse_dirs.split(","):
27
+ yield from iglob(os.path.join(source, d, "**/*.py"), recursive=True)
28
+ yield from iglob(os.path.join(source, d, "**/*.so"), recursive=True)
29
+ yield from iglob(os.path.join(source, d, "**/*.yaml"), recursive=True)
30
+
31
+ for filepath in all_pys(recurse_dirs):
32
+ directory, filename = os.path.split(filepath)
33
+ if directory:
34
+ os.makedirs(os.path.join(destination, directory), exist_ok=True)
35
+ shutil.copy2(
36
+ os.path.join(source, filepath), os.path.join(destination, filepath)
37
+ )
38
+ return destination
39
+
40
+ def launch_cluster(slurm_args, model_args):
41
+ # prepare
42
+ jobname = slurm_args.get('job-name', 'test')
43
+ if slurm_args.get('workplace') is not None:
44
+ os.makedirs(slurm_args.get('workplace'), exist_ok=True)
45
+ if slurm_args.get('workplace') is not None:
46
+ train_log = os.path.join(slurm_args['workplace'], 'train.%A.out')
47
+ train_stderr = os.path.join(slurm_args['workplace'], 'train.%A.stderr.%j')
48
+ else:
49
+ train_log = train_stderr = None
50
+ nodes, gpus = slurm_args.get('nodes', 1), slurm_args.get('gpus', 8)
51
+ if not slurm_args.get('local', False):
52
+ assert (train_log is not None) and (train_stderr is not None)
53
+ # parse slurm
54
+
55
+ destination = ""
56
+ # if slurm_args.get('workplace', None) is not None:
57
+ # # Currently hash is just the current time in ISO format.
58
+ # # Remove colons since they cannot be escaped in POSIX PATH env vars.
59
+ # code_snapshot_hash = datetime.datetime.now().isoformat().replace(":", "_")
60
+ # destination = copy_all_python_files(
61
+ # ".",
62
+ # os.path.join(slurm_args['workplace'], "slurm_snapshot_code"),
63
+ # code_snapshot_hash,
64
+ # 'fairseq',
65
+ # )
66
+ # os.environ["PYTHONPATH"] = destination + ":" + os.environ.get("PYTHONPATH", "")
67
+ # print('creat snapshot at {}'.format(destination))
68
+
69
+ train_cmd = ['python', os.path.join(destination, 'run_train.py'), ]
70
+ train_cmd.extend([f'gpus={nodes * gpus}'])
71
+ train_cmd.extend([f'port={get_random_port()}'])
72
+ train_cmd += model_args
73
+
74
+ base_srun_cmd = [
75
+ 'srun',
76
+ '--job-name', jobname,
77
+ '--output', train_log,
78
+ '--error', train_stderr,
79
+ '--open-mode', 'append',
80
+ '--unbuffered',
81
+ ]
82
+ srun_cmd = base_srun_cmd + train_cmd
83
+ srun_cmd_str = ' '.join(map(shlex.quote, srun_cmd))
84
+ srun_cmd_str = srun_cmd_str + ' &'
85
+
86
+ sbatch_cmd = [
87
+ 'sbatch',
88
+ '--job-name', jobname,
89
+ '--partition', slurm_args.get('partition', 'learnfair'),
90
+ '--gres', 'gpu:volta:{}'.format(gpus),
91
+ '--nodes', str(nodes),
92
+ '--ntasks-per-node', '1',
93
+ '--cpus-per-task', '20',
94
+ '--output', train_log,
95
+ '--error', train_stderr,
96
+ '--open-mode', 'append',
97
+ '--signal', 'B:USR1@180',
98
+ '--time', slurm_args.get('time', '4320'),
99
+ '--mem', slurm_args.get('mem', '500gb'),
100
+ '--exclusive',
101
+ '--exclude', 'learnfair5035,learnfair5289,learnfair5088,learnfair5028,learnfair5032,learnfair5033,learnfair5056,learnfair5098,learnfair5122,learnfair5124,learnfair5156,learnfair5036,learnfair5258,learnfair5205,learnfair5201,learnfair5240,learnfair5087,learnfair5119,learnfair5246,learnfair7474,learnfair7585,learnfair5150,learnfair5166,learnfair5215,learnfair5142,learnfair5070,learnfair5236,learnfair7523'
102
+ ]
103
+ if 'constraint' in slurm_args:
104
+ sbatch_cmd += ['-C', slurm_args.get('constraint')]
105
+ if 'comment' in slurm_args:
106
+ sbatch_cmd += ['--comment', slurm_args.get('comment')]
107
+
108
+ wrapped_cmd = requeue_support() + '\n' + srun_cmd_str + ' \n wait $! \n sleep 610 & \n wait $!'
109
+ sbatch_cmd += ['--wrap', wrapped_cmd]
110
+ sbatch_cmd_str = ' '.join(map(shlex.quote, sbatch_cmd))
111
+
112
+ # start training
113
+ env = os.environ.copy()
114
+ env['OMP_NUM_THREADS'] = '2'
115
+ env['NCCL_SOCKET_IFNAME'] = ''
116
+
117
+ if env.get('SLURM_ARGS', None) is not None:
118
+ del env['SLURM_ARGS']
119
+
120
+ if nodes > 1:
121
+ env['NCCL_SOCKET_IFNAME'] = '^docker0,lo'
122
+ env['NCCL_DEBUG'] = 'INFO'
123
+
124
+ if slurm_args.get('dry-run', False):
125
+ print(sbatch_cmd_str)
126
+
127
+ elif slurm_args.get('local', False):
128
+ assert nodes == 1, 'distributed training cannot be combined with local'
129
+ if 'CUDA_VISIBLE_DEVICES' not in env:
130
+ env['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(gpus)))
131
+ env['NCCL_DEBUG'] = 'INFO'
132
+
133
+ if train_log is not None:
134
+ train_proc = subprocess.Popen(train_cmd, env=env, stdout=subprocess.PIPE)
135
+ tee_proc = subprocess.Popen(['tee', '-a', train_log], stdin=train_proc.stdout)
136
+ train_proc.stdout.close()
137
+ train_proc.wait()
138
+ tee_proc.wait()
139
+ else:
140
+ train_proc = subprocess.Popen(train_cmd, env=env)
141
+ train_proc.wait()
142
+ else:
143
+ with open(train_log, 'a') as train_log_h:
144
+ print(f'running command: {sbatch_cmd_str}\n')
145
+ with subprocess.Popen(sbatch_cmd, stdout=subprocess.PIPE, env=env) as train_proc:
146
+ stdout = train_proc.stdout.read().decode('utf-8')
147
+ print(stdout, file=train_log_h)
148
+ try:
149
+ job_id = int(stdout.rstrip().split()[-1])
150
+ return job_id
151
+ except IndexError:
152
+ return None
153
+
154
+
155
+ def launch(slurm_args, model_args):
156
+ job_id = launch_cluster(slurm_args, model_args)
157
+ if job_id is not None:
158
+ print('Launched {}'.format(job_id))
159
+ else:
160
+ print('Failed.')
161
+
162
+
163
+ def requeue_support():
164
+ return """
165
+ trap_handler () {
166
+ echo "Caught signal: " $1
167
+ # SIGTERM must be bypassed
168
+ if [ "$1" = "TERM" ]; then
169
+ echo "bypass sigterm"
170
+ else
171
+ # Submit a new job to the queue
172
+ echo "Requeuing " $SLURM_JOB_ID
173
+ scontrol requeue $SLURM_JOB_ID
174
+ fi
175
+ }
176
+
177
+
178
+ # Install signal handler
179
+ trap 'trap_handler USR1' USR1
180
+ trap 'trap_handler TERM' TERM
181
+ """
182
+
183
+
184
+ def get_random_port():
185
+ old_state = random.getstate()
186
+ random.seed()
187
+ port = random.randint(10000, 20000)
188
+ random.setstate(old_state)
189
+ return port
legacy.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import pickle
11
+ import re
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ from torch_utils import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def load_network_pkl(f, force_fp16=False):
21
+ data = _LegacyUnpickler(f).load()
22
+
23
+ # Legacy TensorFlow pickle => convert.
24
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25
+ tf_G, tf_D, tf_Gs = data
26
+ G = convert_tf_generator(tf_G)
27
+ D = convert_tf_discriminator(tf_D)
28
+ G_ema = convert_tf_generator(tf_Gs)
29
+ data = dict(G=G, D=D, G_ema=G_ema)
30
+
31
+ # Add missing fields.
32
+ if 'training_set_kwargs' not in data:
33
+ data['training_set_kwargs'] = None
34
+ if 'augment_pipe' not in data:
35
+ data['augment_pipe'] = None
36
+
37
+ # Validate contents.
38
+ # assert isinstance(data['G'], torch.nn.Module)
39
+ # assert isinstance(data['D'], torch.nn.Module)
40
+ # assert isinstance(data['G_ema'], torch.nn.Module)
41
+ # assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42
+ # assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43
+
44
+ # Force FP16.
45
+ if force_fp16:
46
+ for key in ['G', 'D', 'G_ema']:
47
+ old = data[key]
48
+ kwargs = copy.deepcopy(old.init_kwargs)
49
+ if key.startswith('G'):
50
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51
+ kwargs.synthesis_kwargs.num_fp16_res = 4
52
+ kwargs.synthesis_kwargs.conv_clamp = 256
53
+ if key.startswith('D'):
54
+ kwargs.num_fp16_res = 4
55
+ kwargs.conv_clamp = 256
56
+ if kwargs != old.init_kwargs:
57
+ new = type(old)(**kwargs).eval().requires_grad_(False)
58
+ misc.copy_params_and_buffers(old, new, require_all=True)
59
+ data[key] = new
60
+ return data
61
+
62
+ #----------------------------------------------------------------------------
63
+
64
+ class _TFNetworkStub(dnnlib.EasyDict):
65
+ pass
66
+
67
+ class _LegacyUnpickler(pickle.Unpickler):
68
+ def find_class(self, module, name):
69
+ if module == 'dnnlib.tflib.network' and name == 'Network':
70
+ return _TFNetworkStub
71
+ return super().find_class(module, name)
72
+
73
+ #----------------------------------------------------------------------------
74
+
75
+ def _collect_tf_params(tf_net):
76
+ # pylint: disable=protected-access
77
+ tf_params = dict()
78
+ def recurse(prefix, tf_net):
79
+ for name, value in tf_net.variables:
80
+ tf_params[prefix + name] = value
81
+ for name, comp in tf_net.components.items():
82
+ recurse(prefix + name + '/', comp)
83
+ recurse('', tf_net)
84
+ return tf_params
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def _populate_module_params(module, *patterns):
89
+ for name, tensor in misc.named_params_and_buffers(module):
90
+ found = False
91
+ value = None
92
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93
+ match = re.fullmatch(pattern, name)
94
+ if match:
95
+ found = True
96
+ if value_fn is not None:
97
+ value = value_fn(*match.groups())
98
+ break
99
+ try:
100
+ assert found
101
+ if value is not None:
102
+ tensor.copy_(torch.from_numpy(np.array(value)))
103
+ except:
104
+ print(name, list(tensor.shape))
105
+ raise
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def convert_tf_generator(tf_G):
110
+ if tf_G.version < 4:
111
+ raise ValueError('TensorFlow pickle version too low')
112
+
113
+ # Collect kwargs.
114
+ tf_kwargs = tf_G.static_kwargs
115
+ known_kwargs = set()
116
+ def kwarg(tf_name, default=None, none=None):
117
+ known_kwargs.add(tf_name)
118
+ val = tf_kwargs.get(tf_name, default)
119
+ return val if val is not None else none
120
+
121
+ # Convert kwargs.
122
+ kwargs = dnnlib.EasyDict(
123
+ z_dim = kwarg('latent_size', 512),
124
+ c_dim = kwarg('label_size', 0),
125
+ w_dim = kwarg('dlatent_size', 512),
126
+ img_resolution = kwarg('resolution', 1024),
127
+ img_channels = kwarg('num_channels', 3),
128
+ mapping_kwargs = dnnlib.EasyDict(
129
+ num_layers = kwarg('mapping_layers', 8),
130
+ embed_features = kwarg('label_fmaps', None),
131
+ layer_features = kwarg('mapping_fmaps', None),
132
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
133
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
134
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
135
+ ),
136
+ synthesis_kwargs = dnnlib.EasyDict(
137
+ channel_base = kwarg('fmap_base', 16384) * 2,
138
+ channel_max = kwarg('fmap_max', 512),
139
+ num_fp16_res = kwarg('num_fp16_res', 0),
140
+ conv_clamp = kwarg('conv_clamp', None),
141
+ architecture = kwarg('architecture', 'skip'),
142
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
143
+ use_noise = kwarg('use_noise', True),
144
+ activation = kwarg('nonlinearity', 'lrelu'),
145
+ ),
146
+ )
147
+
148
+ # Check for unknown kwargs.
149
+ kwarg('truncation_psi')
150
+ kwarg('truncation_cutoff')
151
+ kwarg('style_mixing_prob')
152
+ kwarg('structure')
153
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
+ if len(unknown_kwargs) > 0:
155
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
+
157
+ # Collect params.
158
+ tf_params = _collect_tf_params(tf_G)
159
+ for name, value in list(tf_params.items()):
160
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
+ if match:
162
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
+ kwargs.synthesis.kwargs.architecture = 'orig'
165
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
+
167
+ # Convert params.
168
+ from training import networks
169
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
170
+ # pylint: disable=unnecessary-lambda
171
+ _populate_module_params(G,
172
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
+ r'.*\.resample_filter', None,
202
+ )
203
+ return G
204
+
205
+ #----------------------------------------------------------------------------
206
+
207
+ def convert_tf_discriminator(tf_D):
208
+ if tf_D.version < 4:
209
+ raise ValueError('TensorFlow pickle version too low')
210
+
211
+ # Collect kwargs.
212
+ tf_kwargs = tf_D.static_kwargs
213
+ known_kwargs = set()
214
+ def kwarg(tf_name, default=None):
215
+ known_kwargs.add(tf_name)
216
+ return tf_kwargs.get(tf_name, default)
217
+
218
+ # Convert kwargs.
219
+ kwargs = dnnlib.EasyDict(
220
+ c_dim = kwarg('label_size', 0),
221
+ img_resolution = kwarg('resolution', 1024),
222
+ img_channels = kwarg('num_channels', 3),
223
+ architecture = kwarg('architecture', 'resnet'),
224
+ channel_base = kwarg('fmap_base', 16384) * 2,
225
+ channel_max = kwarg('fmap_max', 512),
226
+ num_fp16_res = kwarg('num_fp16_res', 0),
227
+ conv_clamp = kwarg('conv_clamp', None),
228
+ cmap_dim = kwarg('mapping_fmaps', None),
229
+ block_kwargs = dnnlib.EasyDict(
230
+ activation = kwarg('nonlinearity', 'lrelu'),
231
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
232
+ freeze_layers = kwarg('freeze_layers', 0),
233
+ ),
234
+ mapping_kwargs = dnnlib.EasyDict(
235
+ num_layers = kwarg('mapping_layers', 0),
236
+ embed_features = kwarg('mapping_fmaps', None),
237
+ layer_features = kwarg('mapping_fmaps', None),
238
+ activation = kwarg('nonlinearity', 'lrelu'),
239
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
240
+ ),
241
+ epilogue_kwargs = dnnlib.EasyDict(
242
+ mbstd_group_size = kwarg('mbstd_group_size', None),
243
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
244
+ activation = kwarg('nonlinearity', 'lrelu'),
245
+ ),
246
+ )
247
+
248
+ # Check for unknown kwargs.
249
+ kwarg('structure')
250
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
251
+ if len(unknown_kwargs) > 0:
252
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
253
+
254
+ # Collect params.
255
+ tf_params = _collect_tf_params(tf_D)
256
+ for name, value in list(tf_params.items()):
257
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
258
+ if match:
259
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
260
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
261
+ kwargs.architecture = 'orig'
262
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
263
+
264
+ # Convert params.
265
+ from training import networks
266
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
267
+ # pylint: disable=unnecessary-lambda
268
+ _populate_module_params(D,
269
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
270
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
271
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
272
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
273
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
274
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
275
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
276
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
277
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
278
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
279
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
280
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
281
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
282
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
283
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
284
+ r'.*\.resample_filter', None,
285
+ )
286
+ return D
287
+
288
+ #----------------------------------------------------------------------------
289
+
290
+ @click.command()
291
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
292
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
293
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
294
+ def convert_network_pickle(source, dest, force_fp16):
295
+ """Convert legacy network pickle into the native PyTorch format.
296
+
297
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
298
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
299
+
300
+ Example:
301
+
302
+ \b
303
+ python legacy.py \\
304
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
305
+ --dest=stylegan2-cat-config-f.pkl
306
+ """
307
+ print(f'Loading "{source}"...')
308
+ with dnnlib.util.open_url(source) as f:
309
+ data = load_network_pkl(f, force_fp16=force_fp16)
310
+ print(f'Saving "{dest}"...')
311
+ with open(dest, 'wb') as f:
312
+ pickle.dump(data, f)
313
+ print('Done.')
314
+
315
+ #----------------------------------------------------------------------------
316
+
317
+ if __name__ == "__main__":
318
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
319
+
320
+ #----------------------------------------------------------------------------
metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
metrics/frechet_inception_distance.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Frechet Inception Distance (FID) from the paper
10
+ "GANs trained by a two time-scale update rule converge to a local Nash
11
+ equilibrium". Matches the original implementation by Heusel et al. at
12
+ https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
+
14
+ import numpy as np
15
+ import scipy.linalg
16
+ from . import metric_utils
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def compute_fid(opts, max_real, num_gen):
21
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
25
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
26
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
27
+
28
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
29
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
30
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
31
+
32
+ if opts.rank != 0:
33
+ return float('nan')
34
+
35
+ m = np.square(mu_gen - mu_real).sum()
36
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
37
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
38
+ return float(fid)
39
+
40
+ #----------------------------------------------------------------------------
metrics/inception_score.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Inception Score (IS) from the paper "Improved techniques for training
10
+ GANs". Matches the original implementation by Salimans et al. at
11
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_is(opts, num_gen, num_splits):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
+
23
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ capture_all=True, max_items=num_gen).get_all()
26
+
27
+ if opts.rank != 0:
28
+ return float('nan'), float('nan')
29
+
30
+ scores = []
31
+ for i in range(num_splits):
32
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34
+ kl = np.mean(np.sum(kl, axis=1))
35
+ scores.append(np.exp(kl))
36
+ return float(np.mean(scores)), float(np.std(scores))
37
+
38
+ #----------------------------------------------------------------------------
metrics/kernel_inception_distance.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
+ GANs". Matches the original implementation by Binkowski et al. at
11
+ https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
+
23
+ real_features = metric_utils.compute_feature_stats_for_dataset(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26
+
27
+ gen_features = metric_utils.compute_feature_stats_for_generator(
28
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30
+
31
+ if opts.rank != 0:
32
+ return float('nan')
33
+
34
+ n = real_features.shape[1]
35
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36
+ t = 0
37
+ for _subset_idx in range(num_subsets):
38
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41
+ b = (x @ y.T / n + 1) ** 3
42
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43
+ kid = t / num_subsets / m
44
+ return float(kid)
45
+
46
+ #----------------------------------------------------------------------------
metrics/metric_main.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import json
12
+ import torch
13
+ import dnnlib
14
+
15
+ from . import metric_utils
16
+ from . import frechet_inception_distance
17
+ from . import kernel_inception_distance
18
+ from . import precision_recall
19
+ from . import perceptual_path_length
20
+ from . import inception_score
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ _metric_dict = dict() # name => fn
25
+
26
+ def register_metric(fn):
27
+ assert callable(fn)
28
+ _metric_dict[fn.__name__] = fn
29
+ return fn
30
+
31
+ def is_valid_metric(metric):
32
+ return metric in _metric_dict
33
+
34
+ def list_valid_metrics():
35
+ return list(_metric_dict.keys())
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
40
+ assert is_valid_metric(metric)
41
+ opts = metric_utils.MetricOptions(**kwargs)
42
+
43
+ # Calculate.
44
+ start_time = time.time()
45
+ results = _metric_dict[metric](opts)
46
+ total_time = time.time() - start_time
47
+
48
+ # Broadcast results.
49
+ for key, value in list(results.items()):
50
+ if opts.num_gpus > 1:
51
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
52
+ torch.distributed.broadcast(tensor=value, src=0)
53
+ value = float(value.cpu())
54
+ results[key] = value
55
+
56
+ # Decorate with metadata.
57
+ return dnnlib.EasyDict(
58
+ results = dnnlib.EasyDict(results),
59
+ metric = metric,
60
+ total_time = total_time,
61
+ total_time_str = dnnlib.util.format_time(total_time),
62
+ num_gpus = opts.num_gpus,
63
+ )
64
+
65
+ #----------------------------------------------------------------------------
66
+
67
+ def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
68
+ metric = result_dict['metric']
69
+ assert is_valid_metric(metric)
70
+ if run_dir is not None and snapshot_pkl is not None:
71
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
72
+
73
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
74
+ print(jsonl_line)
75
+ if run_dir is not None and os.path.isdir(run_dir):
76
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
77
+ f.write(jsonl_line + '\n')
78
+
79
+ #----------------------------------------------------------------------------
80
+ # Primary metrics.
81
+
82
+ @register_metric
83
+ def fid50k_full(opts):
84
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
85
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
86
+ return dict(fid50k_full=fid)
87
+
88
+ @register_metric
89
+ def kid50k_full(opts):
90
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
91
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
92
+ return dict(kid50k_full=kid)
93
+
94
+ @register_metric
95
+ def pr50k3_full(opts):
96
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
97
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
98
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
99
+
100
+ @register_metric
101
+ def ppl2_wend(opts):
102
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
103
+ return dict(ppl2_wend=ppl)
104
+
105
+ @register_metric
106
+ def is50k(opts):
107
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
108
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
109
+ return dict(is50k_mean=mean, is50k_std=std)
110
+
111
+ #----------------------------------------------------------------------------
112
+ # Legacy metrics.
113
+
114
+ @register_metric
115
+ def fid50k(opts):
116
+ opts.dataset_kwargs.update(max_size=None)
117
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
118
+ return dict(fid50k=fid)
119
+
120
+ @register_metric
121
+ def kid50k(opts):
122
+ opts.dataset_kwargs.update(max_size=None)
123
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
124
+ return dict(kid50k=kid)
125
+
126
+ @register_metric
127
+ def pr50k3(opts):
128
+ opts.dataset_kwargs.update(max_size=None)
129
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
130
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
131
+
132
+ @register_metric
133
+ def ppl_zfull(opts):
134
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
135
+ return dict(ppl_zfull=ppl)
136
+
137
+ @register_metric
138
+ def ppl_wfull(opts):
139
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
140
+ return dict(ppl_wfull=ppl)
141
+
142
+ @register_metric
143
+ def ppl_zend(opts):
144
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
145
+ return dict(ppl_zend=ppl)
146
+
147
+ @register_metric
148
+ def ppl_wend(opts):
149
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
150
+ return dict(ppl_wend=ppl)
151
+
152
+ #----------------------------------------------------------------------------
metrics/metric_utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ import os
12
+ import time
13
+ import hashlib
14
+ import pickle
15
+ import copy
16
+ import uuid
17
+ import numpy as np
18
+ import torch
19
+ import dnnlib
20
+ import glob
21
+ #----------------------------------------------------------------------------
22
+
23
+ class MetricOptions:
24
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
25
+ assert 0 <= rank < num_gpus
26
+ self.G = G
27
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
28
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
29
+ self.num_gpus = num_gpus
30
+ self.rank = rank
31
+ self.device = device if device is not None else torch.device('cuda', rank)
32
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
33
+ self.cache = cache
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ _feature_detector_cache = dict()
38
+
39
+ def get_feature_detector_name(url):
40
+ return os.path.splitext(url.split('/')[-1])[0]
41
+
42
+ def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
43
+ assert 0 <= rank < num_gpus
44
+ key = (url, device)
45
+ if key not in _feature_detector_cache:
46
+ is_leader = (rank == 0)
47
+ if not is_leader and num_gpus > 1:
48
+ torch.distributed.barrier() # leader goes first
49
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
50
+ _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
51
+ if is_leader and num_gpus > 1:
52
+ torch.distributed.barrier() # others follow
53
+ return _feature_detector_cache[key]
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ class FeatureStats:
58
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
59
+ self.capture_all = capture_all
60
+ self.capture_mean_cov = capture_mean_cov
61
+ self.max_items = max_items
62
+ self.num_items = 0
63
+ self.num_features = None
64
+ self.all_features = None
65
+ self.raw_mean = None
66
+ self.raw_cov = None
67
+
68
+ def set_num_features(self, num_features):
69
+ if self.num_features is not None:
70
+ assert num_features == self.num_features
71
+ else:
72
+ self.num_features = num_features
73
+ self.all_features = []
74
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
75
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
76
+
77
+ def is_full(self):
78
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
79
+
80
+ def append(self, x):
81
+ x = np.asarray(x, dtype=np.float32)
82
+ assert x.ndim == 2
83
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
84
+ if self.num_items >= self.max_items:
85
+ return
86
+ x = x[:self.max_items - self.num_items]
87
+
88
+ self.set_num_features(x.shape[1])
89
+ self.num_items += x.shape[0]
90
+ if self.capture_all:
91
+ self.all_features.append(x)
92
+ if self.capture_mean_cov:
93
+ x64 = x.astype(np.float64)
94
+ self.raw_mean += x64.sum(axis=0)
95
+ self.raw_cov += x64.T @ x64
96
+
97
+ def append_torch(self, x, num_gpus=1, rank=0):
98
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
99
+ assert 0 <= rank < num_gpus
100
+ if num_gpus > 1:
101
+ ys = []
102
+ for src in range(num_gpus):
103
+ y = x.clone()
104
+ torch.distributed.broadcast(y, src=src)
105
+ ys.append(y)
106
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
107
+ self.append(x.cpu().numpy())
108
+
109
+ def get_all(self):
110
+ assert self.capture_all
111
+ return np.concatenate(self.all_features, axis=0)
112
+
113
+ def get_all_torch(self):
114
+ return torch.from_numpy(self.get_all())
115
+
116
+ def get_mean_cov(self):
117
+ assert self.capture_mean_cov
118
+ mean = self.raw_mean / self.num_items
119
+ cov = self.raw_cov / self.num_items
120
+ cov = cov - np.outer(mean, mean)
121
+ return mean, cov
122
+
123
+ def save(self, pkl_file):
124
+ with open(pkl_file, 'wb') as f:
125
+ pickle.dump(self.__dict__, f)
126
+
127
+ @staticmethod
128
+ def load(pkl_file):
129
+ with open(pkl_file, 'rb') as f:
130
+ s = dnnlib.EasyDict(pickle.load(f))
131
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
132
+ obj.__dict__.update(s)
133
+ return obj
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ class ProgressMonitor:
138
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
139
+ self.tag = tag
140
+ self.num_items = num_items
141
+ self.verbose = verbose
142
+ self.flush_interval = flush_interval
143
+ self.progress_fn = progress_fn
144
+ self.pfn_lo = pfn_lo
145
+ self.pfn_hi = pfn_hi
146
+ self.pfn_total = pfn_total
147
+ self.start_time = time.time()
148
+ self.batch_time = self.start_time
149
+ self.batch_items = 0
150
+ if self.progress_fn is not None:
151
+ self.progress_fn(self.pfn_lo, self.pfn_total)
152
+
153
+ def update(self, cur_items):
154
+ assert (self.num_items is None) or (cur_items <= self.num_items)
155
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
156
+ return
157
+ cur_time = time.time()
158
+ total_time = cur_time - self.start_time
159
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
160
+ if (self.verbose) and (self.tag is not None):
161
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
162
+ self.batch_time = cur_time
163
+ self.batch_items = cur_items
164
+
165
+ if (self.progress_fn is not None) and (self.num_items is not None):
166
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
167
+
168
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
169
+ return ProgressMonitor(
170
+ tag = tag,
171
+ num_items = num_items,
172
+ flush_interval = flush_interval,
173
+ verbose = self.verbose,
174
+ progress_fn = self.progress_fn,
175
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
176
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
177
+ pfn_total = self.pfn_total,
178
+ )
179
+
180
+ #----------------------------------------------------------------------------
181
+
182
+ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
183
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
184
+ if data_loader_kwargs is None:
185
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
186
+
187
+ # Try to lookup from cache.
188
+ cache_file = None
189
+ if opts.cache:
190
+ # Choose cache file name.
191
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
192
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
193
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
194
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
195
+
196
+ # Check if the file exists (all processes must agree).
197
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
198
+ if opts.num_gpus > 1:
199
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
200
+ torch.distributed.broadcast(tensor=flag, src=0)
201
+ flag = (float(flag.cpu()) != 0)
202
+
203
+ # Load.
204
+ if flag:
205
+ return FeatureStats.load(cache_file)
206
+
207
+ # Initialize.
208
+ num_items = len(dataset)
209
+ if max_items is not None:
210
+ num_items = min(num_items, max_items)
211
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
212
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
213
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
214
+
215
+ # Main loop.
216
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
217
+ for images, _labels, _indices in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
218
+ if images.shape[1] == 1:
219
+ images = images.repeat([1, 3, 1, 1])
220
+ features = detector(images.to(opts.device), **detector_kwargs)
221
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
222
+ progress.update(stats.num_items)
223
+
224
+ # Save to cache.
225
+ if cache_file is not None and opts.rank == 0:
226
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
227
+ temp_file = cache_file + '.' + uuid.uuid4().hex
228
+ stats.save(temp_file)
229
+ os.replace(temp_file, cache_file) # atomic
230
+ return stats
231
+
232
+ #----------------------------------------------------------------------------
233
+
234
+ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
235
+ if batch_gen is None:
236
+ batch_gen = min(batch_size, 4)
237
+ assert batch_size % batch_gen == 0
238
+
239
+ # Setup generator and load labels.
240
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
241
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
242
+
243
+ # HACK:
244
+ # other_data = "/checkpoint/jgu/space/gan/ffhq/giraffe_results/gen_images"
245
+ # other_data = "/checkpoint/jgu/space/gan/cars/gen_images_380000"
246
+ # other_data = "/private/home/jgu/work/pi-GAN/Baselines/FFHQEvalOutput2"
247
+ # other_data = "/private/home/jgu/work/pi-GAN/Baselines/AFHQEvalOutput"
248
+ # other_data = sorted(glob.glob(f'{other_data}/*.jpg'))
249
+ # other_data = '/private/home/jgu/work/giraffe/out/afhq256/fid_images.npy'
250
+ # other_images = np.load(other_data)
251
+ # from fairseq import pdb;pdb.set_trace()
252
+ # print(f'other data size = {len(other_data)}')
253
+ other_data = None
254
+
255
+ # Image generation func.
256
+ def run_generator(z, c):
257
+ # from fairseq import pdb;pdb.set_trace()
258
+ if hasattr(G, 'get_final_output'):
259
+ img = G.get_final_output(z=z, c=c, **opts.G_kwargs)
260
+ else:
261
+ img = G(z=z, c=c, **opts.G_kwargs)
262
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
263
+ return img
264
+
265
+ # JIT.
266
+ if jit:
267
+ z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
268
+ c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
269
+ run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
270
+
271
+ # Initialize.
272
+ stats = FeatureStats(**stats_kwargs)
273
+ assert stats.max_items is not None
274
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
275
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
276
+
277
+ # Main loop.
278
+ till_now = 0
279
+ while not stats.is_full():
280
+ images = []
281
+ if other_data is None:
282
+ for _i in range(batch_size // batch_gen):
283
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
284
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
285
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
286
+ img = run_generator(z, c)
287
+ images.append(img)
288
+ images = torch.cat(images)
289
+ else:
290
+ batch_idxs = [((till_now + i) * opts.num_gpus + opts.rank) % len(other_images) for i in range(batch_size)]
291
+ import imageio
292
+ till_now += batch_size
293
+ images = other_images[batch_idxs]
294
+ images = torch.from_numpy(images).to(opts.device)
295
+ # images = np.stack([imageio.imread(other_data[i % len(other_data)]) for i in batch_idxs], axis=0)
296
+ # images = torch.from_numpy(images).to(opts.device).permute(0,3,1,2)
297
+
298
+ if images.shape[1] == 1:
299
+ images = images.repeat([1, 3, 1, 1])
300
+ features = detector(images, **detector_kwargs)
301
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
302
+ progress.update(stats.num_items)
303
+ return stats
304
+
305
+ #----------------------------------------------------------------------------
metrics/perceptual_path_length.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10
+ Architecture for Generative Adversarial Networks". Matches the original
11
+ implementation by Karras et al. at
12
+ https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13
+
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ from . import metric_utils
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ # Spherical interpolation of a batch of vectors.
23
+ def slerp(a, b, t):
24
+ a = a / a.norm(dim=-1, keepdim=True)
25
+ b = b / b.norm(dim=-1, keepdim=True)
26
+ d = (a * b).sum(dim=-1, keepdim=True)
27
+ p = t * torch.acos(d)
28
+ c = b - d * a
29
+ c = c / c.norm(dim=-1, keepdim=True)
30
+ d = a * torch.cos(p) + c * torch.sin(p)
31
+ d = d / d.norm(dim=-1, keepdim=True)
32
+ return d
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ class PPLSampler(torch.nn.Module):
37
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
38
+ assert space in ['z', 'w']
39
+ assert sampling in ['full', 'end']
40
+ super().__init__()
41
+ self.G = copy.deepcopy(G)
42
+ self.G_kwargs = G_kwargs
43
+ self.epsilon = epsilon
44
+ self.space = space
45
+ self.sampling = sampling
46
+ self.crop = crop
47
+ self.vgg16 = copy.deepcopy(vgg16)
48
+
49
+ def forward(self, c):
50
+ # Generate random latents and interpolation t-values.
51
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
52
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
53
+
54
+ # Interpolate in W or Z.
55
+ if self.space == 'w':
56
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
57
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
58
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
59
+ else: # space == 'z'
60
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
61
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
62
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
63
+
64
+ # Randomize noise buffers.
65
+ for name, buf in self.G.named_buffers():
66
+ if name.endswith('.noise_const'):
67
+ buf.copy_(torch.randn_like(buf))
68
+
69
+ # Generate images.
70
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
71
+
72
+ # Center crop.
73
+ if self.crop:
74
+ assert img.shape[2] == img.shape[3]
75
+ c = img.shape[2] // 8
76
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
77
+
78
+ # Downsample to 256x256.
79
+ factor = self.G.img_resolution // 256
80
+ if factor > 1:
81
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
82
+
83
+ # Scale dynamic range from [-1,1] to [0,255].
84
+ img = (img + 1) * (255 / 2)
85
+ if self.G.img_channels == 1:
86
+ img = img.repeat([1, 3, 1, 1])
87
+
88
+ # Evaluate differential LPIPS.
89
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
90
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
91
+ return dist
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+ def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
96
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
97
+ vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
98
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99
+
100
+ # Setup sampler.
101
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102
+ sampler.eval().requires_grad_(False).to(opts.device)
103
+ if jit:
104
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
105
+ sampler = torch.jit.trace(sampler, [c], check_trace=False)
106
+
107
+ # Sampling loop.
108
+ dist = []
109
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
110
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
111
+ progress.update(batch_start)
112
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
113
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
114
+ x = sampler(c)
115
+ for src in range(opts.num_gpus):
116
+ y = x.clone()
117
+ if opts.num_gpus > 1:
118
+ torch.distributed.broadcast(y, src=src)
119
+ dist.append(y)
120
+ progress.update(num_samples)
121
+
122
+ # Compute PPL.
123
+ if opts.rank != 0:
124
+ return float('nan')
125
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
126
+ lo = np.percentile(dist, 1, interpolation='lower')
127
+ hi = np.percentile(dist, 99, interpolation='higher')
128
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
129
+ return float(ppl)
130
+
131
+ #----------------------------------------------------------------------------
metrics/precision_recall.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Precision/Recall (PR) from the paper "Improved Precision and Recall
10
+ Metric for Assessing Generative Models". Matches the original implementation
11
+ by Kynkaanniemi et al. at
12
+ https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13
+
14
+ import torch
15
+ from . import metric_utils
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20
+ assert 0 <= rank < num_gpus
21
+ num_cols = col_features.shape[0]
22
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24
+ dist_batches = []
25
+ for col_batch in col_batches[rank :: num_gpus]:
26
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27
+ for src in range(num_gpus):
28
+ dist_broadcast = dist_batch.clone()
29
+ if num_gpus > 1:
30
+ torch.distributed.broadcast(dist_broadcast, src=src)
31
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
38
+ detector_kwargs = dict(return_features=True)
39
+
40
+ real_features = metric_utils.compute_feature_stats_for_dataset(
41
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43
+
44
+ gen_features = metric_utils.compute_feature_stats_for_generator(
45
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47
+
48
+ results = dict()
49
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50
+ kth = []
51
+ for manifold_batch in manifold.split(row_batch_size):
52
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54
+ kth = torch.cat(kth) if opts.rank == 0 else None
55
+ pred = []
56
+ for probes_batch in probes.split(row_batch_size):
57
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60
+ return results['precision'], results['recall']
61
+
62
+ #----------------------------------------------------------------------------
renderer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+
4
+ """Wrap the generator to render a sequence of images"""
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from torch import random
9
+ import tqdm
10
+ import copy
11
+ import trimesh
12
+
13
+
14
+ class Renderer(object):
15
+
16
+ def __init__(self, generator, discriminator=None, program=None):
17
+ self.generator = generator
18
+ self.discriminator = discriminator
19
+ self.sample_tmp = 0.65
20
+ self.program = program
21
+ self.seed = 0
22
+
23
+ if (program is not None) and (len(program.split(':')) == 2):
24
+ from training.dataset import ImageFolderDataset
25
+ self.image_data = ImageFolderDataset(program.split(':')[1])
26
+ self.program = program.split(':')[0]
27
+ else:
28
+ self.image_data = None
29
+
30
+ def set_random_seed(self, seed):
31
+ self.seed = seed
32
+ torch.manual_seed(seed)
33
+ np.random.seed(seed)
34
+
35
+ def __call__(self, *args, **kwargs):
36
+ self.generator.eval() # eval mode...
37
+
38
+ if self.program is None:
39
+ if hasattr(self.generator, 'get_final_output'):
40
+ return self.generator.get_final_output(*args, **kwargs)
41
+ return self.generator(*args, **kwargs)
42
+
43
+ if self.image_data is not None:
44
+ batch_size = 1
45
+ indices = (np.random.rand(batch_size) * len(self.image_data)).tolist()
46
+ rimages = np.stack([self.image_data._load_raw_image(int(i)) for i in indices], 0)
47
+ rimages = torch.from_numpy(rimages).float().to(kwargs['z'].device) / 127.5 - 1
48
+ kwargs['img'] = rimages
49
+
50
+ outputs = getattr(self, f"render_{self.program}")(*args, **kwargs)
51
+
52
+ if self.image_data is not None:
53
+ imgs = outputs if not isinstance(outputs, tuple) else outputs[0]
54
+ size = imgs[0].size(-1)
55
+ rimg = F.interpolate(rimages, (size, size), mode='bicubic', align_corners=False)
56
+ imgs = [torch.cat([img, rimg], 0) for img in imgs]
57
+ outputs = imgs if not isinstance(outputs, tuple) else (imgs, outputs[1])
58
+ return outputs
59
+
60
+ def get_additional_params(self, ws, t=0):
61
+ gen = self.generator.synthesis
62
+ batch_size = ws.size(0)
63
+
64
+ kwargs = {}
65
+ if not hasattr(gen, 'get_latent_codes'):
66
+ return kwargs
67
+
68
+ s_val, t_val, r_val = [[0, 0, 0]], [[0.5, 0.5, 0.5]], [0.]
69
+ # kwargs["transformations"] = gen.get_transformations(batch_size=batch_size, mode=[s_val, t_val, r_val], device=ws.device)
70
+ # kwargs["bg_rotation"] = gen.get_bg_rotation(batch_size, device=ws.device)
71
+ # kwargs["light_dir"] = gen.get_light_dir(batch_size, device=ws.device)
72
+ kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device)
73
+ kwargs["camera_matrices"] = self.get_camera_traj(t, ws.size(0), device=ws.device)
74
+ return kwargs
75
+
76
+ def get_camera_traj(self, t, batch_size=1, traj_type='pigan', device='cpu'):
77
+ gen = self.generator.synthesis
78
+ if traj_type == 'pigan':
79
+ range_u, range_v = gen.C.range_u, gen.C.range_v
80
+ pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2
81
+ yaw = 0.4 * np.sin(t * 2 * np.pi)
82
+ u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
83
+ v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
84
+ cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device)
85
+ else:
86
+ raise NotImplementedError
87
+ return cam
88
+
89
+ def render_rotation_camera(self, *args, **kwargs):
90
+ batch_size, n_steps = 2, kwargs["n_steps"]
91
+ gen = self.generator.synthesis
92
+
93
+ if 'img' not in kwargs:
94
+ ws = self.generator.mapping(*args, **kwargs)
95
+ else:
96
+ ws, _ = self.generator.encoder(kwargs['img'])
97
+ # ws = ws.repeat(batch_size, 1, 1)
98
+
99
+ # kwargs["not_render_background"] = True
100
+ if hasattr(gen, 'get_latent_codes'):
101
+ kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device)
102
+ kwargs.pop('img', None)
103
+
104
+ out = []
105
+ cameras = []
106
+ relatve_range_u = kwargs['relative_range_u']
107
+ u_samples = np.linspace(relatve_range_u[0], relatve_range_u[1], n_steps)
108
+ for step in tqdm.tqdm(range(n_steps)):
109
+ # Set Camera
110
+ u = u_samples[step]
111
+ kwargs["camera_matrices"] = gen.get_camera(batch_size=batch_size, mode=[u, 0.5, 0.5], device=ws.device)
112
+ cameras.append(gen.get_camera(batch_size=batch_size, mode=[u, 0.5, 0.5], device=ws.device))
113
+ with torch.no_grad():
114
+ out_i = gen(ws, **kwargs)
115
+ if isinstance(out_i, dict):
116
+ out_i = out_i['img']
117
+ out.append(out_i)
118
+
119
+ if 'return_cameras' in kwargs and kwargs["return_cameras"]:
120
+ return out, cameras
121
+ else:
122
+ return out
123
+
124
+ def render_rotation_camera3(self, styles=None, *args, **kwargs):
125
+ gen = self.generator.synthesis
126
+ n_steps = 36 # 120
127
+
128
+ if styles is None:
129
+ batch_size = 2
130
+ if 'img' not in kwargs:
131
+ ws = self.generator.mapping(*args, **kwargs)
132
+ else:
133
+ ws = self.generator.encoder(kwargs['img'])['ws']
134
+ # ws = ws.repeat(batch_size, 1, 1)
135
+ else:
136
+ ws = styles
137
+ batch_size = ws.size(0)
138
+
139
+ # kwargs["not_render_background"] = True
140
+ # Get Random codes and bg rotation
141
+ self.sample_tmp = 0.72
142
+ if hasattr(gen, 'get_latent_codes'):
143
+ kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device)
144
+ kwargs.pop('img', None)
145
+
146
+ # if getattr(gen, "use_noise", False):
147
+ # from dnnlib.geometry import extract_geometry
148
+ # kwargs['meshes'] = {}
149
+ # low_res, high_res = gen.resolution_vol, gen.img_resolution
150
+ # res = low_res * 2
151
+ # while res <= high_res:
152
+ # kwargs['meshes'][res] = [trimesh.Trimesh(*extract_geometry(gen, ws, resolution=res, threshold=30.))]
153
+ # kwargs['meshes'][res] += [
154
+ # torch.randn(len(kwargs['meshes'][res][0].vertices),
155
+ # 2, device=ws.device)[kwargs['meshes'][res][0].faces]]
156
+ # res = res * 2
157
+ # if getattr(gen, "use_noise", False):
158
+ # kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=2048, return_noise=True, sphere_noise=True)
159
+ # if getattr(gen, "use_voxel_noise", False):
160
+ # kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=128, return_noise=True)
161
+ kwargs['noise_mode'] = 'const'
162
+
163
+ out = []
164
+ tspace = np.linspace(0, 1, n_steps)
165
+ range_u, range_v = gen.C.range_u, gen.C.range_v
166
+
167
+ for step in tqdm.tqdm(range(n_steps)):
168
+ t = tspace[step]
169
+ pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2
170
+ yaw = 0.4 * np.sin(t * 2 * np.pi)
171
+ u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
172
+ v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
173
+
174
+ kwargs["camera_matrices"] = gen.get_camera(
175
+ batch_size=batch_size, mode=[u, v, t], device=ws.device)
176
+
177
+ with torch.no_grad():
178
+ out_i = gen(ws, **kwargs)
179
+ if isinstance(out_i, dict):
180
+ out_i = out_i['img']
181
+ out.append(out_i)
182
+ return out
183
+
184
+ def render_rotation_both(self, *args, **kwargs):
185
+ gen = self.generator.synthesis
186
+ batch_size, n_steps = 1, 36
187
+ if 'img' not in kwargs:
188
+ ws = self.generator.mapping(*args, **kwargs)
189
+ else:
190
+ ws, _ = self.generator.encoder(kwargs['img'])
191
+ ws = ws.repeat(batch_size, 1, 1)
192
+
193
+ # kwargs["not_render_background"] = True
194
+ # Get Random codes and bg rotation
195
+ kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device)
196
+ kwargs.pop('img', None)
197
+
198
+ out = []
199
+ tspace = np.linspace(0, 1, n_steps)
200
+ range_u, range_v = gen.C.range_u, gen.C.range_v
201
+
202
+ for step in tqdm.tqdm(range(n_steps)):
203
+ t = tspace[step]
204
+ pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2
205
+ yaw = 0.4 * np.sin(t * 2 * np.pi)
206
+ u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
207
+ v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
208
+
209
+ kwargs["camera_matrices"] = gen.get_camera(
210
+ batch_size=batch_size, mode=[u, v, 0.5], device=ws.device)
211
+
212
+ with torch.no_grad():
213
+ out_i = gen(ws, **kwargs)
214
+ if isinstance(out_i, dict):
215
+ out_i = out_i['img']
216
+
217
+ kwargs_n = copy.deepcopy(kwargs)
218
+ kwargs_n.update({'render_option': 'early,no_background,up64,depth,normal'})
219
+ out_n = gen(ws, **kwargs_n)
220
+ out_n = F.interpolate(out_n,
221
+ size=(out_i.size(-1), out_i.size(-1)),
222
+ mode='bicubic', align_corners=True)
223
+ out_i = torch.cat([out_i, out_n], 0)
224
+ out.append(out_i)
225
+ return out
226
+
227
+ def render_rotation_grid(self, styles=None, return_cameras=False, *args, **kwargs):
228
+ gen = self.generator.synthesis
229
+ if styles is None:
230
+ batch_size = 1
231
+ ws = self.generator.mapping(*args, **kwargs)
232
+ ws = ws.repeat(batch_size, 1, 1)
233
+ else:
234
+ ws = styles
235
+ batch_size = ws.size(0)
236
+
237
+ kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device)
238
+ kwargs.pop('img', None)
239
+
240
+ if getattr(gen, "use_voxel_noise", False):
241
+ kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=128, return_noise=True)
242
+
243
+ out = []
244
+ cameras = []
245
+ range_u, range_v = gen.C.range_u, gen.C.range_v
246
+
247
+ a_steps, b_steps = 6, 3
248
+ aspace = np.linspace(-0.4, 0.4, a_steps)
249
+ bspace = np.linspace(-0.2, 0.2, b_steps) * -1
250
+ for b in tqdm.tqdm(range(b_steps)):
251
+ for a in range(a_steps):
252
+ t_a = aspace[a]
253
+ t_b = bspace[b]
254
+ camera_mat = gen.camera_matrix.repeat(batch_size, 1, 1).to(ws.device)
255
+ loc_x = np.cos(t_b) * np.cos(t_a)
256
+ loc_y = np.cos(t_b) * np.sin(t_a)
257
+ loc_z = np.sin(t_b)
258
+ loc = torch.tensor([[loc_x, loc_y, loc_z]], dtype=torch.float32).to(ws.device)
259
+ from dnnlib.camera import look_at
260
+ R = look_at(loc)
261
+ RT = torch.eye(4).reshape(1, 4, 4).repeat(batch_size, 1, 1)
262
+ RT[:, :3, :3] = R
263
+ RT[:, :3, -1] = loc
264
+
265
+ world_mat = RT.to(ws.device)
266
+ #kwargs["camera_matrices"] = gen.get_camera(
267
+ # batch_size=batch_size, mode=[u, v, 0.5], device=ws.device)
268
+ kwargs["camera_matrices"] = (camera_mat, world_mat, "random", None)
269
+
270
+ with torch.no_grad():
271
+ out_i = gen(ws, **kwargs)
272
+ if isinstance(out_i, dict):
273
+ out_i = out_i['img']
274
+
275
+ # kwargs_n = copy.deepcopy(kwargs)
276
+ # kwargs_n.update({'render_option': 'early,no_background,up64,depth,normal'})
277
+ # out_n = gen(ws, **kwargs_n)
278
+ # out_n = F.interpolate(out_n,
279
+ # size=(out_i.size(-1), out_i.size(-1)),
280
+ # mode='bicubic', align_corners=True)
281
+ # out_i = torch.cat([out_i, out_n], 0)
282
+ out.append(out_i)
283
+
284
+ if return_cameras:
285
+ return out, cameras
286
+ else:
287
+ return out
288
+
289
+ def render_rotation_camera_grid(self, *args, **kwargs):
290
+ batch_size, n_steps = 1, 60
291
+ gen = self.generator.synthesis
292
+ bbox_generator = self.generator.synthesis.boundingbox_generator
293
+
294
+ ws = self.generator.mapping(*args, **kwargs)
295
+ ws = ws.repeat(batch_size, 1, 1)
296
+
297
+ # Get Random codes and bg rotation
298
+ kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device)
299
+ del kwargs['render_option']
300
+
301
+ out = []
302
+ for v in [0.15, 0.5, 1.05]:
303
+ for step in tqdm.tqdm(range(n_steps)):
304
+ # Set Camera
305
+ u = step * 1.0 / (n_steps - 1) - 1.0
306
+ kwargs["camera_matrices"] = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=ws.device)
307
+ with torch.no_grad():
308
+ out_i = gen(ws, render_option=None, **kwargs)
309
+ if isinstance(out_i, dict):
310
+ out_i = out_i['img']
311
+ # option_n = 'early,no_background,up64,depth,direct_depth'
312
+ # option_n = 'early,up128,no_background,depth,normal'
313
+ # out_n = gen(ws, render_option=option_n, **kwargs)
314
+ # out_n = F.interpolate(out_n,
315
+ # size=(out_i.size(-1), out_i.size(-1)),
316
+ # mode='bicubic', align_corners=True)
317
+ # out_i = torch.cat([out_i, out_n], 0)
318
+
319
+ out.append(out_i)
320
+
321
+ # out += out[::-1]
322
+ return out
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.0.3
2
+ click==8.0.3
3
+ clip-by-openai==1.1
4
+ einops==0.3.0
5
+ glfw==2.5.0
6
+ gradio==2.8.13
7
+ imageio==2.9.0
8
+ imgui==1.4.1
9
+ kornia==0.5.10
10
+ lmdb==0.98
11
+ lpips==0.1.4
12
+ matplotlib==3.4.3
13
+ numpy==1.21.2
14
+ hydra-core==1.1
15
+ opencv_python_headless==4.5.1.48
16
+ Pillow==9.0.1
17
+ psutil==5.8.0
18
+ PyMCubes==0.1.2
19
+ PyOpenGL==3.1.6
20
+ pyspng==0.1.0
21
+ requests==2.26.0
22
+ scipy==1.7.1
23
+ submitit==1.1.5
24
+ tensorboardX==2.5
25
+ torch==1.7.1
26
+ torchvision==0.8.2
27
+ tqdm==4.62.2
28
+ trimesh==3.9.8
29
+ imageio-ffmpeg==0.4.5
run_train.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+
4
+ from math import dist
5
+ import sys
6
+ import os
7
+ import click
8
+ import re
9
+ import json
10
+ import glob
11
+ import tempfile
12
+ import torch
13
+ import dnnlib
14
+ import hydra
15
+
16
+ from datetime import date
17
+ from training import training_loop
18
+ from metrics import metric_main
19
+ from torch_utils import training_stats, custom_ops, distributed_utils
20
+ from torch_utils.distributed_utils import get_init_file, get_shared_folder
21
+ from omegaconf import DictConfig, OmegaConf
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ class UserError(Exception):
26
+ pass
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ def setup_training_loop_kwargs(cfg):
31
+ args = OmegaConf.create({})
32
+
33
+ # ------------------------------------------
34
+ # General options: gpus, snap, metrics, seed
35
+ # ------------------------------------------
36
+ args.rank = 0
37
+ args.gpu = 0
38
+ args.num_gpus = torch.cuda.device_count() if cfg.gpus is None else cfg.gpus
39
+ args.nodes = cfg.nodes if cfg.nodes is not None else 1
40
+ args.world_size = 1
41
+
42
+ args.dist_url = 'env://'
43
+ args.launcher = cfg.launcher
44
+ args.partition = cfg.partition
45
+ args.comment = cfg.comment
46
+ args.timeout = 4320 if cfg.timeout is None else cfg.timeout
47
+ args.job_dir = ''
48
+
49
+ if cfg.snap is None:
50
+ cfg.snap = 50
51
+ assert isinstance(cfg.snap, int)
52
+ if cfg.snap < 1:
53
+ raise UserError('snap must be at least 1')
54
+ args.image_snapshot_ticks = cfg.imgsnap
55
+ args.network_snapshot_ticks = cfg.snap
56
+ if hasattr(cfg, 'ucp'):
57
+ args.update_cam_prior_ticks = cfg.ucp
58
+
59
+ if cfg.metrics is None:
60
+ cfg.metrics = ['fid50k_full']
61
+ cfg.metrics = list(cfg.metrics)
62
+ if not all(metric_main.is_valid_metric(metric) for metric in cfg.metrics):
63
+ raise UserError('\n'.join(['metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
64
+ args.metrics = cfg.metrics
65
+
66
+ if cfg.seed is None:
67
+ cfg.seed = 0
68
+ assert isinstance(cfg.seed, int)
69
+ args.random_seed = cfg.seed
70
+
71
+ # -----------------------------------
72
+ # Dataset: data, cond, subset, mirror
73
+ # -----------------------------------
74
+
75
+ assert cfg.data is not None
76
+ assert isinstance(cfg.data, str)
77
+ args.update({"training_set_kwargs": dict(class_name='training.dataset.ImageFolderDataset', path=cfg.data, resolution=cfg.resolution, use_labels=True, max_size=None, xflip=False)})
78
+ args.update({"data_loader_kwargs": dict(pin_memory=True, num_workers=3, prefetch_factor=2)})
79
+ args.generation_with_image = getattr(cfg, 'generate_with_image', False)
80
+ try:
81
+ training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset
82
+ args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution
83
+ args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels
84
+ args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size
85
+ desc = training_set.name
86
+ del training_set # conserve memory
87
+ except IOError as err:
88
+ raise UserError(f'data: {err}')
89
+
90
+ if cfg.cond is None:
91
+ cfg.cond = False
92
+ assert isinstance(cfg.cond, bool)
93
+ if cfg.cond:
94
+ if not args.training_set_kwargs.use_labels:
95
+ raise UserError('cond=True requires labels specified in dataset.json')
96
+ desc += '-cond'
97
+ else:
98
+ args.training_set_kwargs.use_labels = False
99
+
100
+ if cfg.subset is not None:
101
+ assert isinstance(cfg.subset, int)
102
+ if not 1 <= cfg.subset <= args.training_set_kwargs.max_size:
103
+ raise UserError(f'subset must be between 1 and {args.training_set_kwargs.max_size}')
104
+ desc += f'-subset{cfg.subset}'
105
+ if cfg.subset < args.training_set_kwargs.max_size:
106
+ args.training_set_kwargs.max_size = cfg.subset
107
+ args.training_set_kwargs.random_seed = args.random_seed
108
+
109
+ if cfg.mirror is None:
110
+ cfg.mirror = False
111
+ assert isinstance(cfg.mirror, bool)
112
+ if cfg.mirror:
113
+ desc += '-mirror'
114
+ args.training_set_kwargs.xflip = True
115
+
116
+ # ------------------------------------
117
+ # Base config: cfg, model, gamma, kimg, batch
118
+ # ------------------------------------
119
+ if cfg.auto:
120
+ cfg.spec.name = 'auto'
121
+ desc += f'-{cfg.spec.name}'
122
+ desc += f'-{cfg.model.name}'
123
+ if cfg.spec.name == 'auto':
124
+ res = args.training_set_kwargs.resolution
125
+ cfg.spec.fmaps = 1 if res >= 512 else 0.5
126
+ cfg.spec.lrate = 0.002 if res >= 1024 else 0.0025
127
+ cfg.spec.gamma = 0.0002 * (res ** 2) / cfg.spec.mb # heuristic formula
128
+ cfg.spec.ema = cfg.spec.mb * 10 / 32
129
+
130
+ if getattr(cfg.spec, 'lrate_disc', None) is None:
131
+ cfg.spec.lrate_disc = cfg.spec.lrate # use the same learning rate for discriminator
132
+
133
+ # model (generator, discriminator)
134
+ args.update({"G_kwargs": dict(**cfg.model.G_kwargs)})
135
+ args.update({"D_kwargs": dict(**cfg.model.D_kwargs)})
136
+ args.update({"G_opt_kwargs": dict(class_name='torch.optim.Adam', lr=cfg.spec.lrate, betas=[0,0.99], eps=1e-8)})
137
+ args.update({"D_opt_kwargs": dict(class_name='torch.optim.Adam', lr=cfg.spec.lrate_disc, betas=[0,0.99], eps=1e-8)})
138
+ args.update({"loss_kwargs": dict(class_name='training.loss.StyleGAN2Loss', r1_gamma=cfg.spec.gamma, **cfg.model.loss_kwargs)})
139
+
140
+ if cfg.spec.name == 'cifar':
141
+ args.loss_kwargs.pl_weight = 0 # disable path length regularization
142
+ args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
143
+ args.D_kwargs.architecture = 'orig' # disable residual skip connections
144
+
145
+ # kimg data config
146
+ args.spec = cfg.spec # just keep the dict.
147
+ args.total_kimg = cfg.spec.kimg
148
+ args.batch_size = cfg.spec.mb
149
+ args.batch_gpu = cfg.spec.mbstd
150
+ args.ema_kimg = cfg.spec.ema
151
+ args.ema_rampup = cfg.spec.ramp
152
+
153
+ # ---------------------------------------------------
154
+ # Discriminator augmentation: aug, p, target, augpipe
155
+ # ---------------------------------------------------
156
+ if cfg.aug is None:
157
+ cfg.aug = 'ada'
158
+ else:
159
+ assert isinstance(cfg.aug, str)
160
+ desc += f'-{cfg.aug}'
161
+
162
+ if cfg.aug == 'ada':
163
+ args.ada_target = 0.6
164
+ elif cfg.aug == 'noaug':
165
+ pass
166
+ elif cfg.aug == 'fixed':
167
+ if cfg.p is None:
168
+ raise UserError(f'--aug={cfg.aug} requires specifying --p')
169
+ else:
170
+ raise UserError(f'--aug={cfg.aug} not supported')
171
+
172
+ if cfg.p is not None:
173
+ assert isinstance(cfg.p, float)
174
+ if cfg.aug != 'fixed':
175
+ raise UserError('--p can only be specified with --aug=fixed')
176
+ if not 0 <= cfg.p <= 1:
177
+ raise UserError('--p must be between 0 and 1')
178
+ desc += f'-p{cfg.p:g}'
179
+ args.augment_p = cfg.p
180
+
181
+ if cfg.target is not None:
182
+ assert isinstance(cfg.target, float)
183
+ if cfg.aug != 'ada':
184
+ raise UserError('--target can only be specified with --aug=ada')
185
+ if not 0 <= cfg.target <= 1:
186
+ raise UserError('--target must be between 0 and 1')
187
+ desc += f'-target{cfg.target:g}'
188
+ args.ada_target = cfg.target
189
+
190
+ assert cfg.augpipe is None or isinstance(cfg.augpipe, str)
191
+ if cfg.augpipe is None:
192
+ cfg.augpipe = 'bgc'
193
+ else:
194
+ if cfg.aug == 'noaug':
195
+ raise UserError('--augpipe cannot be specified with --aug=noaug')
196
+ desc += f'-{cfg.augpipe}'
197
+
198
+ augpipe_specs = {
199
+ 'blit': dict(xflip=1, rotate90=1, xint=1),
200
+ 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1),
201
+ 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
202
+ 'filter': dict(imgfilter=1),
203
+ 'noise': dict(noise=1),
204
+ 'cutout': dict(cutout=1),
205
+ 'bgc0': dict(xint=1, scale=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
206
+ 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
207
+ 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
208
+ 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
209
+ 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
210
+ 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
211
+ }
212
+ assert cfg.augpipe in augpipe_specs
213
+ if cfg.aug != 'noaug':
214
+ args.update({"augment_kwargs": dict(class_name='training.augment.AugmentPipe', **augpipe_specs[cfg.augpipe])})
215
+
216
+ # ----------------------------------
217
+ # Transfer learning: resume, freezed
218
+ # ----------------------------------
219
+
220
+ resume_specs = {
221
+ 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
222
+ 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
223
+ 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
224
+ 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
225
+ 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
226
+ }
227
+
228
+ assert cfg.resume is None or isinstance(cfg.resume, str)
229
+ if cfg.resume is None:
230
+ cfg.resume = 'noresume'
231
+ elif cfg.resume == 'noresume':
232
+ desc += '-noresume'
233
+ elif cfg.resume in resume_specs:
234
+ desc += f'-resume{cfg.resume}'
235
+ args.resume_pkl = resume_specs[cfg.resume] # predefined url
236
+ else:
237
+ desc += '-resumecustom'
238
+ args.resume_pkl = cfg.resume # custom path or url
239
+
240
+ if cfg.resume != 'noresume':
241
+ args.ada_kimg = 100 # make ADA react faster at the beginning
242
+ args.ema_rampup = None # disable EMA rampup
243
+
244
+ if cfg.freezed is not None:
245
+ assert isinstance(cfg.freezed, int)
246
+ if not cfg.freezed >= 0:
247
+ raise UserError('--freezed must be non-negative')
248
+ desc += f'-freezed{cfg.freezed:d}'
249
+ args.D_kwargs.block_kwargs.freeze_layers = cfg.freezed
250
+
251
+ # -------------------------------------------------
252
+ # Performance options: fp32, nhwc, nobench, workers
253
+ # -------------------------------------------------
254
+ args.num_fp16_res = cfg.num_fp16_res
255
+ if cfg.fp32 is None:
256
+ cfg.fp32 = False
257
+ assert isinstance(cfg.fp32, bool)
258
+ if cfg.fp32:
259
+ args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
260
+ args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None
261
+
262
+ if cfg.nhwc is None:
263
+ cfg.nhwc = False
264
+ assert isinstance(cfg.nhwc, bool)
265
+ if cfg.nhwc:
266
+ args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True
267
+
268
+ if cfg.nobench is None:
269
+ cfg.nobench = False
270
+ assert isinstance(cfg.nobench, bool)
271
+ if cfg.nobench:
272
+ args.cudnn_benchmark = False
273
+
274
+ if cfg.allow_tf32 is None:
275
+ cfg.allow_tf32 = False
276
+ assert isinstance(cfg.allow_tf32, bool)
277
+ args.allow_tf32 = cfg.allow_tf32
278
+
279
+ if cfg.workers is not None:
280
+ assert isinstance(cfg.workers, int)
281
+ if not cfg.workers >= 1:
282
+ raise UserError('--workers must be at least 1')
283
+ args.data_loader_kwargs.num_workers = cfg.workers
284
+
285
+ args.debug = cfg.debug
286
+ if getattr(cfg, "prefix", None) is not None:
287
+ desc = cfg.prefix + '-' + desc
288
+ return desc, args
289
+
290
+ #----------------------------------------------------------------------------
291
+
292
+ def subprocess_fn(rank, args):
293
+ if not args.debug:
294
+ dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True)
295
+
296
+ # Init torch.distributed.
297
+ distributed_utils.init_distributed_mode(rank, args)
298
+ if args.rank != 0:
299
+ custom_ops.verbosity = 'none'
300
+
301
+ # Execute training loop.
302
+ training_loop.training_loop(**args)
303
+
304
+ #----------------------------------------------------------------------------
305
+
306
+ class CommaSeparatedList(click.ParamType):
307
+ name = 'list'
308
+
309
+ def convert(self, value, param, ctx):
310
+ _ = param, ctx
311
+ if value is None or value.lower() == 'none' or value == '':
312
+ return []
313
+ return value.split(',')
314
+
315
+
316
+ @hydra.main(config_path="conf", config_name="config")
317
+ def main(cfg: DictConfig):
318
+
319
+ outdir = cfg.outdir
320
+
321
+ # Setup training options
322
+ run_desc, args = setup_training_loop_kwargs(cfg)
323
+
324
+ # Pick output directory.
325
+ prev_run_dirs = []
326
+ if os.path.isdir(outdir):
327
+ prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
328
+
329
+ if cfg.resume_run is None:
330
+ prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
331
+ prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
332
+ cur_run_id = max(prev_run_ids, default=-1) + 1
333
+ else:
334
+ cur_run_id = cfg.resume_run
335
+
336
+ args.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')
337
+ print(outdir, args.run_dir)
338
+
339
+ if cfg.resume_run is not None:
340
+ pkls = sorted(glob.glob(args.run_dir + '/network*.pkl'))
341
+ if len(pkls) > 0:
342
+ args.resume_pkl = pkls[-1]
343
+ args.resume_start = int(args.resume_pkl.split('-')[-1][:-4]) * 1000
344
+ else:
345
+ args.resume_start = 0
346
+
347
+ # Print options.
348
+ print()
349
+ print('Training options:')
350
+ print(OmegaConf.to_yaml(args))
351
+ print()
352
+ print(f'Output directory: {args.run_dir}')
353
+ print(f'Training data: {args.training_set_kwargs.path}')
354
+ print(f'Training duration: {args.total_kimg} kimg')
355
+ print(f'Number of images: {args.training_set_kwargs.max_size}')
356
+ print(f'Image resolution: {args.training_set_kwargs.resolution}')
357
+ print(f'Conditional model: {args.training_set_kwargs.use_labels}')
358
+ print(f'Dataset x-flips: {args.training_set_kwargs.xflip}')
359
+ print()
360
+
361
+ # Dry run?
362
+ if cfg.dry_run:
363
+ print('Dry run; exiting.')
364
+ return
365
+
366
+ # Create output directory.
367
+ print('Creating output directory...')
368
+ if not os.path.exists(args.run_dir):
369
+ os.makedirs(args.run_dir)
370
+ with open(os.path.join(args.run_dir, 'training_options.yaml'), 'wt') as fp:
371
+ OmegaConf.save(config=args, f=fp.name)
372
+
373
+ # Launch processes.
374
+ print('Launching processes...')
375
+ if (args.launcher == 'spawn') and (args.num_gpus > 1):
376
+ args.dist_url = distributed_utils.get_init_file().as_uri()
377
+ torch.multiprocessing.set_start_method('spawn')
378
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args,), nprocs=args.num_gpus)
379
+ else:
380
+ subprocess_fn(rank=0, args=args)
381
+
382
+ #----------------------------------------------------------------------------
383
+
384
+ if __name__ == "__main__":
385
+ if os.getenv('SLURM_ARGS') is not None:
386
+ # deparcated launcher for slurm jobs.
387
+ slurm_arg = eval(os.getenv('SLURM_ARGS'))
388
+ all_args = sys.argv[1:]
389
+ print(slurm_arg)
390
+ print(all_args)
391
+
392
+ from launcher import launch
393
+ launch(slurm_arg, all_args)
394
+
395
+ else:
396
+ main() # pylint: disable=no-value-for-parameter
397
+
398
+ #----------------------------------------------------------------------------
torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/custom_ops.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import glob
10
+ import hashlib
11
+ import importlib
12
+ import os
13
+ import re
14
+ import shutil
15
+ import uuid
16
+
17
+ import torch
18
+ import torch.utils.cpp_extension
19
+ from torch.utils.file_baton import FileBaton
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Global options.
23
+
24
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25
+
26
+ #----------------------------------------------------------------------------
27
+ # Internal helper funcs.
28
+
29
+ def _find_compiler_bindir():
30
+ patterns = [
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
35
+ ]
36
+ for pattern in patterns:
37
+ matches = sorted(glob.glob(pattern))
38
+ if len(matches):
39
+ return matches[-1]
40
+ return None
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def _get_mangled_gpu_name():
45
+ name = torch.cuda.get_device_name().lower()
46
+ out = []
47
+ for c in name:
48
+ if re.match('[a-z0-9_-]+', c):
49
+ out.append(c)
50
+ else:
51
+ out.append('-')
52
+ return ''.join(out)
53
+
54
+ #----------------------------------------------------------------------------
55
+ # Main entry point for compiling and loading C++/CUDA plugins.
56
+
57
+ _cached_plugins = dict()
58
+
59
+ def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60
+ assert verbosity in ['none', 'brief', 'full']
61
+ if headers is None:
62
+ headers = []
63
+ if source_dir is not None:
64
+ sources = [os.path.join(source_dir, fname) for fname in sources]
65
+ headers = [os.path.join(source_dir, fname) for fname in headers]
66
+
67
+ # Already cached?
68
+ if module_name in _cached_plugins:
69
+ return _cached_plugins[module_name]
70
+
71
+ # Print status.
72
+ if verbosity == 'full':
73
+ print(f'Setting up PyTorch plugin "{module_name}"...')
74
+ elif verbosity == 'brief':
75
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76
+ verbose_build = (verbosity == 'full')
77
+
78
+ # Compile and load.
79
+ try: # pylint: disable=too-many-nested-blocks
80
+ # Make sure we can find the necessary compiler binaries.
81
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82
+ compiler_bindir = _find_compiler_bindir()
83
+ if compiler_bindir is None:
84
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85
+ os.environ['PATH'] += ';' + compiler_bindir
86
+
87
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88
+ # break the build or unnecessarily restrict what's available to nvcc.
89
+ # Unset it to let nvcc decide based on what's available on the
90
+ # machine.
91
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92
+
93
+ # Incremental build md5sum trickery. Copies all the input source files
94
+ # into a cached build directory under a combined md5 digest of the input
95
+ # source files. Copying is done only if the combined digest has changed.
96
+ # This keeps input file timestamps and filenames the same as in previous
97
+ # extension builds, allowing for fast incremental rebuilds.
98
+ #
99
+ # This optimization is done only in case all the source files reside in
100
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101
+ # environment variable is set (we take this as a signal that the user
102
+ # actually cares about this.)
103
+ #
104
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105
+ # around the *.cu dependency bug in ninja config.
106
+ #
107
+ all_source_files = sorted(sources + headers)
108
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110
+
111
+ # Compute combined hash digest for all source files.
112
+ hash_md5 = hashlib.md5()
113
+ for src in all_source_files:
114
+ with open(src, 'rb') as f:
115
+ hash_md5.update(f.read())
116
+
117
+ # Select cached build directory name.
118
+ source_digest = hash_md5.hexdigest()
119
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121
+
122
+ if not os.path.isdir(cached_build_dir):
123
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124
+ os.makedirs(tmpdir)
125
+ for src in all_source_files:
126
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127
+ try:
128
+ os.replace(tmpdir, cached_build_dir) # atomic
129
+ except OSError:
130
+ # source directory already exists, delete tmpdir and its contents.
131
+ shutil.rmtree(tmpdir)
132
+ if not os.path.isdir(cached_build_dir): raise
133
+
134
+ # Compile.
135
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
138
+ else:
139
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140
+
141
+ # Load.
142
+ module = importlib.import_module(module_name)
143
+
144
+ except:
145
+ if verbosity == 'brief':
146
+ print('Failed!')
147
+ raise
148
+
149
+ # Print status and add to cache dict.
150
+ if verbosity == 'full':
151
+ print(f'Done setting up PyTorch plugin "{module_name}".')
152
+ elif verbosity == 'brief':
153
+ print('Done.')
154
+ _cached_plugins[module_name] = module
155
+ return module
156
+
157
+ #----------------------------------------------------------------------------
torch_utils/distributed_utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ import logging
4
+ import os
5
+ import pickle
6
+ import random
7
+ import socket
8
+ import struct
9
+ import subprocess
10
+ import warnings
11
+ import tempfile
12
+ import uuid
13
+
14
+
15
+ from datetime import date
16
+ from pathlib import Path
17
+ from collections import OrderedDict
18
+ from typing import Any, Dict, Mapping
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def is_master(args):
28
+ return args.distributed_rank == 0
29
+
30
+
31
+ def init_distributed_mode(rank, args):
32
+ if "WORLD_SIZE" in os.environ:
33
+ args.world_size = int(os.environ["WORLD_SIZE"])
34
+
35
+ if args.launcher == 'spawn': # single node with multiprocessing.spawn
36
+ args.world_size = args.num_gpus
37
+ args.rank = rank
38
+ args.gpu = rank
39
+
40
+ elif 'RANK' in os.environ:
41
+ args.rank = int(os.environ["RANK"])
42
+ args.gpu = int(os.environ['LOCAL_RANK'])
43
+
44
+ elif 'SLURM_PROCID' in os.environ:
45
+ args.rank = int(os.environ['SLURM_PROCID'])
46
+ args.gpu = args.rank % torch.cuda.device_count()
47
+
48
+ if args.world_size == 1:
49
+ return
50
+
51
+ if 'MASTER_ADDR' in os.environ:
52
+ args.dist_url = 'tcp://{}:{}'.format(os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
53
+
54
+ print(f'gpu={args.gpu}, rank={args.rank}, world_size={args.world_size}')
55
+ args.distributed = True
56
+ torch.cuda.set_device(args.gpu)
57
+ args.dist_backend = 'nccl'
58
+ print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
59
+
60
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
61
+ world_size=args.world_size, rank=args.rank)
62
+ torch.distributed.barrier()
63
+
64
+
65
+ def gather_list_and_concat(tensor):
66
+ gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
67
+ dist.all_gather(gather_t, tensor)
68
+ return torch.cat(gather_t)
69
+
70
+
71
+ def get_rank():
72
+ return dist.get_rank()
73
+
74
+
75
+ def get_world_size():
76
+ return dist.get_world_size()
77
+
78
+
79
+ def get_default_group():
80
+ return dist.group.WORLD
81
+
82
+
83
+ def all_gather_list(data, group=None, max_size=16384):
84
+ """Gathers arbitrary data from all nodes into a list.
85
+
86
+ Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
87
+ data. Note that *data* must be picklable.
88
+
89
+ Args:
90
+ data (Any): data from the local worker to be gathered on other workers
91
+ group (optional): group of the collective
92
+ max_size (int, optional): maximum size of the data to be gathered
93
+ across workers
94
+ """
95
+ rank = get_rank()
96
+ world_size = get_world_size()
97
+
98
+ buffer_size = max_size * world_size
99
+ if not hasattr(all_gather_list, '_buffer') or \
100
+ all_gather_list._buffer.numel() < buffer_size:
101
+ all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
102
+ all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
103
+ buffer = all_gather_list._buffer
104
+ buffer.zero_()
105
+ cpu_buffer = all_gather_list._cpu_buffer
106
+
107
+ data = data.cpu()
108
+ enc = pickle.dumps(data)
109
+ enc_size = len(enc)
110
+ header_size = 4 # size of header that contains the length of the encoded data
111
+ size = header_size + enc_size
112
+ if size > max_size:
113
+ raise ValueError('encoded data size ({}) exceeds max_size ({})'.format(size, max_size))
114
+
115
+ header = struct.pack(">I", enc_size)
116
+ cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
117
+ start = rank * max_size
118
+ buffer[start:start + size].copy_(cpu_buffer[:size])
119
+
120
+ all_reduce(buffer, group=group)
121
+
122
+ buffer = buffer.cpu()
123
+ try:
124
+ result = []
125
+ for i in range(world_size):
126
+ out_buffer = buffer[i * max_size:(i + 1) * max_size]
127
+ enc_size, = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
128
+ if enc_size > 0:
129
+ result.append(pickle.loads(bytes(out_buffer[header_size:header_size + enc_size].tolist())))
130
+ return result
131
+ except pickle.UnpicklingError:
132
+ raise Exception(
133
+ 'Unable to unpickle data from other workers. all_gather_list requires all '
134
+ 'workers to enter the function together, so this error usually indicates '
135
+ 'that the workers have fallen out of sync somehow. Workers can fall out of '
136
+ 'sync if one of them runs out of memory, or if there are other conditions '
137
+ 'in your training script that can cause one worker to finish an epoch '
138
+ 'while other workers are still iterating over their portions of the data. '
139
+ 'Try rerunning with --ddp-backend=no_c10d and see if that helps.'
140
+ )
141
+
142
+
143
+ def all_reduce_dict(
144
+ data: Mapping[str, Any],
145
+ device,
146
+ group=None,
147
+ ) -> Dict[str, Any]:
148
+ """
149
+ AllReduce a dictionary of values across workers. We separately
150
+ reduce items that are already on the device and items on CPU for
151
+ better performance.
152
+
153
+ Args:
154
+ data (Mapping[str, Any]): dictionary of data to all-reduce, but
155
+ cannot be a nested dictionary
156
+ device (torch.device): device for the reduction
157
+ group (optional): group of the collective
158
+ """
159
+ data_keys = list(data.keys())
160
+
161
+ # We want to separately reduce items that are already on the
162
+ # device and items on CPU for performance reasons.
163
+ cpu_data = OrderedDict()
164
+ device_data = OrderedDict()
165
+ for k in data_keys:
166
+ t = data[k]
167
+ if not torch.is_tensor(t):
168
+ cpu_data[k] = torch.tensor(t, dtype=torch.double)
169
+ elif t.device.type != device.type:
170
+ cpu_data[k] = t.to(dtype=torch.double)
171
+ else:
172
+ device_data[k] = t.to(dtype=torch.double)
173
+
174
+ def _all_reduce_dict(data: OrderedDict):
175
+ if len(data) == 0:
176
+ return data
177
+ buf = torch.stack(list(data.values())).to(device=device)
178
+ all_reduce(buf, group=group)
179
+ return {k: buf[i] for i, k in enumerate(data)}
180
+
181
+ cpu_data = _all_reduce_dict(cpu_data)
182
+ device_data = _all_reduce_dict(device_data)
183
+
184
+ def get_from_stack(key):
185
+ if key in cpu_data:
186
+ return cpu_data[key]
187
+ elif key in device_data:
188
+ return device_data[key]
189
+ raise KeyError
190
+
191
+ return OrderedDict([(key, get_from_stack(key)) for key in data_keys])
192
+
193
+
194
+ def get_shared_folder() -> Path:
195
+ user = os.getenv("USER")
196
+ if Path("/checkpoint/").is_dir():
197
+ p = Path(f"/checkpoint/{user}/experiments")
198
+ p.mkdir(exist_ok=True)
199
+ return p
200
+ else:
201
+ p = Path(f"/tmp/experiments")
202
+ p.mkdir(exist_ok=True)
203
+ return p
204
+
205
+
206
+ def get_init_file():
207
+ # Init file must not exist, but it's parent dir must exist.
208
+ os.makedirs(str(get_shared_folder()), exist_ok=True)
209
+ init_file = Path(str(get_shared_folder()) + f"/{uuid.uuid4().hex}_init")
210
+ if init_file.exists():
211
+ os.remove(str(init_file))
212
+ return init_file
213
+
torch_utils/misc.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ import re
12
+ import contextlib
13
+ import numpy as np
14
+ import torch
15
+ import warnings
16
+ import dnnlib
17
+
18
+ #----------------------------------------------------------------------------
19
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
20
+ # same constant is used multiple times.
21
+
22
+ _constant_cache = dict()
23
+
24
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
25
+ value = np.asarray(value)
26
+ if shape is not None:
27
+ shape = tuple(shape)
28
+ if dtype is None:
29
+ dtype = torch.get_default_dtype()
30
+ if device is None:
31
+ device = torch.device('cpu')
32
+ if memory_format is None:
33
+ memory_format = torch.contiguous_format
34
+
35
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
36
+ tensor = _constant_cache.get(key, None)
37
+ if tensor is None:
38
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
39
+ if shape is not None:
40
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
41
+ tensor = tensor.contiguous(memory_format=memory_format)
42
+ _constant_cache[key] = tensor
43
+ return tensor
44
+
45
+ #----------------------------------------------------------------------------
46
+ # Replace NaN/Inf with specified numerical values.
47
+
48
+ try:
49
+ nan_to_num = torch.nan_to_num # 1.8.0a0
50
+ except AttributeError:
51
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
52
+ assert isinstance(input, torch.Tensor)
53
+ if posinf is None:
54
+ posinf = torch.finfo(input.dtype).max
55
+ if neginf is None:
56
+ neginf = torch.finfo(input.dtype).min
57
+ assert nan == 0
58
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
59
+
60
+ #----------------------------------------------------------------------------
61
+ # Symbolic assert.
62
+
63
+ try:
64
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
65
+ except AttributeError:
66
+ symbolic_assert = torch.Assert # 1.7.0
67
+
68
+ #----------------------------------------------------------------------------
69
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
70
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
71
+
72
+ @contextlib.contextmanager
73
+ def suppress_tracer_warnings():
74
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
75
+ warnings.filters.insert(0, flt)
76
+ yield
77
+ warnings.filters.remove(flt)
78
+
79
+ #----------------------------------------------------------------------------
80
+ # Assert that the shape of a tensor matches the given list of integers.
81
+ # None indicates that the size of a dimension is allowed to vary.
82
+ # Performs symbolic assertion when used in torch.jit.trace().
83
+
84
+ def assert_shape(tensor, ref_shape):
85
+ if tensor.ndim != len(ref_shape):
86
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
87
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
88
+ if ref_size is None:
89
+ pass
90
+ elif isinstance(ref_size, torch.Tensor):
91
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
92
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
93
+ elif isinstance(size, torch.Tensor):
94
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
95
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
96
+ elif size != ref_size:
97
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
98
+
99
+ #----------------------------------------------------------------------------
100
+ # Function decorator that calls torch.autograd.profiler.record_function().
101
+
102
+ def profiled_function(fn):
103
+ def decorator(*args, **kwargs):
104
+ with torch.autograd.profiler.record_function(fn.__name__):
105
+ return fn(*args, **kwargs)
106
+ decorator.__name__ = fn.__name__
107
+ return decorator
108
+
109
+ #----------------------------------------------------------------------------
110
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
111
+ # indefinitely, shuffling items as it goes.
112
+
113
+ class InfiniteSampler(torch.utils.data.Sampler):
114
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
115
+ assert len(dataset) > 0
116
+ assert num_replicas > 0
117
+ assert 0 <= rank < num_replicas
118
+ assert 0 <= window_size <= 1
119
+ super().__init__(dataset)
120
+ self.dataset = dataset
121
+ self.rank = rank
122
+ self.num_replicas = num_replicas
123
+ self.shuffle = shuffle
124
+ self.seed = seed
125
+ self.window_size = window_size
126
+
127
+ def __iter__(self):
128
+ order = np.arange(len(self.dataset))
129
+ rnd = None
130
+ window = 0
131
+ if self.shuffle:
132
+ rnd = np.random.RandomState(self.seed)
133
+ rnd.shuffle(order)
134
+ window = int(np.rint(order.size * self.window_size))
135
+
136
+ idx = 0
137
+ while True:
138
+ i = idx % order.size
139
+ if idx % self.num_replicas == self.rank:
140
+ yield order[i]
141
+ if window >= 2:
142
+ j = (i - rnd.randint(window)) % order.size
143
+ order[i], order[j] = order[j], order[i]
144
+ idx += 1
145
+
146
+ #----------------------------------------------------------------------------
147
+ # Utilities for operating with torch.nn.Module parameters and buffers.
148
+
149
+ def params_and_buffers(module):
150
+ assert isinstance(module, torch.nn.Module)
151
+ return list(module.parameters()) + list(module.buffers())
152
+
153
+ def named_params_and_buffers(module):
154
+ assert isinstance(module, torch.nn.Module)
155
+ return list(module.named_parameters()) + list(module.named_buffers())
156
+
157
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
158
+ assert isinstance(src_module, torch.nn.Module)
159
+ assert isinstance(dst_module, torch.nn.Module)
160
+ src_tensors = dict(named_params_and_buffers(src_module))
161
+ for name, tensor in named_params_and_buffers(dst_module):
162
+ assert (name in src_tensors) or (not require_all)
163
+ if name in src_tensors:
164
+ try:
165
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
166
+ except Exception as e:
167
+ print(f'Error loading: {name} {src_tensors[name].shape} {tensor.shape}')
168
+ raise e
169
+ #----------------------------------------------------------------------------
170
+ # Context manager for easily enabling/disabling DistributedDataParallel
171
+ # synchronization.
172
+
173
+ @contextlib.contextmanager
174
+ def ddp_sync(module, sync):
175
+ assert isinstance(module, torch.nn.Module)
176
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
177
+ yield
178
+ else:
179
+ with module.no_sync():
180
+ yield
181
+
182
+ #----------------------------------------------------------------------------
183
+ # Check DistributedDataParallel consistency across processes.
184
+
185
+ def check_ddp_consistency(module, ignore_regex=None):
186
+ assert isinstance(module, torch.nn.Module)
187
+ for name, tensor in named_params_and_buffers(module):
188
+ fullname = type(module).__name__ + '.' + name
189
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
190
+ continue
191
+ tensor = tensor.detach()
192
+ if tensor.is_floating_point():
193
+ tensor = nan_to_num(tensor)
194
+ other = tensor.clone()
195
+ torch.distributed.broadcast(tensor=other, src=0)
196
+ assert (tensor == other).all(), fullname
197
+
198
+ #----------------------------------------------------------------------------
199
+ # Print summary table of module hierarchy.
200
+
201
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
202
+ assert isinstance(module, torch.nn.Module)
203
+ assert not isinstance(module, torch.jit.ScriptModule)
204
+ assert isinstance(inputs, (tuple, list))
205
+
206
+ # Register hooks.
207
+ entries = []
208
+ nesting = [0]
209
+ def pre_hook(_mod, _inputs):
210
+ nesting[0] += 1
211
+ def post_hook(mod, _inputs, outputs):
212
+ nesting[0] -= 1
213
+ if nesting[0] <= max_nesting:
214
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
215
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
216
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
217
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
218
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
219
+
220
+ # Run module.
221
+ outputs = module(*inputs)
222
+ for hook in hooks:
223
+ hook.remove()
224
+
225
+ # Identify unique outputs, parameters, and buffers.
226
+ tensors_seen = set()
227
+ for e in entries:
228
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
229
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
230
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
231
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
232
+
233
+ # Filter out redundant entries.
234
+ if skip_redundant:
235
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
236
+
237
+ # Construct table.
238
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
239
+ rows += [['---'] * len(rows[0])]
240
+ param_total = 0
241
+ buffer_total = 0
242
+ submodule_names = {mod: name for name, mod in module.named_modules()}
243
+ for e in entries:
244
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
245
+ param_size = sum(t.numel() for t in e.unique_params)
246
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
247
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
248
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
249
+ rows += [[
250
+ name + (':0' if len(e.outputs) >= 2 else ''),
251
+ str(param_size) if param_size else '-',
252
+ str(buffer_size) if buffer_size else '-',
253
+ (output_shapes + ['-'])[0],
254
+ (output_dtypes + ['-'])[0],
255
+ ]]
256
+ for idx in range(1, len(e.outputs)):
257
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
258
+ param_total += param_size
259
+ buffer_total += buffer_size
260
+ rows += [['---'] * len(rows[0])]
261
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
262
+
263
+ # Print table.
264
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
265
+ print()
266
+ for row in rows:
267
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
268
+ print()
269
+ return outputs
270
+
271
+ #----------------------------------------------------------------------------
272
+
273
+ def get_ddp_func(m, func_name):
274
+ if hasattr(m, func_name):
275
+ return getattr(m, func_name)
276
+ if hasattr(m.module, func_name):
277
+ return getattr(m.module, func_name)
278
+ return None
279
+
280
+
281
+ #----------------------------------------------------------------------------
282
+
283
+ @contextlib.contextmanager
284
+ def cuda_time(prefix=""):
285
+ start = torch.cuda.Event(enable_timing=True)
286
+ end = torch.cuda.Event(enable_timing=True)
287
+ start.record()
288
+ try:
289
+ yield
290
+ finally:
291
+ end.record()
292
+ torch.cuda.synchronize()
293
+ print(f'{prefix}: {start.elapsed_time(end)} ms')
294
+
295
+ # ---------------------------------------------------------------------------
296
+
297
+ def get_func(m, f):
298
+ if hasattr(m, f):
299
+ return getattr(m, f)
300
+ elif hasattr(m.module, f):
301
+ return getattr(m.module, f)
302
+ else:
303
+ raise NotImplementedError