Gpt / mlp.py
Zerx966's picture
Upload 10 files
3ef28b3 verified
raw
history blame
1.62 kB
from dataclasses import dataclass
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import bias_gelu_impl
from mamba_config import MambaConfig
class MLP(nn.Module):
def __init__(
self, config: MambaConfig, is_expert: bool = False, layer_idx=None
):
super().__init__()
self.config: MambaConfig = config
self.layer = layer_idx
ffn_hidden_size_1 = self.config.ffn_hidden_size
ffn_hidden_size_2 = self.config.ffn_hidden_size
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
if self.config.gated_linear_unit:
ffn_hidden_size_1 *= 2
self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear, device = self.config.device)
self.linear_fc1.is_expert = is_expert
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear, device = self.config.device)
def forward(self, hidden_states, inference_params=None):
intermediate = self.linear_fc1(hidden_states)
intermediate = self.activation_func(intermediate)
output = self.linear_fc2(intermediate)
return output