File size: 1,080 Bytes
2b85630
 
 
0f7cc62
2b85630
 
0f7cc62
 
 
 
 
 
 
 
 
 
 
 
 
2b85630
 
 
 
 
 
 
 
 
 
 
 
 
 
4f84af9
2b85630
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from transformers import PreTrainedModel
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig


mp = {0:'sad',1:'joy',2:'love',3:'anger',4:'fear',5:'surprise'}

class SentimentConfig(PretrainedConfig):
    model_type = "SententenceTransformerSentimentClassifier"

    def __init__(self, embedding_model: str="sentence-transformers/all-MiniLM-L6-v2", class_map: dict=mp, h1: int=44, h2: int=46, **kwargs):
        self.embedding_model = embedding_model
        self.class_map = class_map
        self.h1 = h1
        self.h2 = h2

        super().__init__(**kwargs)

class SententenceTransformerSentimentModel(PreTrainedModel):
    config_class = SentimentConfig

    def __init__(self, config):
        super().__init__(config)

        self.fc1 = nn.Linear(384, config.h1)
        self.fc2 = nn.Linear(config.h1, config.h2)
        self.out = nn.Linear(config.h2, 6)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)
        out = F.softmax(x, dim=-1)
        return out