|
import os |
|
import torch |
|
import PIL.Image |
|
import numpy as np |
|
from torch import nn |
|
import torch.distributed as dist |
|
import timm.models.hub as timm_hub |
|
|
|
"""Modified from https://github.com/CompVis/taming-transformers.git""" |
|
|
|
import hashlib |
|
import requests |
|
from tqdm import tqdm |
|
try: |
|
import piq |
|
except: |
|
pass |
|
|
|
_CONTEXT_PARALLEL_GROUP = None |
|
_CONTEXT_PARALLEL_SIZE = None |
|
|
|
|
|
def is_dist_avail_and_initialized(): |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
|
|
|
|
def get_world_size(): |
|
if not is_dist_avail_and_initialized(): |
|
return 1 |
|
return dist.get_world_size() |
|
|
|
|
|
def get_rank(): |
|
if not is_dist_avail_and_initialized(): |
|
return 0 |
|
return dist.get_rank() |
|
|
|
|
|
def is_main_process(): |
|
return get_rank() == 0 |
|
|
|
|
|
def is_context_parallel_initialized(): |
|
if _CONTEXT_PARALLEL_GROUP is None: |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def set_context_parallel_group(size, group): |
|
global _CONTEXT_PARALLEL_GROUP |
|
global _CONTEXT_PARALLEL_SIZE |
|
_CONTEXT_PARALLEL_GROUP = group |
|
_CONTEXT_PARALLEL_SIZE = size |
|
|
|
|
|
def initialize_context_parallel(context_parallel_size): |
|
global _CONTEXT_PARALLEL_GROUP |
|
global _CONTEXT_PARALLEL_SIZE |
|
|
|
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" |
|
_CONTEXT_PARALLEL_SIZE = context_parallel_size |
|
|
|
rank = torch.distributed.get_rank() |
|
world_size = torch.distributed.get_world_size() |
|
|
|
for i in range(0, world_size, context_parallel_size): |
|
ranks = range(i, i + context_parallel_size) |
|
group = torch.distributed.new_group(ranks) |
|
if rank in ranks: |
|
_CONTEXT_PARALLEL_GROUP = group |
|
break |
|
|
|
|
|
def get_context_parallel_group(): |
|
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" |
|
|
|
return _CONTEXT_PARALLEL_GROUP |
|
|
|
|
|
def get_context_parallel_world_size(): |
|
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" |
|
|
|
return _CONTEXT_PARALLEL_SIZE |
|
|
|
|
|
def get_context_parallel_rank(): |
|
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" |
|
|
|
rank = get_rank() |
|
cp_rank = rank % _CONTEXT_PARALLEL_SIZE |
|
return cp_rank |
|
|
|
|
|
def get_context_parallel_group_rank(): |
|
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" |
|
|
|
rank = get_rank() |
|
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE |
|
|
|
return cp_group_rank |
|
|
|
|
|
def download_cached_file(url, check_hash=True, progress=False): |
|
""" |
|
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. |
|
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. |
|
""" |
|
|
|
def get_cached_file_path(): |
|
|
|
parts = torch.hub.urlparse(url) |
|
filename = os.path.basename(parts.path) |
|
cached_file = os.path.join(timm_hub.get_cache_dir(), filename) |
|
|
|
return cached_file |
|
|
|
if is_main_process(): |
|
timm_hub.download_cached_file(url, check_hash, progress) |
|
|
|
if is_dist_avail_and_initialized(): |
|
dist.barrier() |
|
|
|
return get_cached_file_path() |
|
|
|
|
|
def convert_weights_to_fp16(model: nn.Module): |
|
"""Convert applicable model parameters to fp16""" |
|
|
|
def _convert_weights_to_fp16(l): |
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): |
|
l.weight.data = l.weight.data.to(torch.float16) |
|
if l.bias is not None: |
|
l.bias.data = l.bias.data.to(torch.float16) |
|
|
|
model.apply(_convert_weights_to_fp16) |
|
|
|
|
|
def convert_weights_to_bf16(model: nn.Module): |
|
"""Convert applicable model parameters to fp16""" |
|
|
|
def _convert_weights_to_bf16(l): |
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): |
|
l.weight.data = l.weight.data.to(torch.bfloat16) |
|
if l.bias is not None: |
|
l.bias.data = l.bias.data.to(torch.bfloat16) |
|
|
|
model.apply(_convert_weights_to_bf16) |
|
|
|
|
|
def save_result(result, result_dir, filename, remove_duplicate="", save_format='json'): |
|
import json |
|
import jsonlines |
|
print("Dump result") |
|
|
|
|
|
if not os.path.exists(result_dir): |
|
if is_main_process(): |
|
os.makedirs(result_dir) |
|
if is_dist_avail_and_initialized(): |
|
torch.distributed.barrier() |
|
|
|
result_file = os.path.join( |
|
result_dir, "%s_rank%d.json" % (filename, get_rank()) |
|
) |
|
|
|
final_result_file = os.path.join(result_dir, f"{filename}.{save_format}") |
|
|
|
json.dump(result, open(result_file, "w")) |
|
|
|
if is_dist_avail_and_initialized(): |
|
torch.distributed.barrier() |
|
|
|
if is_main_process(): |
|
|
|
|
|
result = [] |
|
|
|
for rank in range(get_world_size()): |
|
result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank)) |
|
res = json.load(open(result_file, "r")) |
|
result += res |
|
|
|
|
|
if remove_duplicate: |
|
result_new = [] |
|
id_set = set() |
|
for res in result: |
|
if res[remove_duplicate] not in id_set: |
|
id_set.add(res[remove_duplicate]) |
|
result_new.append(res) |
|
result = result_new |
|
|
|
if save_format == 'json': |
|
json.dump(result, open(final_result_file, "w")) |
|
else: |
|
assert save_format == 'jsonl', "Only support json adn jsonl format" |
|
with jsonlines.open(final_result_file, "w") as writer: |
|
writer.write_all(result) |
|
|
|
|
|
|
|
return final_result_file |
|
|
|
|
|
|
|
|
|
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): |
|
h, w = input.shape[-2:] |
|
factors = (h / size[0], w / size[1]) |
|
|
|
|
|
|
|
sigmas = ( |
|
max((factors[0] - 1.0) / 2.0, 0.001), |
|
max((factors[1] - 1.0) / 2.0, 0.001), |
|
) |
|
|
|
|
|
|
|
|
|
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) |
|
|
|
|
|
if (ks[0] % 2) == 0: |
|
ks = ks[0] + 1, ks[1] |
|
|
|
if (ks[1] % 2) == 0: |
|
ks = ks[0], ks[1] + 1 |
|
|
|
input = _gaussian_blur2d(input, ks, sigmas) |
|
|
|
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) |
|
return output |
|
|
|
|
|
def _compute_padding(kernel_size): |
|
"""Compute padding tuple.""" |
|
|
|
|
|
if len(kernel_size) < 2: |
|
raise AssertionError(kernel_size) |
|
computed = [k - 1 for k in kernel_size] |
|
|
|
|
|
out_padding = 2 * len(kernel_size) * [0] |
|
|
|
for i in range(len(kernel_size)): |
|
computed_tmp = computed[-(i + 1)] |
|
|
|
pad_front = computed_tmp // 2 |
|
pad_rear = computed_tmp - pad_front |
|
|
|
out_padding[2 * i + 0] = pad_front |
|
out_padding[2 * i + 1] = pad_rear |
|
|
|
return out_padding |
|
|
|
|
|
def _filter2d(input, kernel): |
|
|
|
b, c, h, w = input.shape |
|
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) |
|
|
|
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) |
|
|
|
height, width = tmp_kernel.shape[-2:] |
|
|
|
padding_shape: list[int] = _compute_padding([height, width]) |
|
input = torch.nn.functional.pad(input, padding_shape, mode="reflect") |
|
|
|
|
|
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) |
|
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) |
|
|
|
|
|
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) |
|
|
|
out = output.view(b, c, h, w) |
|
return out |
|
|
|
|
|
def _gaussian(window_size: int, sigma): |
|
if isinstance(sigma, float): |
|
sigma = torch.tensor([[sigma]]) |
|
|
|
batch_size = sigma.shape[0] |
|
|
|
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) |
|
|
|
if window_size % 2 == 0: |
|
x = x + 0.5 |
|
|
|
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) |
|
|
|
return gauss / gauss.sum(-1, keepdim=True) |
|
|
|
|
|
def _gaussian_blur2d(input, kernel_size, sigma): |
|
if isinstance(sigma, tuple): |
|
sigma = torch.tensor([sigma], dtype=input.dtype) |
|
else: |
|
sigma = sigma.to(dtype=input.dtype) |
|
|
|
ky, kx = int(kernel_size[0]), int(kernel_size[1]) |
|
bs = sigma.shape[0] |
|
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) |
|
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) |
|
out_x = _filter2d(input, kernel_x[..., None, :]) |
|
out = _filter2d(out_x, kernel_y[..., None]) |
|
|
|
return out |
|
|
|
|
|
URL_MAP = { |
|
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" |
|
} |
|
|
|
CKPT_MAP = { |
|
"vgg_lpips": "vgg.pth" |
|
} |
|
|
|
MD5_MAP = { |
|
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a" |
|
} |
|
|
|
|
|
def download(url, local_path, chunk_size=1024): |
|
os.makedirs(os.path.split(local_path)[0], exist_ok=True) |
|
with requests.get(url, stream=True) as r: |
|
total_size = int(r.headers.get("content-length", 0)) |
|
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: |
|
with open(local_path, "wb") as f: |
|
for data in r.iter_content(chunk_size=chunk_size): |
|
if data: |
|
f.write(data) |
|
pbar.update(chunk_size) |
|
|
|
|
|
def md5_hash(path): |
|
with open(path, "rb") as f: |
|
content = f.read() |
|
return hashlib.md5(content).hexdigest() |
|
|
|
|
|
def get_ckpt_path(name, root, check=False): |
|
assert name in URL_MAP |
|
path = os.path.join(root, CKPT_MAP[name]) |
|
print(md5_hash(path)) |
|
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): |
|
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) |
|
download(URL_MAP[name], path) |
|
md5 = md5_hash(path) |
|
assert md5 == MD5_MAP[name], md5 |
|
return path |
|
|
|
|
|
class KeyNotFoundError(Exception): |
|
def __init__(self, cause, keys=None, visited=None): |
|
self.cause = cause |
|
self.keys = keys |
|
self.visited = visited |
|
messages = list() |
|
if keys is not None: |
|
messages.append("Key not found: {}".format(keys)) |
|
if visited is not None: |
|
messages.append("Visited: {}".format(visited)) |
|
messages.append("Cause:\n{}".format(cause)) |
|
message = "\n".join(messages) |
|
super().__init__(message) |
|
|
|
|
|
def retrieve( |
|
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False |
|
): |
|
"""Given a nested list or dict return the desired value at key expanding |
|
callable nodes if necessary and :attr:`expand` is ``True``. The expansion |
|
is done in-place. |
|
|
|
Parameters |
|
---------- |
|
list_or_dict : list or dict |
|
Possibly nested list or dictionary. |
|
key : str |
|
key/to/value, path like string describing all keys necessary to |
|
consider to get to the desired value. List indices can also be |
|
passed here. |
|
splitval : str |
|
String that defines the delimiter between keys of the |
|
different depth levels in `key`. |
|
default : obj |
|
Value returned if :attr:`key` is not found. |
|
expand : bool |
|
Whether to expand callable nodes on the path or not. |
|
|
|
Returns |
|
------- |
|
The desired value or if :attr:`default` is not ``None`` and the |
|
:attr:`key` is not found returns ``default``. |
|
|
|
Raises |
|
------ |
|
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is |
|
``None``. |
|
""" |
|
|
|
keys = key.split(splitval) |
|
|
|
success = True |
|
try: |
|
visited = [] |
|
parent = None |
|
last_key = None |
|
for key in keys: |
|
if callable(list_or_dict): |
|
if not expand: |
|
raise KeyNotFoundError( |
|
ValueError( |
|
"Trying to get past callable node with expand=False." |
|
), |
|
keys=keys, |
|
visited=visited, |
|
) |
|
list_or_dict = list_or_dict() |
|
parent[last_key] = list_or_dict |
|
|
|
last_key = key |
|
parent = list_or_dict |
|
|
|
try: |
|
if isinstance(list_or_dict, dict): |
|
list_or_dict = list_or_dict[key] |
|
else: |
|
list_or_dict = list_or_dict[int(key)] |
|
except (KeyError, IndexError, ValueError) as e: |
|
raise KeyNotFoundError(e, keys=keys, visited=visited) |
|
|
|
visited += [key] |
|
|
|
if expand and callable(list_or_dict): |
|
list_or_dict = list_or_dict() |
|
parent[last_key] = list_or_dict |
|
except KeyNotFoundError as e: |
|
if default is None: |
|
raise e |
|
else: |
|
list_or_dict = default |
|
success = False |
|
|
|
if not pass_success: |
|
return list_or_dict |
|
else: |
|
return list_or_dict, success |