File size: 1,623 Bytes
3ef28b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from dataclasses import dataclass
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import bias_gelu_impl
from mamba_config import MambaConfig

class MLP(nn.Module):
    def __init__(
        self, config: MambaConfig, is_expert: bool = False, layer_idx=None
    ):
        super().__init__()
        
        self.config: MambaConfig = config
        self.layer = layer_idx
        ffn_hidden_size_1 = self.config.ffn_hidden_size
        ffn_hidden_size_2 = self.config.ffn_hidden_size
        
        # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        if self.config.gated_linear_unit:
            ffn_hidden_size_1 *= 2
            
        self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear, device = self.config.device)
        self.linear_fc1.is_expert = is_expert

        if self.config.gated_linear_unit:

            def glu(x):
                x = torch.chunk(x, 2, dim=-1)
                return self.config.activation_func(x[0]) * x[1]

            self.activation_func = glu
        else:
            self.activation_func = self.config.activation_func

        self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear, device = self.config.device)

    def forward(self, hidden_states, inference_params=None):
        intermediate = self.linear_fc1(hidden_states)
        intermediate = self.activation_func(intermediate)
        output = self.linear_fc2(intermediate)
        return output