import torch import torch.nn as nn import numpy as np from transformers import AutoTokenizer import pickle # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained("DistillMDPI1/DistillMDPI1/saved_tokenizer") # Step 1: Ensure the tokenizer has the [MULT] token tokenizer.add_special_tokens({'additional_special_tokens': ['']}) mult_token_id = tokenizer.convert_tokens_to_ids('') cls_token_id = tokenizer.cls_token_id sep_token_id = tokenizer.sep_token_id pad_token_id = tokenizer.pad_token_id ## Voice Part functions maxlen = 255 # maximum of length batch_size = 32 max_pred = 5 # max tokens of prediction n_layers = 6 # number of Encoder of Encoder Layer n_heads = 12 # number of heads in Multi-Head Attention d_model = 768 # Embedding Size d_ff = 768 * 4 # 4*d_model, FeedForward dimension d_k = d_v = 64 # dimension of K(=Q), V n_segments = 2 vocab_size = tokenizer.vocab_size +1 def get_attn_pad_mask(seq_q, seq_k): batch_size, len_q = seq_q.size() batch_size, len_k = seq_k.size() # eq(zero) is PAD token pad_attn_mask = seq_k.data.eq(1).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k class Embedding(nn.Module): def __init__(self): super(Embedding, self).__init__() self.tok_embed = nn.Embedding(vocab_size, d_model) # token embedding self.pos_embed = nn.Embedding(maxlen, d_model) # position embedding self.seg_embed = nn.Embedding(n_segments, d_model) # segment(token type) embedding self.norm = nn.LayerNorm(d_model) def forward(self, x, seg): seq_len = x.size(1) pos = torch.arange(seq_len, dtype=torch.long, device=x.device) pos = pos.unsqueeze(0).expand_as(x) # (seq_len,) -> (batch_size, seq_len) embedding = self.tok_embed(x) embedding += self.pos_embed(pos) embedding += self.seg_embed(seg) return self.norm(embedding) class ScaledDotProductAttention(nn.Module): def __init__(self): super(ScaledDotProductAttention, self).__init__() def forward(self, Q, K, V, attn_mask): scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one. attn = nn.Softmax(dim=-1)(scores) context = torch.matmul(attn, V) return scores , context, attn class MultiHeadAttention(nn.Module): def __init__(self): super(MultiHeadAttention, self).__init__() self.W_Q = nn.Linear(d_model, d_k * n_heads) self.W_K = nn.Linear(d_model, d_k * n_heads) self.W_V = nn.Linear(d_model, d_v * n_heads) self.fc = nn.Linear(n_heads * d_v, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, Q, K, V, attn_mask): # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model] residual, batch_size = Q, Q.size(0) device = Q.device Q, K, V = Q.to(device), K.to(device), V.to(device) # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size x n_heads x len_q x d_k] k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size x n_heads x len_k x d_k] v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size x n_heads x len_k x d_v] attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k] # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] scores ,context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask) context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v] output = self.fc(context) return self.norm(output + residual), attn # output: [batch_size x len_q x d_model] class PoswiseFeedForwardNet(nn.Module): def __init__(self): super(PoswiseFeedForwardNet, self).__init__() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.gelu = nn.GELU() def forward(self, x): # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model) return self.fc2(self.gelu(self.fc1(x))) class EncoderLayer(nn.Module): def __init__(self): super(EncoderLayer, self).__init__() self.enc_self_attn = MultiHeadAttention() self.pos_ffn = PoswiseFeedForwardNet() def forward(self, enc_inputs, enc_self_attn_mask): enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask.to(enc_inputs.device)) # enc_inputs to same Q,K,V enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model] return enc_outputs, attn class BERT(nn.Module): def __init__(self): super(BERT, self).__init__() self.embedding = Embedding() self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) self.fc = nn.Linear(d_model, d_model) self.activ1 = nn.Tanh() self.linear = nn.Linear(d_model, d_model) self.activ2 = nn.GELU() self.norm = nn.LayerNorm(d_model) self.classifier = nn.Linear(d_model, 2) # decoder is shared with embedding layer embed_weight = self.embedding.tok_embed.weight n_vocab, n_dim = embed_weight.size() self.decoder = nn.Linear(n_dim, n_vocab, bias=False) self.decoder.weight = embed_weight self.decoder_bias = nn.Parameter(torch.zeros(n_vocab)) self.mclassifier = nn.Linear(d_model, 17) def forward(self, input_ids, segment_ids, masked_pos): output = self.embedding(input_ids, segment_ids) enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids).to(output.device) for layer in self.layers: output, enc_self_attn = layer(output, enc_self_attn_mask) # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model] # it will be decided by first token(CLS) h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model] logits_clsf = self.classifier(h_pooled) # [batch_size, 2] masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model] # get masked position from final output of transformer. h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model] h_masked = self.norm(self.activ2(self.linear(h_masked))) logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab] h_mult_sent1 = self.activ1(self.fc(output[:, 1])) logits_mclsf1 = self.mclassifier(h_mult_sent1) mult2_token_id = mult_token_id # Assuming mult_token_id is defined globally mult2_positions = (input_ids == mult2_token_id).nonzero(as_tuple=False) # Find positions of [MULT2] tokens # Ensure there are exactly two [MULT] tokens in each input sequence assert mult2_positions.size(0) == 2 * input_ids.size(0) mult2_positions = mult2_positions[1::2][:, 1] # Gather the hidden states corresponding to the second [MULT] token h_mult_sent2 = output[torch.arange(output.size(0)), mult2_positions] logits_mclsf2 = self.mclassifier(h_mult_sent2) logits_mclsf2 = self.mclassifier(h_mult_sent2) return logits_lm, logits_clsf , logits_mclsf1 , logits_mclsf2