|
import math |
|
import numpy as np |
|
import torch as t |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
try: |
|
from apex.normalization import FusedLayerNorm |
|
print("Using apex FusedLayerNorm") |
|
except ImportError: |
|
from torch.nn import LayerNorm as FusedLayerNorm |
|
|
|
class LayerNorm(FusedLayerNorm): |
|
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): |
|
super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) |
|
self.width = np.prod(normalized_shape) |
|
self.max_numel = 65535*self.width |
|
|
|
def forward(self, input): |
|
if input.numel() > self.max_numel: |
|
return F.layer_norm(input.float(), self.normalized_shape, self.weight, self.bias, self.eps).type_as(input) |
|
else: |
|
return super(LayerNorm, self).forward(input.float()).type_as(input) |
|
|
|
def gelu(x): |
|
return 0.5 * x * (1 + t.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * t.pow(x, 3)))) |
|
|
|
|
|
def swish(x): |
|
return x * t.sigmoid(x) |
|
|
|
@t.jit.script |
|
def quick_gelu(x): |
|
return x * t.sigmoid(1.702 * x) |
|
|
|
@t.jit.script |
|
def quick_gelu_bwd(x, grad_output): |
|
sig = t.sigmoid(1.702 * x) |
|
return grad_output * sig * (1.702 * x * (1 - sig) + 1.) |
|
|
|
class QuickGelu(t.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x): |
|
ctx.save_for_backward(x) |
|
return quick_gelu(x) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return quick_gelu_bwd(ctx.saved_tensors[0], grad_output) |
|
|
|
def memory_efficient_quick_gelu(x): |
|
return QuickGelu.apply(x) |
|
|
|
ACT_FNS = { |
|
'relu': t.nn.functional.relu, |
|
'swish': swish, |
|
'gelu': gelu, |
|
'quick_gelu': memory_efficient_quick_gelu |
|
} |
|
|
|
def _move_to_gpu_and_convert_conv_weights_to_fp16(l): |
|
l.cuda() |
|
if isinstance(l, Conv1D): |
|
l.w.data = l.w.data.half() |
|
|
|
def _convert_conv_weights_to_fp32(l): |
|
if isinstance(l, Conv1D): |
|
l.w.data = l.w.data.float() |
|
|
|
def _convert_conv_weights_to_fp16(l): |
|
if isinstance(l, Conv1D): |
|
l.w.data = l.w.data.half() |
|
|
|
def _convert_embedding_weights_to_fp16(l): |
|
if isinstance(l, t.nn.Embedding): |
|
l.weight.data = l.weight.data.half() |
|
|
|
def _convert_embedding_weights_to_fp32(l): |
|
if isinstance(l, t.nn.Embedding): |
|
l.weight.data = l.weight.data.float() |
|
|
|
class Conv1D(nn.Module): |
|
def __init__(self, n_in, n_out, zero_out=False, init_scale=1.0): |
|
super(Conv1D, self).__init__() |
|
self.n_in = n_in |
|
self.n_out = n_out |
|
if zero_out: |
|
w = t.zeros(n_in, n_out) |
|
else: |
|
w = t.empty(n_in, n_out) |
|
nn.init.normal_(w, std=0.02 * init_scale) |
|
b = t.zeros(n_out) |
|
self.w = nn.Parameter(w) |
|
self.b = nn.Parameter(b) |
|
|
|
def forward(self, x): |
|
size_out = (*x.size()[:-1], self.n_out) |
|
x = t.addmm(self.b.type_as(x), x.view(-1, x.size(-1)), self.w.type_as(x)) |
|
x = x.view(*size_out) |
|
return x |
|
|
|
|
|
class Mask(nn.Module): |
|
def __init__(self, n_ctx): |
|
super().__init__() |
|
self.register_buffer('b', t.tril(t.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) |
|
|
|
def forward(self, w): |
|
w = w * self.b + -1e9 * (1 - self.b) |
|
return w |
|
|
|
def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
Args: |
|
logits: logits distribution shape (vocabulary size) |
|
top_k >0: keep only top k tokens with highest probability (top-k filtering). |
|
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
|
""" |
|
|
|
logits = logits.clone() |
|
top_k = min(top_k, logits.size(-1)) |
|
assert (top_k == 0) or (top_p == 0.0) |
|
if top_k > 0: |
|
|
|
indices_to_remove = logits < t.topk(logits, top_k, dim=-1)[0][..., -1:] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p > 0.0: |
|
sorted_logits, sorted_indices = t.sort(logits, descending=True, dim=-1) |
|
cumulative_probs = t.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = t.zeros_like(logits, dtype=t.uint8).scatter_(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) |
|
logits[indices_to_remove] = filter_value |
|
return logits |
|
|