Spaces:
Sleeping
Sleeping
import glob | |
import logging | |
import re | |
import time | |
from collections import defaultdict | |
import os | |
import sys | |
import shutil | |
import types | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from torch import nn | |
def tensors_to_scalars(metrics): | |
new_metrics = {} | |
for k, v in metrics.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
if type(v) is dict: | |
v = tensors_to_scalars(v) | |
new_metrics[k] = v | |
return new_metrics | |
class AvgrageMeter(object): | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.avg = 0 | |
self.sum = 0 | |
self.cnt = 0 | |
def update(self, val, n=1): | |
self.sum += val * n | |
self.cnt += n | |
self.avg = self.sum / self.cnt | |
def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): | |
"""Convert a list of 1d tensors into a padded 2d tensor.""" | |
size = max(v.size(0) for v in values) if max_len is None else max_len | |
res = values[0].new(len(values), size).fill_(pad_idx) | |
def copy_tensor(src, dst): | |
assert dst.numel() == src.numel() | |
if shift_right: | |
dst[1:] = src[:-1] | |
dst[0] = shift_id | |
else: | |
dst.copy_(src) | |
for i, v in enumerate(values): | |
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) | |
return res | |
def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): | |
"""Convert a list of 2d tensors into a padded 3d tensor.""" | |
size = max(v.size(0) for v in values) if max_len is None else max_len | |
res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) | |
def copy_tensor(src, dst): | |
assert dst.numel() == src.numel() | |
if shift_right: | |
dst[1:] = src[:-1] | |
else: | |
dst.copy_(src) | |
for i, v in enumerate(values): | |
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) | |
return res | |
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | |
if len(batch) == 0: | |
return 0 | |
if len(batch) == max_sentences: | |
return 1 | |
if num_tokens > max_tokens: | |
return 1 | |
return 0 | |
def batch_by_size( | |
indices, num_tokens_fn, max_tokens=None, max_sentences=None, | |
required_batch_size_multiple=1, distributed=False | |
): | |
""" | |
Yield mini-batches of indices bucketed by size. Batches may contain | |
sequences of different lengths. | |
Args: | |
indices (List[int]): ordered list of dataset indices | |
num_tokens_fn (callable): function that returns the number of tokens at | |
a given index | |
max_tokens (int, optional): max number of tokens in each batch | |
(default: None). | |
max_sentences (int, optional): max number of sentences in each | |
batch (default: None). | |
required_batch_size_multiple (int, optional): require batch size to | |
be a multiple of N (default: 1). | |
""" | |
max_tokens = max_tokens if max_tokens is not None else sys.maxsize | |
max_sentences = max_sentences if max_sentences is not None else sys.maxsize | |
bsz_mult = required_batch_size_multiple | |
if isinstance(indices, types.GeneratorType): | |
indices = np.fromiter(indices, dtype=np.int64, count=-1) | |
sample_len = 0 | |
sample_lens = [] | |
batch = [] | |
batches = [] | |
for i in range(len(indices)): | |
idx = indices[i] | |
num_tokens = num_tokens_fn(idx) | |
sample_lens.append(num_tokens) | |
sample_len = max(sample_len, num_tokens) | |
assert sample_len <= max_tokens, ( | |
"sentence at index {} of size {} exceeds max_tokens " | |
"limit of {}!".format(idx, sample_len, max_tokens) | |
) | |
num_tokens = (len(batch) + 1) * sample_len | |
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | |
mod_len = max( | |
bsz_mult * (len(batch) // bsz_mult), | |
len(batch) % bsz_mult, | |
) | |
batches.append(batch[:mod_len]) | |
batch = batch[mod_len:] | |
sample_lens = sample_lens[mod_len:] | |
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 | |
batch.append(idx) | |
if len(batch) > 0: | |
batches.append(batch) | |
return batches | |
def make_positions(tensor, padding_idx): | |
"""Replace non-padding symbols with their position numbers. | |
Position numbers begin at padding_idx+1. Padding symbols are ignored. | |
""" | |
# The series of casts and type-conversions here are carefully | |
# balanced to both work with ONNX export and XLA. In particular XLA | |
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know | |
# how to handle the dtype kwarg in cumsum. | |
mask = tensor.ne(padding_idx).int() | |
return ( | |
torch.cumsum(mask, dim=1).type_as(mask) * mask | |
).long() + padding_idx | |
def softmax(x, dim): | |
return F.softmax(x, dim=dim, dtype=torch.float32) | |
def unpack_dict_to_list(samples): | |
samples_ = [] | |
bsz = samples.get('outputs').size(0) | |
for i in range(bsz): | |
res = {} | |
for k, v in samples.items(): | |
try: | |
res[k] = v[i] | |
except: | |
pass | |
samples_.append(res) | |
return samples_ | |
def load_ckpt(cur_model, ckpt_base_dir, prefix_in_ckpt='model', force=True, strict=True): | |
if os.path.isfile(ckpt_base_dir): | |
base_dir = os.path.dirname(ckpt_base_dir) | |
checkpoint_path = [ckpt_base_dir] | |
else: | |
base_dir = ckpt_base_dir | |
checkpoint_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= | |
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x.replace('\\','/'))[0])) | |
if len(checkpoint_path) > 0: | |
checkpoint_path = checkpoint_path[-1] | |
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"] | |
state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items() | |
if k.startswith(f'{prefix_in_ckpt}.')} | |
if not strict: | |
cur_model_state_dict = cur_model.state_dict() | |
unmatched_keys = [] | |
for key, param in state_dict.items(): | |
if key in cur_model_state_dict: | |
new_param = cur_model_state_dict[key] | |
if new_param.shape != param.shape: | |
unmatched_keys.append(key) | |
print("| Unmatched keys: ", key, new_param.shape, param.shape) | |
for key in unmatched_keys: | |
del state_dict[key] | |
cur_model.load_state_dict(state_dict, strict=strict) | |
print(f"| load '{prefix_in_ckpt}' from '{checkpoint_path}'.") | |
else: | |
e_msg = f"| ckpt not found in {base_dir}." | |
if force: | |
assert False, e_msg | |
else: | |
print(e_msg) | |
def remove_padding(x, padding_idx=0): | |
if x is None: | |
return None | |
assert len(x.shape) in [1, 2] | |
if len(x.shape) == 2: # [T, H] | |
return x[np.abs(x).sum(-1) != padding_idx] | |
elif len(x.shape) == 1: # [T] | |
return x[x != padding_idx] | |
class Timer: | |
timer_map = {} | |
def __init__(self, name, print_time=False): | |
if name not in Timer.timer_map: | |
Timer.timer_map[name] = 0 | |
self.name = name | |
self.print_time = print_time | |
def __enter__(self): | |
self.t = time.time() | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
Timer.timer_map[self.name] += time.time() - self.t | |
if self.print_time: | |
print(self.name, Timer.timer_map[self.name]) | |
def print_arch(model, model_name='model'): | |
#print(f"| {model_name} Arch: ", model) | |
num_params(model, model_name=model_name) | |
def num_params(model, print_out=True, model_name="model"): | |
parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 | |
if print_out: | |
print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) | |
return parameters | |