MrVicente commited on
Commit
6cf191b
1 Parent(s): be59fc7

added demo base code

Browse files
.gitattributes CHANGED
@@ -29,3 +29,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ kgs_binding/conceptnet/conceptnet_english_noun_2_noun_relations.json filter=lfs diff=lfs merge=lfs -text
33
+ kgs_binding/conceptnet/conceptnet_english_nouns.json filter=lfs diff=lfs merge=lfs -text
34
+ kgs_binding/conceptnet/conceptnet_english_nouns_simple.json filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+
4
+ from inference import RelationsInference
5
+ from utils import KGType,Model_Type
6
+
7
+ #############################
8
+ # Constants
9
+ #############################
10
+
11
+ examples = [["What's the meaning of life?", "eli5", "constraint"],
12
+ ["boat, water, bird", "commongen", "constraint"],
13
+ ["What flows under a bridge?", "commonsense_qa", "constraint"]]
14
+
15
+ bart = RelationsInference(
16
+ model_path='MrVicente/commonsense_bart_commongen',
17
+ kg_type=KGType.CONCEPTNET,
18
+ model_type=Model_Type.RELATIONS,
19
+ max_length=32
20
+ )
21
+
22
+ #############################
23
+ # Helper
24
+ #############################
25
+
26
+ def infer_bart(context, task_type, decoding_type_str):
27
+ response, encoder_attentions, model_input = bart.generate_based_on_context(context, use_kg=False)
28
+ return response[0]
29
+
30
+
31
+ def plot_attention(layer, head):
32
+ fig = plt.figure()
33
+ plt.plot([1, 2, 3], [2, 4, 6])
34
+ plt.title("Things")
35
+ plt.ylabel("Cases")
36
+ plt.xlabel("Days since Day 0")
37
+ return fig
38
+
39
+
40
+ #############################
41
+ # Interface
42
+ #############################
43
+
44
+ app = gr.Blocks()
45
+ with app:
46
+ gr.Markdown(
47
+ """
48
+ # Demo
49
+ ### Test Commonsense Relation-Aware BART (BART-RA) model
50
+
51
+ Tutorial: <br>
52
+ 1) Select the possible model variations and tasks;<br>
53
+ 2) Change the inputs and Click the buttons to produce results;<br>
54
+ 3) See attention visualisations, by choosing a specific layer and head;<br>
55
+ """)
56
+ with gr.Row():
57
+ context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:')
58
+ model_result_output = gr.Textbox(lines=2, label='Model result:')
59
+ with gr.Column():
60
+ task_type_choice = gr.Radio(
61
+ ["eli5", "commongen"], value="eli5", label="What task do you want to try?"
62
+ )
63
+ decoding_type_choice = gr.Radio(
64
+ ["default", "constraint"], value="default", label="What decoding strategy do you want to use?"
65
+ )
66
+ with gr.Row():
67
+ model_btn = gr.Button(value="See Model Results")
68
+ gr.Markdown(
69
+ """
70
+ ---
71
+ Observe Attention
72
+ """
73
+ )
74
+ with gr.Row():
75
+ with gr.Column():
76
+ layer = gr.Slider(0, 11, 0, step=1, label="Layer")
77
+ head = gr.Slider(0, 15, 0, step=1, label="Head")
78
+ with gr.Column():
79
+ plot_output = gr.Plot()
80
+ with gr.Row():
81
+ vis_btn = gr.Button(value="See Attention Scores")
82
+ model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice],
83
+ outputs=[model_result_output])
84
+ vis_btn.click(fn=plot_attention, inputs=[layer, head], outputs=[plot_output])
85
+
86
+ if __name__ == '__main__':
87
+ app.launch()
custom_bart/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bart_attention import BartCustomAttention
2
+ from .bart_mask_attention import BartCustomMaskAttention
3
+ from .bart_for_conditional_generation import BartCustomForConditionalGeneration
4
+ from .bart_model import BartCustomModel
5
+ from .config import BartCustomConfig
6
+ from .custom_constants import BartConstants
7
+ from .decoder import *
8
+ from .decoder_layer import *
9
+ from .encoder import *
10
+ from .encoder_layer import *
11
+ from .bart_generation_mixin import *
12
+ from . import *
custom_bart/attention_utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+
7
+ # Remote modules
8
+ import torch
9
+
10
+ # Local modules
11
+
12
+ #############################
13
+ # Constants
14
+ #############################
15
+
16
+ #############################
17
+ # Stuff
18
+ #############################
19
+
20
+ def find_head_to_mask(heads_mask) -> int:
21
+ head_idx = torch.argmax(heads_mask)
22
+ head_idx_simple = head_idx.item()
23
+ return head_idx_simple
24
+
25
+ def commonsense_attention_mask_update(bsz, n_tokens, commonsense_matrix, attn_weights,
26
+ num_heads=16, specific_head=0):
27
+ commonsense_mask = torch.zeros(
28
+ ((bsz, num_heads, n_tokens, n_tokens))
29
+ )
30
+ attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens))
31
+ zeros = torch.zeros(
32
+ ((bsz, n_tokens, n_tokens))
33
+ )
34
+ head_previous_attention_weights = attn_weights_helper[specific_head]
35
+ attn_weights_helper[specific_head] = zeros
36
+ attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens))
37
+ if commonsense_matrix is None:
38
+ # ignore is not passed (ones -> neutral since multiplication is used)
39
+ commonsense_matrix = torch.ones(
40
+ ((bsz, n_tokens, n_tokens))
41
+ )
42
+ commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
43
+ commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix
44
+ # TODO Stupid conversion
45
+ commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda')
46
+ return attn_weights_helper + commonsense_mask
47
+
48
+ def convert_relations_to_binary_mask(input_relations, should_clone=True):
49
+ relations_binary_mask=input_relations
50
+ if should_clone:
51
+ relations_binary_mask = input_relations.clone()
52
+ relations_binary_mask[relations_binary_mask > 1] = 1
53
+ return relations_binary_mask
54
+
55
+ def relation_binary_2d_to_1d(relations_binary_mask):
56
+ relations_binary_mask = relations_binary_mask.sum(dim=1)
57
+ relations_binary_mask[relations_binary_mask > 1] = 1
58
+ return relations_binary_mask
59
+
60
+ def create_layer_with_commonsense_on_specific_head(relation_binary_mask, bsz, num_heads, specific_head=0):
61
+ n_tokens = relation_binary_mask.size()[-1]
62
+ relations_mask = torch.zeros(
63
+ (bsz, num_heads, n_tokens, n_tokens)
64
+ )
65
+ layer = relations_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
66
+ layer[specific_head] = relation_binary_mask
67
+ layer = layer.reshape((bsz, num_heads, n_tokens, n_tokens))
68
+ return layer
69
+
70
+ def update_weights_regarding_relations_on_specific_head(layer_head_mask, attn_weights, relation_inputs, bsz, num_heads, tgt_len, src_len, verbose=True):
71
+ #layer_head_mask = layer_head_mask.to(attn_weights.device)
72
+ inverse_layer_head_mask = (layer_head_mask.view(num_heads, 1, 1) - 1) * -1
73
+ #inverse_layer_head_mask = inverse_layer_head_mask.to(attn_weights.device)
74
+ #print('layer_head_mask:', layer_head_mask)
75
+ if verbose:
76
+ print("==============================")
77
+ print('layer_head_mask.shape:', layer_head_mask.shape)
78
+ print('inverse_layer_head_mask.shape:', inverse_layer_head_mask.shape)
79
+ print('attn_weights.shape:', attn_weights.shape)
80
+ print('relation_inputs.shape', relation_inputs.shape)
81
+ print("==============================")
82
+ #print('layer_head_mask.device:', layer_head_mask.device)
83
+ #print('inverse_layer_head_mask.device:', inverse_layer_head_mask.device)
84
+ #print('relation_inputs.device:', relation_inputs.device)
85
+ intermediate_weights = inverse_layer_head_mask * attn_weights.view(bsz, num_heads, tgt_len, src_len)
86
+ relation_inputs = convert_relations_to_binary_mask(relation_inputs, should_clone=False)
87
+ relation_weights = layer_head_mask.view(num_heads, 1, 1) * relation_inputs.view(bsz,1,tgt_len, src_len) * attn_weights.view(bsz, num_heads,
88
+ tgt_len, src_len)
89
+ attn_weights = intermediate_weights + relation_weights
90
+ # [batch, n_heads, seq_length, seq_length]
91
+ if verbose:
92
+ print('attn_weights_int.shape', attn_weights.shape)
93
+ return attn_weights
94
+
95
+ """
96
+ def create_commonsense_mask(self, bsz, n_tokens, commonsense_matrix, num_heads=16, specific_head=0):
97
+ commonsense_mask = torch.zeros(
98
+ ((bsz, num_heads, n_tokens, n_tokens))
99
+ )
100
+ if commonsense_matrix is None:
101
+ commonsense_matrix = torch.zeros(
102
+ ((bsz, n_tokens, n_tokens))
103
+ )
104
+ commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
105
+ commonsense_mask[specific_head] = commonsense_matrix
106
+ commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens))
107
+ return commonsense_mask
108
+
109
+ def commonsense_attention_mask_update(self, bsz, n_tokens, commonsense_matrix, attn_weights,
110
+ specific_head=0):
111
+ num_heads = self.num_heads
112
+ commonsense_mask = torch.zeros(
113
+ ((bsz, num_heads, n_tokens, n_tokens))
114
+ )
115
+ attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens))
116
+ zeros = torch.zeros(
117
+ ((bsz, n_tokens, n_tokens))
118
+ )
119
+ head_previous_attention_weights = attn_weights_helper[specific_head]
120
+ attn_weights_helper[specific_head] = zeros
121
+ attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens))
122
+ if commonsense_matrix is None:
123
+ # ignore is not passed (ones -> neutral since multiplication is used)
124
+ commonsense_matrix = torch.ones(
125
+ ((bsz, n_tokens, n_tokens))
126
+ )
127
+ commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
128
+ commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix
129
+ # TODO Stupid conversion
130
+ commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda')
131
+ return attn_weights_helper + commonsense_mask
132
+ """
custom_bart/bart_attention.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import Optional, Tuple
7
+ # Remote modules
8
+ import torch
9
+ from torch import nn
10
+
11
+ # Local modules
12
+ from .attention_utils import (
13
+ create_layer_with_commonsense_on_specific_head,
14
+ find_head_to_mask,
15
+ convert_relations_to_binary_mask,
16
+ update_weights_regarding_relations_on_specific_head
17
+ )
18
+
19
+
20
+ class BartCustomAttention(nn.Module):
21
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
22
+
23
+ def __init__(
24
+ self,
25
+ embed_dim: int,
26
+ num_heads: int,
27
+ dropout: float = 0.0,
28
+ is_decoder: bool = False,
29
+ bias: bool = True,
30
+ num_relation_kinds: int = 0,
31
+ use_same_relation_kv_emb: bool = True,
32
+ heads_mask: Optional[torch.Tensor] = None,
33
+ ):
34
+ super().__init__()
35
+ self.embed_dim = embed_dim
36
+ self.num_heads = num_heads
37
+ self.dropout = dropout
38
+ self.head_dim = embed_dim // num_heads
39
+
40
+ if (self.head_dim * num_heads) != self.embed_dim:
41
+ raise ValueError(
42
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
43
+ f" and `num_heads`: {num_heads})."
44
+ )
45
+ if heads_mask.size() != (self.num_heads,):
46
+ raise ValueError(
47
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {heads_mask.size()}"
48
+ )
49
+ self.heads_mask = heads_mask
50
+
51
+ self.scaling = self.head_dim**-0.5
52
+ self.is_decoder = is_decoder
53
+
54
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
55
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
56
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
57
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
58
+
59
+ self.num_relation_kinds = num_relation_kinds
60
+ self.relation_k_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0)
61
+ if use_same_relation_kv_emb:
62
+ self.relation_v_emb = self.relation_k_emb
63
+ else:
64
+ self.relation_v_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0)
65
+
66
+ self.k_rel_scale = 0.0
67
+ self.v_rel_scale = 1.0
68
+
69
+
70
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
71
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
72
+
73
+ def forward(
74
+ self,
75
+ hidden_states: torch.Tensor,
76
+ key_value_states: Optional[torch.Tensor] = None,
77
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ layer_head_mask: Optional[torch.Tensor] = None,
80
+ output_attentions: bool = False,
81
+ relation_inputs: Optional[torch.Tensor] = None,
82
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
83
+ """Input shape: Batch x Time x Channel"""
84
+
85
+ #print('device:', hidden_states.device)
86
+ # if key_value_states are provided this layer is used as a cross-attention layer
87
+ # for the decoder
88
+ is_cross_attention = key_value_states is not None
89
+
90
+ bsz, tgt_len, embed_dim = hidden_states.size()
91
+
92
+ #print(relation_inputs.shape, 'VS ', (bsz, tgt_len, tgt_len))
93
+ if relation_inputs is None:
94
+ # TODO
95
+ print('oh no')
96
+ relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long()
97
+ print(relation_inputs.shape, ' | ', (bsz, tgt_len, tgt_len))
98
+ assert relation_inputs.shape == (bsz, tgt_len, tgt_len)
99
+
100
+ # (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)
101
+ relation_k_embeds = self.relation_k_emb(relation_inputs)
102
+ relation_v_embeds = self.relation_v_emb(relation_inputs)
103
+
104
+ # get query proj
105
+ query_states = self.q_proj(hidden_states) * self.scaling
106
+ # get key, value proj
107
+ if is_cross_attention and past_key_value is not None:
108
+ # reuse k,v, cross_attentions
109
+ key_states = past_key_value[0]
110
+ value_states = past_key_value[1]
111
+ elif is_cross_attention:
112
+ # cross_attentions
113
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
114
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
115
+ elif past_key_value is not None:
116
+ # reuse k, v, self_attention
117
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
118
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
119
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
120
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
121
+ else:
122
+ # self_attention
123
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
124
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
125
+
126
+ if self.is_decoder:
127
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
128
+ # Further calls to cross_attention layer can then reuse all cross-attention
129
+ # key/value_states (first "if" case)
130
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
131
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
132
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
133
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
134
+ past_key_value = (key_states, value_states)
135
+
136
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
137
+ query_states = self._shape(query_states, tgt_len, bsz)
138
+ src_len = key_states.size(2)
139
+
140
+ # compute scores
141
+ attn_weights = torch.matmul(
142
+ query_states, key_states.transpose(3, 2)
143
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
144
+
145
+ # q_t is [batch, seq_length, n_heads, dim_per_head]
146
+ q_t = query_states.permute(0, 2, 1, 3)
147
+ #print('qt.shape: ', q_t.shape)
148
+ # r_t is [batch, seq_length, dim_per_head, seq_length]
149
+ r_t = relation_k_embeds.transpose(-2, -1)
150
+ #print('rt.shape: ', r_t.shape)
151
+
152
+ q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length]
153
+ q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length]
154
+
155
+ # Make sure impact of relation-aware only apllicable on specific heads (k-part)
156
+
157
+ #print("==========")
158
+ #print('first K: ', q_tr_tmatmul_t.sum())
159
+ """
160
+ q_tr_tmatmul_t = self.layer_heads_relation_attention_update(
161
+ self.heads_mask,
162
+ q_tr_tmatmul_t,
163
+ )
164
+ """
165
+ #print('second K: ', q_tr_tmatmul_t.sum())
166
+ #print("==========")
167
+
168
+ # give weight to influence
169
+ #q_tr_tmatmul_t = 100.0 * q_tr_tmatmul_t
170
+
171
+ # Add to scores
172
+ #print('attn_weights k [before]', attn_weights)
173
+ #print('attn_weights sum k [before]', attn_weights.sum())
174
+ attn_weights += self.k_rel_scale * q_tr_tmatmul_t
175
+ #attn_weights += 100.0 * q_tr_tmatmul_t
176
+ #print('attn_weights k [after]: ', attn_weights)
177
+ #print('attn_weights sum k [after]', attn_weights.sum())
178
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
179
+
180
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
181
+ raise ValueError(
182
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
183
+ )
184
+
185
+ if attention_mask is not None:
186
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
187
+ raise ValueError(
188
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
189
+ )
190
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
191
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
192
+
193
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
194
+
195
+ # Wrong place... gonna comment
196
+ """
197
+ attn_weights = self.layer_heads_relation_attention_update(layer_head_mask,
198
+ relation_inputs,
199
+ attn_weights,
200
+ bsz,
201
+ tgt_len,
202
+ src_len)
203
+ """
204
+ if layer_head_mask is not None:
205
+ if layer_head_mask.size() != (self.num_heads,):
206
+ raise ValueError(
207
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
208
+ )
209
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
210
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
211
+
212
+
213
+ if output_attentions:
214
+ # this operation is a bit awkward, but it's required to
215
+ # make sure that attn_weights keeps its gradient.
216
+ # In order to do so, attn_weights have to be reshaped
217
+ # twice and have to be reused in the following
218
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
219
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
220
+ else:
221
+ attn_weights_reshaped = None
222
+
223
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
224
+
225
+ attn_output = torch.bmm(attn_probs, value_states.view(*proj_shape))
226
+
227
+ #print('attn_probs.shape', attn_probs.shape)
228
+ # w_t is [batch, seq_length, n_heads, seq_length]
229
+ w_t = attn_probs.view(bsz, self.num_heads, tgt_len, src_len).permute(0, 2, 1, 3)
230
+ #print('w_t.shape 1:', w_t.shape)
231
+ #print('relation_v_embeds.shape', relation_v_embeds.shape)
232
+ # [batch, seq_length, n_heads, seq_length]
233
+ w_tr_matmul = torch.matmul(w_t, relation_v_embeds)
234
+ #print('w_tr_matmul.shape 1:', w_tr_matmul.shape)
235
+ #print('w_tr_matmul.shape 2:', w_tr_matmul.shape)
236
+ # Make sure impact of relation-aware only apllicable on specific heads (v-part)
237
+
238
+ #print("==========")
239
+ #print('first V sum: ', w_tr_matmul.sum())
240
+ #print('first V: ', w_tr_matmul[0])
241
+ """
242
+ w_tr_matmul = self.layer_heads_relation_attention_v_update(
243
+ self.heads_mask,
244
+ w_tr_matmul,
245
+ bsz,
246
+ tgt_len,
247
+ )
248
+ """
249
+ w_tr_matmul = self.v_rel_scale * w_tr_matmul
250
+ #print('second V sum: ', w_tr_matmul.sum())
251
+ #print('second V: ', w_tr_matmul[0])
252
+ #print("==========")
253
+
254
+ w_tr_matmul = w_tr_matmul.permute(0, 2, 1, 3)
255
+ w_tr_matmul = w_tr_matmul.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
256
+
257
+ #print('attn_output v [before]', attn_output)
258
+ #print('attn_output sum v [before]', attn_output.sum())
259
+ attn_output += w_tr_matmul
260
+ #attn_output += 100.0 * w_tr_matmul
261
+ #print('attn_output v [after]', attn_output)
262
+ #print('attn_output sum v [after]', attn_output.sum())
263
+ #raise Exception()
264
+
265
+
266
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
267
+ raise ValueError(
268
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
269
+ )
270
+
271
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
272
+ attn_output = attn_output.transpose(1, 2)
273
+
274
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
275
+ # partitioned aross GPUs when using tensor-parallelism.
276
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
277
+
278
+ attn_output = self.out_proj(attn_output)
279
+
280
+ return attn_output, attn_weights_reshaped, past_key_value
281
+
282
+ def layer_heads_relation_attention_update(self,
283
+ layer_head_mask,
284
+ data,
285
+ ):
286
+ if layer_head_mask is not None:
287
+ if layer_head_mask.size() != (self.num_heads,):
288
+ raise ValueError(
289
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
290
+ )
291
+ #print('layer_head_mask:', layer_head_mask)
292
+ masked_weights = layer_head_mask.view(self.num_heads, 1, 1) * data
293
+ return masked_weights
294
+ return data
295
+
296
+ def layer_heads_relation_attention_v_update(self,
297
+ layer_head_mask,
298
+ data,
299
+ bsz,
300
+ tgt_len,
301
+ ):
302
+ if layer_head_mask is not None:
303
+ if layer_head_mask.size() != (self.num_heads,):
304
+ raise ValueError(
305
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
306
+ )
307
+ #relation_binary_mask = convert_relations_to_binary_mask(relation_inputs)
308
+ #one_dimension_mask = relation_binary_mask.sum(-1)
309
+ #relation_binary_mask = convert_relations_to_binary_mask(one_dimension_mask)
310
+ # [16, 128, 16, 64]
311
+ masked_weights = layer_head_mask.view(self.num_heads, 1, 1) * data.view(bsz, self.num_heads, tgt_len, self.head_dim)
312
+ return masked_weights.view(bsz, tgt_len, self.num_heads, self.head_dim)
313
+ return data
custom_bart/bart_for_conditional_generation.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import (
7
+ Optional,
8
+ Tuple,
9
+ Union,
10
+ List,
11
+ )
12
+
13
+ # Remote modules
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import CrossEntropyLoss
17
+ from transformers import (
18
+ BartConfig,
19
+ BartPretrainedModel,
20
+ )
21
+ from transformers.modeling_outputs import Seq2SeqLMOutput
22
+ from transformers.models.bart.modeling_bart import shift_tokens_right
23
+
24
+ from transformers.utils import (
25
+ add_end_docstrings,
26
+ add_start_docstrings,
27
+ add_start_docstrings_to_model_forward,
28
+ logging,
29
+ replace_return_docstrings,
30
+ )
31
+
32
+ from .bart_model import BartCustomModel
33
+ from .config import BartCustomConfig
34
+ from .custom_constants import BartConstants
35
+ from .bart_generation_mixin import GenerationMixin
36
+ from .custom_outputs import CustomSeq2SeqLMOutput
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ @add_start_docstrings(
41
+ "The BART Model with a language modeling head. Can be used for summarization.", BartConstants.BART_START_DOCSTRING
42
+ )
43
+ class BartCustomForConditionalGeneration(BartPretrainedModel, GenerationMixin):
44
+ base_model_prefix = "model"
45
+ _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
46
+
47
+ def __init__(self, config: BartCustomConfig):
48
+ super().__init__(config)
49
+ self.model = BartCustomModel(config)
50
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
51
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
52
+
53
+ # Initialize weights and apply final processing
54
+ self.post_init()
55
+
56
+ def get_encoder(self):
57
+ return self.model.get_encoder()
58
+
59
+ def get_decoder(self):
60
+ return self.model.get_decoder()
61
+
62
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
63
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
64
+ self._resize_final_logits_bias(new_num_tokens)
65
+ return new_embeddings
66
+
67
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
68
+ old_num_tokens = self.final_logits_bias.shape[-1]
69
+ if new_num_tokens <= old_num_tokens:
70
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
71
+ else:
72
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
73
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
74
+ self.register_buffer("final_logits_bias", new_bias)
75
+
76
+ def get_output_embeddings(self):
77
+ return self.lm_head
78
+
79
+ def set_output_embeddings(self, new_embeddings):
80
+ self.lm_head = new_embeddings
81
+
82
+ @add_start_docstrings_to_model_forward(BartConstants.BART_INPUTS_DOCSTRING)
83
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=BartConstants.CONFIG_FOR_DOC)
84
+ @add_end_docstrings(BartConstants.BART_GENERATION_EXAMPLE)
85
+ def forward(
86
+ self,
87
+ input_ids: torch.LongTensor = None,
88
+ attention_mask: Optional[torch.Tensor] = None,
89
+ decoder_input_ids: Optional[torch.LongTensor] = None,
90
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
91
+ head_mask: Optional[torch.Tensor] = None,
92
+ decoder_head_mask: Optional[torch.Tensor] = None,
93
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
94
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
95
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
96
+ inputs_embeds: Optional[torch.FloatTensor] = None,
97
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
98
+ labels: Optional[torch.LongTensor] = None,
99
+ use_cache: Optional[bool] = None,
100
+ output_attentions: Optional[bool] = None,
101
+ output_hidden_states: Optional[bool] = None,
102
+ return_dict: Optional[bool] = None,
103
+ input_commonsense_relations: Optional[torch.Tensor] = None,
104
+ reduce_ce=True,
105
+ ) -> Union[Tuple, CustomSeq2SeqLMOutput]:
106
+ r"""
107
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
108
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
109
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
110
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
111
+
112
+ Returns:
113
+ """
114
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
115
+
116
+ if labels is not None:
117
+ if use_cache:
118
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
119
+ use_cache = False
120
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
121
+ decoder_input_ids = shift_tokens_right(
122
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
123
+ )
124
+ outputs = self.model(
125
+ input_ids,
126
+ attention_mask=attention_mask,
127
+ decoder_input_ids=decoder_input_ids,
128
+ encoder_outputs=encoder_outputs,
129
+ decoder_attention_mask=decoder_attention_mask,
130
+ head_mask=head_mask,
131
+ decoder_head_mask=decoder_head_mask,
132
+ cross_attn_head_mask=cross_attn_head_mask,
133
+ past_key_values=past_key_values,
134
+ inputs_embeds=inputs_embeds,
135
+ decoder_inputs_embeds=decoder_inputs_embeds,
136
+ use_cache=use_cache,
137
+ output_attentions=output_attentions,
138
+ output_hidden_states=output_hidden_states,
139
+ return_dict=return_dict,
140
+ relation_inputs=input_commonsense_relations
141
+ )
142
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
143
+
144
+ masked_lm_loss = None
145
+ if labels is not None:
146
+ loss_fct = CrossEntropyLoss(reduce=reduce_ce, ignore_index=self.config.pad_token_id) # added ignore_index=self.config.pad_token_id
147
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
148
+
149
+ if not return_dict:
150
+ output = (lm_logits,) + outputs[1:]
151
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
152
+
153
+ return CustomSeq2SeqLMOutput(
154
+ loss=masked_lm_loss,
155
+ logits=lm_logits,
156
+ past_key_values=outputs.past_key_values,
157
+ decoder_hidden_states=outputs.decoder_hidden_states,
158
+ decoder_attentions=outputs.decoder_attentions,
159
+ cross_attentions=outputs.cross_attentions,
160
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
161
+ encoder_hidden_states=outputs.encoder_hidden_states,
162
+ encoder_attentions=outputs.encoder_attentions,
163
+ head_mask=outputs.encoder_head_mask
164
+ )
165
+
166
+ def prepare_inputs_for_generation(
167
+ self,
168
+ decoder_input_ids,
169
+ past=None,
170
+ attention_mask=None,
171
+ head_mask=None,
172
+ decoder_head_mask=None,
173
+ cross_attn_head_mask=None,
174
+ use_cache=None,
175
+ encoder_outputs=None,
176
+ **kwargs
177
+ ):
178
+ # cut decoder_input_ids if past is used
179
+ if past is not None:
180
+ decoder_input_ids = decoder_input_ids[:, -1:]
181
+
182
+ return {
183
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
184
+ "encoder_outputs": encoder_outputs,
185
+ "past_key_values": past,
186
+ "decoder_input_ids": decoder_input_ids,
187
+ "attention_mask": attention_mask,
188
+ "head_mask": head_mask,
189
+ "decoder_head_mask": decoder_head_mask,
190
+ "cross_attn_head_mask": cross_attn_head_mask,
191
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
192
+ }
193
+
194
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
195
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
196
+
197
+ @staticmethod
198
+ def _reorder_cache(past, beam_idx):
199
+ reordered_past = ()
200
+ for layer_past in past:
201
+ # cached cross_attention states don't have to be reordered -> they are always the same
202
+ reordered_past += (
203
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
204
+ )
205
+ return reordered_past
custom_bart/bart_generation_mixin.py ADDED
The diff for this file is too large to render. See raw diff
 
custom_bart/bart_mask_attention.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import Optional, Tuple
7
+
8
+ # Remote modules
9
+ import torch
10
+ from torch import nn
11
+
12
+ # Local modules
13
+ from .attention_utils import update_weights_regarding_relations_on_specific_head
14
+
15
+
16
+ class BartCustomMaskAttention(nn.Module):
17
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
18
+
19
+ def __init__(
20
+ self,
21
+ embed_dim: int,
22
+ num_heads: int,
23
+ dropout: float = 0.0,
24
+ is_decoder: bool = False,
25
+ bias: bool = True,
26
+ num_relation_kinds: int = 0,
27
+ heads_mask: Optional[torch.Tensor] = None,
28
+ ):
29
+ super().__init__()
30
+ self.embed_dim = embed_dim
31
+ self.num_heads = num_heads
32
+ self.dropout = dropout
33
+ self.head_dim = embed_dim // num_heads
34
+
35
+ if (self.head_dim * num_heads) != self.embed_dim:
36
+ raise ValueError(
37
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
38
+ f" and `num_heads`: {num_heads})."
39
+ )
40
+ if heads_mask.size() != (self.num_heads,):
41
+ raise ValueError(
42
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {heads_mask.size()}"
43
+ )
44
+ self.heads_mask = heads_mask
45
+
46
+ self.scaling = self.head_dim**-0.5
47
+ self.is_decoder = is_decoder
48
+
49
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
50
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
51
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
52
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
53
+
54
+ self.num_relation_kinds = num_relation_kinds
55
+
56
+
57
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
58
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ key_value_states: Optional[torch.Tensor] = None,
64
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ layer_head_mask: Optional[torch.Tensor] = None,
67
+ output_attentions: bool = False,
68
+ relation_inputs: Optional[torch.Tensor] = None,
69
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
70
+ """Input shape: Batch x Time x Channel"""
71
+
72
+ # if key_value_states are provided this layer is used as a cross-attention layer
73
+ # for the decoder
74
+ is_cross_attention = key_value_states is not None
75
+
76
+ bsz, tgt_len, embed_dim = hidden_states.size()
77
+
78
+ #print(relation_inputs.shape, 'VS ', (bsz, tgt_len, tgt_len))
79
+ if relation_inputs is None:
80
+ # TODO
81
+ relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long()
82
+ assert relation_inputs.shape == (bsz, tgt_len, tgt_len)
83
+
84
+ # get query proj
85
+ query_states = self.q_proj(hidden_states) * self.scaling
86
+ # get key, value proj
87
+ if is_cross_attention and past_key_value is not None:
88
+ # reuse k,v, cross_attentions
89
+ key_states = past_key_value[0]
90
+ value_states = past_key_value[1]
91
+ elif is_cross_attention:
92
+ # cross_attentions
93
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
94
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
95
+ elif past_key_value is not None:
96
+ # reuse k, v, self_attention
97
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
98
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
99
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
100
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
101
+ else:
102
+ # self_attention
103
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
104
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
105
+
106
+ if self.is_decoder:
107
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
108
+ # Further calls to cross_attention layer can then reuse all cross-attention
109
+ # key/value_states (first "if" case)
110
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
111
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
112
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
113
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
114
+ past_key_value = (key_states, value_states)
115
+
116
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
117
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
118
+ key_states = key_states.view(*proj_shape)
119
+ value_states = value_states.view(*proj_shape)
120
+
121
+ src_len = key_states.size(1)
122
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
123
+
124
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
125
+ raise ValueError(
126
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
127
+ )
128
+
129
+ if attention_mask is not None:
130
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
131
+ raise ValueError(
132
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
133
+ )
134
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
135
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
136
+
137
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
138
+
139
+ if self.heads_mask is not None:# and layer_head_mask is not None:
140
+ if self.heads_mask.size() != (self.num_heads,):
141
+ raise ValueError(
142
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
143
+ )
144
+ h_mask = layer_head_mask
145
+ #print('h_mask: ', h_mask)
146
+ if layer_head_mask is None:
147
+ h_mask = self.heads_mask
148
+ #h_mask.to(attn_weights.device)
149
+ attn_weights = update_weights_regarding_relations_on_specific_head(h_mask, attn_weights,
150
+ relation_inputs, bsz, self.num_heads, tgt_len,
151
+ src_len, verbose=False)
152
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
153
+
154
+ elif layer_head_mask is not None:
155
+ if layer_head_mask.size() != (self.num_heads,):
156
+ raise ValueError(
157
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
158
+ )
159
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
160
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
161
+
162
+
163
+ if output_attentions:
164
+ # this operation is a bit awkward, but it's required to
165
+ # make sure that attn_weights keeps its gradient.
166
+ # In order to do so, attn_weights have to be reshaped
167
+ # twice and have to be reused in the following
168
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
169
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
170
+ else:
171
+ attn_weights_reshaped = None
172
+
173
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
174
+
175
+ attn_output = torch.bmm(attn_probs, value_states)
176
+
177
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
178
+ raise ValueError(
179
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
180
+ )
181
+
182
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
183
+ attn_output = attn_output.transpose(1, 2)
184
+
185
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
186
+ # partitioned aross GPUs when using tensor-parallelism.
187
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
188
+
189
+ attn_output = self.out_proj(attn_output)
190
+
191
+ return attn_output, attn_weights_reshaped, past_key_value
192
+
193
+ def find_head_to_mask(self, heads_mask) -> int:
194
+ head_idx = torch.argmax(heads_mask)
195
+ head_idx_simple = head_idx.item()
196
+ return head_idx_simple
197
+
198
+ def create_commonsense_mask(self, bsz, n_tokens, commonsense_matrix, num_heads=16, specific_head=0):
199
+ commonsense_mask = torch.zeros(
200
+ ((bsz, num_heads, n_tokens, n_tokens))
201
+ )
202
+ if commonsense_matrix is None:
203
+ commonsense_matrix = torch.zeros(
204
+ ((bsz, n_tokens, n_tokens))
205
+ )
206
+ commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
207
+ commonsense_mask[specific_head] = commonsense_matrix
208
+ commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens))
209
+ return commonsense_mask
210
+
211
+ def commonsense_attention_mask_update(self, bsz, n_tokens, commonsense_matrix, attn_weights,
212
+ specific_head=0):
213
+ num_heads = self.num_heads
214
+ commonsense_mask = torch.zeros(
215
+ ((bsz, num_heads, n_tokens, n_tokens))
216
+ )
217
+ attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens))
218
+ zeros = torch.zeros(
219
+ ((bsz, n_tokens, n_tokens))
220
+ )
221
+ head_previous_attention_weights = attn_weights_helper[specific_head]
222
+ attn_weights_helper[specific_head] = zeros
223
+ attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens))
224
+ if commonsense_matrix is None:
225
+ # ignore is not passed (ones -> neutral since multiplication is used)
226
+ commonsense_matrix = torch.ones(
227
+ ((bsz, n_tokens, n_tokens))
228
+ )
229
+ commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
230
+ commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix
231
+ # TODO Stupid conversion
232
+ commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda')
233
+ return attn_weights_helper + commonsense_mask
234
+
235
+ def convert_relations_to_binary_mask(self, input_relations):
236
+ relations_binary_mask = input_relations.clone()
237
+ relations_binary_mask[relations_binary_mask > 1] = 1
238
+ return relations_binary_mask
custom_bart/bart_model.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import (
7
+ Optional,
8
+ Tuple,
9
+ Union,
10
+ List,
11
+ )
12
+
13
+ # Remote modules
14
+ import torch
15
+ from torch import nn
16
+ from transformers import (
17
+ BartConfig,
18
+ BartPretrainedModel,
19
+ )
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutput, Seq2SeqModelOutput,
22
+ )
23
+ from transformers.models.bart.modeling_bart import shift_tokens_right
24
+
25
+ from transformers.utils import (
26
+ add_code_sample_docstrings,
27
+ add_end_docstrings,
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ logging,
31
+ replace_return_docstrings,
32
+ )
33
+
34
+ # Local modules
35
+ from .config import BartCustomConfig
36
+ from .encoder import BartCustomEncoder
37
+ from .decoder import BartCustomDecoder
38
+ from .custom_constants import BartConstants
39
+ from .custom_outputs import CustomSeq2SeqModelOutput
40
+
41
+ @add_start_docstrings(
42
+ "The bare BART Model outputting raw hidden-states without any specific head on top.",
43
+ BartConstants.BART_START_DOCSTRING,
44
+ )
45
+ class BartCustomModel(BartPretrainedModel):
46
+ def __init__(self, config: BartCustomConfig):
47
+ super().__init__(config)
48
+
49
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
50
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
51
+
52
+ self.encoder = BartCustomEncoder(config, self.shared)
53
+ self.decoder = BartCustomDecoder(config, self.shared)
54
+
55
+ # Initialize weights and apply final processing
56
+ self.post_init()
57
+
58
+ def get_input_embeddings(self):
59
+ return self.shared
60
+
61
+ def set_input_embeddings(self, value):
62
+ self.shared = value
63
+ self.encoder.embed_tokens = self.shared
64
+ self.decoder.embed_tokens = self.shared
65
+
66
+ def get_encoder(self):
67
+ return self.encoder
68
+
69
+ def get_decoder(self):
70
+ return self.decoder
71
+
72
+ @add_start_docstrings_to_model_forward(BartConstants.BART_INPUTS_DOCSTRING)
73
+ @add_code_sample_docstrings(
74
+ processor_class= BartConstants.TOKENIZER_FOR_DOC,
75
+ checkpoint= BartConstants.CHECKPOINT_FOR_DOC,
76
+ output_type= Seq2SeqModelOutput,
77
+ config_class= BartConstants.CONFIG_FOR_DOC,
78
+ expected_output= BartConstants.EXPECTED_OUTPUT_SHAPE,
79
+ )
80
+ def forward(
81
+ self,
82
+ input_ids: torch.LongTensor = None,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ decoder_input_ids: Optional[torch.LongTensor] = None,
85
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
86
+ head_mask: Optional[torch.Tensor] = None,
87
+ decoder_head_mask: Optional[torch.Tensor] = None,
88
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
89
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
90
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
91
+ inputs_embeds: Optional[torch.FloatTensor] = None,
92
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
93
+ use_cache: Optional[bool] = None,
94
+ output_attentions: Optional[bool] = None,
95
+ output_hidden_states: Optional[bool] = None,
96
+ return_dict: Optional[bool] = None,
97
+ relation_inputs: Optional[torch.Tensor] = None,
98
+ ) -> Union[Tuple, CustomSeq2SeqModelOutput]:
99
+
100
+ # different to other models, Bart automatically creates decoder_input_ids from
101
+ # input_ids if no decoder_input_ids are provided
102
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
103
+ if input_ids is None:
104
+ raise ValueError(
105
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
106
+ "passed, `input_ids` cannot be `None`. Please pass either "
107
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
108
+ )
109
+
110
+ decoder_input_ids = shift_tokens_right(
111
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
112
+ )
113
+
114
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
115
+ output_hidden_states = (
116
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
117
+ )
118
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
119
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
120
+
121
+ if encoder_outputs is None:
122
+ encoder_outputs = self.encoder(
123
+ input_ids=input_ids,
124
+ attention_mask=attention_mask,
125
+ head_mask=head_mask,
126
+ inputs_embeds=inputs_embeds,
127
+ output_attentions=output_attentions,
128
+ output_hidden_states=output_hidden_states,
129
+ return_dict=return_dict,
130
+ relation_inputs=relation_inputs
131
+ )
132
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
133
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
134
+ encoder_outputs = BaseModelOutput(
135
+ last_hidden_state=encoder_outputs[0],
136
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
137
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
138
+ )
139
+
140
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
141
+ decoder_outputs = self.decoder(
142
+ input_ids=decoder_input_ids,
143
+ attention_mask=decoder_attention_mask,
144
+ encoder_hidden_states=encoder_outputs[0],
145
+ encoder_attention_mask=attention_mask,
146
+ head_mask=decoder_head_mask,
147
+ cross_attn_head_mask=cross_attn_head_mask,
148
+ past_key_values=past_key_values,
149
+ inputs_embeds=decoder_inputs_embeds,
150
+ use_cache=use_cache,
151
+ output_attentions=output_attentions,
152
+ output_hidden_states=output_hidden_states,
153
+ return_dict=return_dict,
154
+ )
155
+
156
+ if not return_dict:
157
+ return decoder_outputs + encoder_outputs
158
+
159
+ return CustomSeq2SeqModelOutput(
160
+ last_hidden_state=decoder_outputs.last_hidden_state,
161
+ past_key_values=decoder_outputs.past_key_values,
162
+ decoder_hidden_states=decoder_outputs.hidden_states,
163
+ decoder_attentions=decoder_outputs.attentions,
164
+ cross_attentions=decoder_outputs.cross_attentions,
165
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
166
+ encoder_hidden_states=encoder_outputs.hidden_states,
167
+ encoder_attentions=encoder_outputs.attentions,
168
+ encoder_head_mask=head_mask
169
+ )
custom_bart/bart_onnx.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from collections import OrderedDict
3
+ from typing import Any, Mapping, Optional
4
+
5
+ import torch
6
+ from transformers import PreTrainedTokenizer
7
+ from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
8
+ from transformers.onnx.utils import compute_effective_axis_dimension
9
+ from transformers.utils.generic import TensorType
10
+ from transformers.utils.import_utils import is_torch_available
11
+
12
+ class BartCustumOnnxConfig(OnnxSeq2SeqConfigWithPast):
13
+ @property
14
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
15
+ if self.task in ["default", "seq2seq-lm"]:
16
+ common_inputs = OrderedDict(
17
+ [
18
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
19
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
20
+ ("input_commonsense_relations", {0: "batch", 1: "encoder_sequence", 2: "encoder_sequence"}),
21
+ ]
22
+ )
23
+
24
+ if self.use_past:
25
+ common_inputs["decoder_input_ids"] = {0: "batch"}
26
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
27
+ else:
28
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
29
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
30
+
31
+ if self.use_past:
32
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
33
+ elif self.task == "causal-lm":
34
+ # TODO: figure this case out.
35
+ common_inputs = OrderedDict(
36
+ [
37
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
38
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
39
+ ]
40
+ )
41
+ if self.use_past:
42
+ num_encoder_layers, _ = self.num_layers
43
+ for i in range(num_encoder_layers):
44
+ common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
45
+ common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
46
+ else:
47
+ common_inputs = OrderedDict(
48
+ [
49
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
50
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
51
+ ("input_commonsense_relations", {0: "batch", 2: "encoder_sequence", 3: "encoder_sequence"}),
52
+ ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
53
+ ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
54
+ ]
55
+ )
56
+
57
+ return common_inputs
58
+
59
+ @property
60
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
61
+ if self.task in ["default", "seq2seq-lm"]:
62
+ common_outputs = super().outputs
63
+ else:
64
+ common_outputs = super(OnnxConfigWithPast, self).outputs
65
+ if self.use_past:
66
+ num_encoder_layers, _ = self.num_layers
67
+ for i in range(num_encoder_layers):
68
+ common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
69
+ common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
70
+ return common_outputs
71
+
72
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
73
+ self,
74
+ tokenizer: PreTrainedTokenizer,
75
+ batch_size: int = -1,
76
+ seq_length: int = -1,
77
+ is_pair: bool = False,
78
+ framework: Optional[TensorType] = None,
79
+ ) -> Mapping[str, Any]:
80
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
81
+ tokenizer, batch_size, seq_length, is_pair, framework
82
+ )
83
+
84
+ # Generate decoder inputs
85
+ decoder_seq_length = seq_length if not self.use_past else 1
86
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
87
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
88
+ )
89
+ decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
90
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
91
+
92
+ if self.use_past:
93
+ if not is_torch_available():
94
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
95
+ else:
96
+ import torch
97
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
98
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
99
+ num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
100
+ encoder_shape = (
101
+ batch,
102
+ num_encoder_attention_heads,
103
+ encoder_seq_length,
104
+ self._config.hidden_size // num_encoder_attention_heads,
105
+ )
106
+ decoder_past_length = decoder_seq_length + 3
107
+ decoder_shape = (
108
+ batch,
109
+ num_decoder_attention_heads,
110
+ decoder_past_length,
111
+ self._config.hidden_size // num_decoder_attention_heads,
112
+ )
113
+
114
+ common_inputs["decoder_attention_mask"] = torch.cat(
115
+ [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
116
+ )
117
+
118
+ common_inputs["past_key_values"] = []
119
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
120
+ num_encoder_layers, num_decoder_layers = self.num_layers
121
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
122
+ max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
123
+ remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
124
+
125
+ for _ in range(min_num_layers):
126
+ common_inputs["past_key_values"].append(
127
+ (
128
+ torch.zeros(decoder_shape),
129
+ torch.zeros(decoder_shape),
130
+ torch.zeros(encoder_shape),
131
+ torch.zeros(encoder_shape),
132
+ )
133
+ )
134
+ # TODO: test this.
135
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
136
+ for _ in range(min_num_layers, max_num_layers):
137
+ common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
138
+ return common_inputs
139
+
140
+ def _generate_dummy_inputs_for_causal_lm(
141
+ self,
142
+ tokenizer: PreTrainedTokenizer,
143
+ batch_size: int = -1,
144
+ seq_length: int = -1,
145
+ is_pair: bool = False,
146
+ framework: Optional[TensorType] = None,
147
+ ) -> Mapping[str, Any]:
148
+ common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
149
+ tokenizer, batch_size, seq_length, is_pair, framework
150
+ )
151
+
152
+ if self.use_past:
153
+ if not is_torch_available():
154
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
155
+ else:
156
+ import torch
157
+ batch, seqlen = common_inputs["input_ids"].shape
158
+ # Not using the same length for past_key_values
159
+ past_key_values_length = seqlen + 2
160
+ num_encoder_layers, _ = self.num_layers
161
+ num_encoder_attention_heads, _ = self.num_attention_heads
162
+ past_shape = (
163
+ batch,
164
+ num_encoder_attention_heads,
165
+ past_key_values_length,
166
+ self._config.hidden_size // num_encoder_attention_heads,
167
+ )
168
+
169
+ mask_dtype = common_inputs["attention_mask"].dtype
170
+ common_inputs["attention_mask"] = torch.cat(
171
+ [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
172
+ )
173
+ common_inputs["past_key_values"] = [
174
+ (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
175
+ ]
176
+ return common_inputs
177
+
178
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
179
+ self,
180
+ tokenizer: PreTrainedTokenizer,
181
+ batch_size: int = -1,
182
+ seq_length: int = -1,
183
+ is_pair: bool = False,
184
+ framework: Optional[TensorType] = None,
185
+ ) -> Mapping[str, Any]:
186
+ # Copied from OnnxConfig.generate_dummy_inputs
187
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
188
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
189
+ batch_size = compute_effective_axis_dimension(
190
+ batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
191
+ )
192
+
193
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
194
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
195
+ seq_length = compute_effective_axis_dimension(
196
+ seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
197
+ )
198
+
199
+ # Generate dummy inputs according to compute batch and sequence
200
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
201
+ tmp_seq_length = seq_length + 2
202
+ commonsense_relation= torch.IntTensor([[[0] * tmp_seq_length] * tmp_seq_length]* batch_size)
203
+ common_inputs = dict(tokenizer(dummy_input,
204
+ return_tensors=framework))
205
+ common_inputs['input_commonsense_relations'] = commonsense_relation
206
+ print('here:', common_inputs)
207
+ return common_inputs
208
+
209
+ def generate_dummy_inputs(
210
+ self,
211
+ tokenizer: PreTrainedTokenizer,
212
+ batch_size: int = -1,
213
+ seq_length: int = -1,
214
+ is_pair: bool = False,
215
+ framework: Optional[TensorType] = None,
216
+ ) -> Mapping[str, Any]:
217
+ if self.task in ["default", "seq2seq-lm"]:
218
+ common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
219
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
220
+ )
221
+
222
+ elif self.task == "causal-lm":
223
+ common_inputs = self._generate_dummy_inputs_for_causal_lm(
224
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
225
+ )
226
+ else:
227
+ common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
228
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
229
+ )
230
+ if 'decoder_input_commonsense_relations' in common_inputs:
231
+ del common_inputs['decoder_input_commonsense_relations']
232
+ return common_inputs
233
+
234
+ def _flatten_past_key_values_(self, flattened_output, name, idx, t):
235
+ if self.task in ["default", "seq2seq-lm"]:
236
+ flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
237
+ else:
238
+ flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
239
+ flattened_output, name, idx, t
240
+ )
custom_bart/config.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartConfig
2
+
3
+ class BartCustomConfig(BartConfig):
4
+ def __init__(
5
+ self,
6
+ model_type='bart',
7
+ vocab_size=50265,
8
+ max_position_embeddings=1024,
9
+ encoder_layers=12,
10
+ encoder_ffn_dim=4096,
11
+ encoder_attention_heads=16,
12
+ decoder_layers=12,
13
+ decoder_ffn_dim=4096,
14
+ decoder_attention_heads=16,
15
+ encoder_layerdrop=0.0,
16
+ decoder_layerdrop=0.0,
17
+ activation_function="gelu",
18
+ d_model=1024,
19
+ dropout=0.1,
20
+ attention_dropout=0.1,
21
+ activation_dropout=0.1,
22
+ init_std=0.02,
23
+ classifier_dropout=0.0,
24
+ classif_dropout=0.1,
25
+ scale_embedding=False,
26
+ use_cache=True,
27
+ num_labels=3,
28
+ pad_token_id=1,
29
+ bos_token_id=0,
30
+ eos_token_id=2,
31
+ is_encoder_decoder=True,
32
+ decoder_start_token_id=2,
33
+ forced_eos_token_id=2,
34
+ forced_bos_token_id=0,
35
+ no_repeat_ngram_size=3, # adding
36
+ num_hidden_layers=12,
37
+ normalize_before=False,
38
+ num_beams=4,
39
+ add_bias_logits=False,
40
+ add_final_layer_norm=False,
41
+ early_stopping=True,
42
+ gradient_checkpointing=False,
43
+ num_relation_kinds = 0,
44
+ use_same_relation_kv_emb = True,
45
+ is_simple_mask_commonsense = False,
46
+ should_embed_positions = False,
47
+ heads_mask = None,
48
+ **kwargs
49
+ ):
50
+ super(BartCustomConfig, self).__init__(
51
+ model_type=model_type,
52
+ vocab_size=vocab_size,
53
+ max_position_embeddings=max_position_embeddings,
54
+ encoder_layers=encoder_layers,
55
+ encoder_ffn_dim=encoder_ffn_dim,
56
+ encoder_attention_heads=encoder_attention_heads,
57
+ decoder_layers=decoder_layers,
58
+ decoder_ffn_dim=decoder_ffn_dim,
59
+ decoder_attention_heads=decoder_attention_heads,
60
+ encoder_layerdrop=encoder_layerdrop,
61
+ decoder_layerdrop=decoder_layerdrop,
62
+ activation_function=activation_function,
63
+ d_model=d_model,
64
+ dropout=dropout,
65
+ attention_dropout=attention_dropout,
66
+ activation_dropout=activation_dropout,
67
+ init_std=init_std,
68
+ classifier_dropout=classifier_dropout,
69
+ classif_dropout=classif_dropout,
70
+ scale_embedding=scale_embedding,
71
+ use_cache=use_cache,
72
+ num_labels=num_labels,
73
+ pad_token_id = pad_token_id,
74
+ bos_token_id = bos_token_id,
75
+ eos_token_id = eos_token_id,
76
+ is_encoder_decoder = is_encoder_decoder,
77
+ decoder_start_token_id = decoder_start_token_id,
78
+ forced_eos_token_id = forced_eos_token_id,
79
+ forced_bos_token_id=forced_bos_token_id,
80
+ no_repeat_ngram_size=no_repeat_ngram_size, # Adding
81
+ normalize_before=normalize_before,
82
+ num_hidden_layers=num_hidden_layers,
83
+ num_beams=num_beams,
84
+ add_bias_logits=add_bias_logits,
85
+ add_final_layer_norm=add_final_layer_norm,
86
+ early_stopping=early_stopping,
87
+ gradient_checkpointing=gradient_checkpointing,
88
+ num_relation_kinds = num_relation_kinds,
89
+ use_same_relation_kv_emb = use_same_relation_kv_emb,
90
+ is_simple_mask_commonsense = is_simple_mask_commonsense,
91
+ heads_mask = None,
92
+ should_embed_positions=False,
93
+ **kwargs
94
+ )
95
+ self.num_relation_kinds = num_relation_kinds
96
+ self.use_same_relation_kv_emb = use_same_relation_kv_emb
97
+ self.is_simple_mask_commonsense = is_simple_mask_commonsense
98
+ self.heads_mask = heads_mask
99
+ self.should_embed_positions = should_embed_positions
100
+
101
+ class BartSmallCustomConfig(BartConfig):
102
+ def __init__(
103
+ self,
104
+ vocab_size=50265,
105
+ max_position_embeddings=1024,
106
+ encoder_layers=6,
107
+ encoder_ffn_dim=3072,
108
+ encoder_attention_heads=12,
109
+ decoder_layers=12,
110
+ decoder_ffn_dim=3072,
111
+ decoder_attention_heads=12,
112
+ encoder_layerdrop=0.0,
113
+ decoder_layerdrop=0.0,
114
+ activation_function="gelu",
115
+ d_model=768,
116
+ dropout=0.1,
117
+ attention_dropout=0.1,
118
+ activation_dropout=0.1,
119
+ init_std=0.02,
120
+ classifier_dropout=0.0,
121
+ classif_dropout= 0.1,
122
+ scale_embedding=False,
123
+ use_cache=True,
124
+ num_labels=3,
125
+ pad_token_id=1,
126
+ bos_token_id=0,
127
+ eos_token_id=2,
128
+ is_encoder_decoder=True,
129
+ decoder_start_token_id=2,
130
+ forced_eos_token_id=2,
131
+ forced_bos_token_id=0,
132
+ no_repeat_ngram_size=3, #adding
133
+ num_hidden_layers=6,
134
+ normalize_before=False,
135
+ num_beams=4,
136
+ add_bias_logits=False,
137
+ add_final_layer_norm=False,
138
+ _name_or_path="bart-base",
139
+ early_stopping=True,
140
+ gradient_checkpointing=False,
141
+ num_relation_kinds = 0,
142
+ use_same_relation_kv_emb = True,
143
+ is_simple_mask_commonsense = False,
144
+ should_embed_positions = True,
145
+ heads_mask = None,
146
+ **kwargs
147
+ ):
148
+ super(BartSmallCustomConfig, self).__init__(
149
+ vocab_size=vocab_size,
150
+ max_position_embeddings=max_position_embeddings,
151
+ encoder_layers=encoder_layers,
152
+ encoder_ffn_dim=encoder_ffn_dim,
153
+ encoder_attention_heads=encoder_attention_heads,
154
+ decoder_layers=decoder_layers,
155
+ decoder_ffn_dim=decoder_ffn_dim,
156
+ decoder_attention_heads=decoder_attention_heads,
157
+ encoder_layerdrop=encoder_layerdrop,
158
+ decoder_layerdrop=decoder_layerdrop,
159
+ activation_function=activation_function,
160
+ d_model=d_model,
161
+ dropout=dropout,
162
+ attention_dropout=attention_dropout,
163
+ activation_dropout=activation_dropout,
164
+ init_std=init_std,
165
+ classifier_dropout=classifier_dropout,
166
+ classif_dropout=classif_dropout,
167
+ scale_embedding=scale_embedding,
168
+ use_cache=use_cache,
169
+ num_labels=num_labels,
170
+ pad_token_id = pad_token_id,
171
+ bos_token_id = bos_token_id,
172
+ eos_token_id = eos_token_id,
173
+ is_encoder_decoder = is_encoder_decoder,
174
+ decoder_start_token_id = decoder_start_token_id,
175
+ forced_eos_token_id = forced_eos_token_id,
176
+ forced_bos_token_id=forced_bos_token_id,
177
+ no_repeat_ngram_size = no_repeat_ngram_size, #Adding
178
+ normalize_before = normalize_before,
179
+ num_hidden_layers=num_hidden_layers,
180
+ num_beams=num_beams,
181
+ add_bias_logits=add_bias_logits,
182
+ add_final_layer_norm=add_final_layer_norm,
183
+ _name_or_path=_name_or_path,
184
+ early_stopping=early_stopping,
185
+ gradient_checkpointing=gradient_checkpointing,
186
+ num_relation_kinds = num_relation_kinds,
187
+ use_same_relation_kv_emb = use_same_relation_kv_emb,
188
+ is_simple_mask_commonsense = is_simple_mask_commonsense,
189
+ heads_mask = heads_mask,
190
+ should_embed_positions=should_embed_positions,
191
+ **kwargs
192
+ )
193
+ self.num_relation_kinds = num_relation_kinds
194
+ self.use_same_relation_kv_emb = use_same_relation_kv_emb
195
+ self.is_simple_mask_commonsense = is_simple_mask_commonsense
196
+ self.heads_mask = heads_mask
197
+ self.should_embed_positions = should_embed_positions
custom_bart/custom_constants.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class BartConstants:
3
+ CHECKPOINT_FOR_DOC = "facebook/bart-base"
4
+ CONFIG_FOR_DOC = "BartConfig"
5
+ TOKENIZER_FOR_DOC = "BartTokenizer"
6
+
7
+ # Base model docstring
8
+ EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
9
+
10
+ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
11
+ "facebook/bart-large",
12
+ ]
13
+
14
+ BART_START_DOCSTRING = r"""
15
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
16
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
17
+ etc.)
18
+
19
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
20
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
21
+ and behavior.
22
+
23
+ Parameters:
24
+ config ([`BartConfig`]):
25
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
26
+ load the weights associated with the model, only the configuration. Check out the
27
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
28
+ """
29
+ BART_INPUTS_DOCSTRING = r"""
30
+ Args:
31
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
32
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
33
+ it.
34
+
35
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
36
+ [`PreTrainedTokenizer.__call__`] for details.
37
+
38
+ [What are input IDs?](../glossary#input-ids)
39
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
40
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
41
+
42
+ - 1 for tokens that are **not masked**,
43
+ - 0 for tokens that are **masked**.
44
+
45
+ [What are attention masks?](../glossary#attention-mask)
46
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
47
+ Indices of decoder input sequence tokens in the vocabulary.
48
+
49
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
50
+ [`PreTrainedTokenizer.__call__`] for details.
51
+
52
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
53
+
54
+ Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
55
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
56
+
57
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
58
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
59
+ for denoising pre-training following the paper.
60
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
61
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
62
+ be used by default.
63
+
64
+ If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_inputs`] and
65
+ modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information
66
+ on the default strategy.
67
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
68
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
69
+
70
+ - 1 indicates the head is **not masked**,
71
+ - 0 indicates the head is **masked**.
72
+
73
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
74
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
75
+
76
+ - 1 indicates the head is **not masked**,
77
+ - 0 indicates the head is **masked**.
78
+
79
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
80
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
81
+ 1]`:
82
+
83
+ - 1 indicates the head is **not masked**,
84
+ - 0 indicates the head is **masked**.
85
+
86
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
87
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
88
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
89
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
90
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
91
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
92
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
93
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
94
+
95
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
96
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
97
+
98
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
99
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
100
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
101
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
102
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
103
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
104
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
105
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
106
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
107
+ input (see `past_key_values`). This is useful if you want more control over how to convert
108
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
109
+
110
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
111
+ of `inputs_embeds`.
112
+ use_cache (`bool`, *optional*):
113
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
114
+ `past_key_values`).
115
+ output_attentions (`bool`, *optional*):
116
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
117
+ tensors for more detail.
118
+ output_hidden_states (`bool`, *optional*):
119
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
120
+ more detail.
121
+ return_dict (`bool`, *optional*):
122
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
123
+ """
124
+ BART_GENERATION_EXAMPLE = r"""
125
+ Summarization example:
126
+
127
+ ```python
128
+ >>> from transformers import BartTokenizer, BartForConditionalGeneration
129
+
130
+ >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
131
+ >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
132
+
133
+ >>> ARTICLE_TO_SUMMARIZE = (
134
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
135
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
136
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
137
+ ... )
138
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
139
+
140
+ >>> # Generate Summary
141
+ >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
142
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
143
+ 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
144
+ ```
145
+
146
+ Mask filling example:
147
+
148
+ ```python
149
+ >>> from transformers import BartTokenizer, BartForConditionalGeneration
150
+
151
+ >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
152
+ >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
153
+
154
+ >>> TXT = "My friends are <mask> but they eat too many carbs."
155
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
156
+ >>> logits = model(input_ids).logits
157
+
158
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
159
+ >>> probs = logits[0, masked_index].softmax(dim=0)
160
+ >>> values, predictions = probs.topk(5)
161
+
162
+ >>> tokenizer.decode(predictions).split()
163
+ ['not', 'good', 'healthy', 'great', 'very']
164
+ ```
165
+ """
166
+
167
+
168
+
custom_bart/custom_outputs.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ # Remote modules
10
+ import torch
11
+ from transformers.modeling_outputs import ModelOutput
12
+
13
+ # Local modules
14
+
15
+ #############################
16
+ # Constants
17
+ #############################
18
+
19
+ #############################
20
+ # Stuff
21
+ #############################
22
+
23
+ @dataclass
24
+ class CustomSeq2SeqLMOutput(ModelOutput):
25
+ """
26
+ Base class for sequence-to-sequence language models outputs.
27
+
28
+ Args:
29
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
30
+ Language modeling loss.
31
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
32
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
33
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
34
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
35
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
36
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
37
+
38
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
39
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
40
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
41
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
42
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
43
+
44
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
45
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
46
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
47
+ sequence_length)`.
48
+
49
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
50
+ self-attention heads.
51
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
52
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
53
+ sequence_length)`.
54
+
55
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
56
+ weighted average in the cross-attention heads.
57
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
58
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
59
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
60
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
61
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
62
+
63
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
64
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
65
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
66
+ sequence_length)`.
67
+
68
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
69
+ self-attention heads.
70
+ """
71
+
72
+ loss: Optional[torch.FloatTensor] = None
73
+ logits: torch.FloatTensor = None
74
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
75
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
76
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
77
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
78
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
79
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
80
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
81
+ head_mask: Optional[Tuple[torch.FloatTensor]] = None
82
+
83
+ @dataclass
84
+ class CustomSeq2SeqModelOutput(ModelOutput):
85
+ """
86
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
87
+ decoding.
88
+
89
+ Args:
90
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
91
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
92
+
93
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
94
+ hidden_size)` is output.
95
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
96
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
97
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
98
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
99
+
100
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
101
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
102
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
103
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
104
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
105
+
106
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
107
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
108
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
109
+ sequence_length)`.
110
+
111
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
112
+ self-attention heads.
113
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
114
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
115
+ sequence_length)`.
116
+
117
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
118
+ weighted average in the cross-attention heads.
119
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
120
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
121
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
122
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
123
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
124
+
125
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
126
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
127
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
128
+ sequence_length)`.
129
+
130
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
131
+ self-attention heads.
132
+ """
133
+
134
+ last_hidden_state: torch.FloatTensor = None
135
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
136
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
137
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
138
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
139
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
140
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
141
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
142
+ encoder_head_mask: Optional[Tuple[torch.FloatTensor]] = None
custom_bart/decoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import (
7
+ Optional,
8
+ Tuple,
9
+ Union,
10
+ List,
11
+ )
12
+ import math
13
+ import random
14
+
15
+ # Remote modules
16
+ import torch
17
+ from torch import nn
18
+ from transformers import (
19
+ BartConfig,
20
+ BartPretrainedModel,
21
+ )
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutput,
24
+ BaseModelOutputWithPastAndCrossAttentions
25
+ )
26
+ from transformers.models.bart.modeling_bart import (
27
+ BartLearnedPositionalEmbedding,
28
+ _expand_mask,
29
+ _make_causal_mask
30
+ )
31
+ from transformers.utils import (
32
+ logging,
33
+ )
34
+
35
+ # Local modules
36
+ from .config import BartCustomConfig
37
+ from .decoder_layer import BartCustomDecoderLayer
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ class BartCustomDecoder(BartPretrainedModel):
42
+ """
43
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
44
+
45
+ Args:
46
+ config: BartConfig
47
+ embed_tokens (nn.Embedding): output embedding
48
+ """
49
+
50
+ def __init__(self, config: BartCustomConfig, embed_tokens: Optional[nn.Embedding] = None):
51
+ super().__init__(config)
52
+ self.dropout = config.dropout
53
+ self.layerdrop = config.decoder_layerdrop
54
+ self.padding_idx = config.pad_token_id
55
+ self.max_target_positions = config.max_position_embeddings
56
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
57
+
58
+ if embed_tokens is not None:
59
+ self.embed_tokens = embed_tokens
60
+ else:
61
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
62
+
63
+ self.embed_positions = BartLearnedPositionalEmbedding(
64
+ config.max_position_embeddings,
65
+ config.d_model,
66
+ )
67
+ self.layers = nn.ModuleList([BartCustomDecoderLayer(config) for _ in range(config.decoder_layers)])
68
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
69
+
70
+ self.gradient_checkpointing = False
71
+ # Initialize weights and apply final processing
72
+ self.post_init()
73
+
74
+ def get_input_embeddings(self):
75
+ return self.embed_tokens
76
+
77
+ def set_input_embeddings(self, value):
78
+ self.embed_tokens = value
79
+
80
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
81
+ # create causal mask
82
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
83
+ combined_attention_mask = None
84
+ if input_shape[-1] > 1:
85
+ combined_attention_mask = _make_causal_mask(
86
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
87
+ ).to(self.device)
88
+
89
+ if attention_mask is not None:
90
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
91
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
92
+ combined_attention_mask = (
93
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
94
+ )
95
+
96
+ return combined_attention_mask
97
+
98
+ def forward(
99
+ self,
100
+ input_ids: torch.LongTensor = None,
101
+ attention_mask: Optional[torch.Tensor] = None,
102
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
103
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
104
+ head_mask: Optional[torch.Tensor] = None,
105
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
106
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
107
+ inputs_embeds: Optional[torch.FloatTensor] = None,
108
+ use_cache: Optional[bool] = None,
109
+ output_attentions: Optional[bool] = None,
110
+ output_hidden_states: Optional[bool] = None,
111
+ return_dict: Optional[bool] = None,
112
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
113
+ r"""
114
+ Args:
115
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
116
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
117
+ provide it.
118
+
119
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
120
+ [`PreTrainedTokenizer.__call__`] for details.
121
+
122
+ [What are input IDs?](../glossary#input-ids)
123
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
124
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
125
+
126
+ - 1 for tokens that are **not masked**,
127
+ - 0 for tokens that are **masked**.
128
+
129
+ [What are attention masks?](../glossary#attention-mask)
130
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
131
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
132
+ of the decoder.
133
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
134
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
135
+ selected in `[0, 1]`:
136
+
137
+ - 1 for tokens that are **not masked**,
138
+ - 0 for tokens that are **masked**.
139
+
140
+ [What are attention masks?](../glossary#attention-mask)
141
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
142
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
143
+
144
+ - 1 indicates the head is **not masked**,
145
+ - 0 indicates the head is **masked**.
146
+
147
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
148
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
149
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
150
+
151
+ - 1 indicates the head is **not masked**,
152
+ - 0 indicates the head is **masked**.
153
+
154
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
155
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
156
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
157
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
158
+
159
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
160
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
161
+
162
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
163
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
164
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
165
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
166
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
167
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
168
+ embedding lookup matrix.
169
+ output_attentions (`bool`, *optional*):
170
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
171
+ returned tensors for more detail.
172
+ output_hidden_states (`bool`, *optional*):
173
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
174
+ for more detail.
175
+ return_dict (`bool`, *optional*):
176
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
177
+ """
178
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
179
+ output_hidden_states = (
180
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
181
+ )
182
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
183
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
184
+
185
+ # retrieve input_ids and inputs_embeds
186
+ if input_ids is not None and inputs_embeds is not None:
187
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
188
+ elif input_ids is not None:
189
+ input_shape = input_ids.size()
190
+ input_ids = input_ids.view(-1, input_shape[-1])
191
+ elif inputs_embeds is not None:
192
+ input_shape = inputs_embeds.size()[:-1]
193
+ else:
194
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
195
+
196
+ # past_key_values_length
197
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
198
+
199
+ if inputs_embeds is None:
200
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
201
+
202
+ attention_mask = self._prepare_decoder_attention_mask(
203
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
204
+ )
205
+
206
+ # expand encoder attention mask
207
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
208
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
209
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
210
+
211
+ # embed positions
212
+ positions = self.embed_positions(input_shape, past_key_values_length)
213
+
214
+ hidden_states = inputs_embeds + positions
215
+ hidden_states = self.layernorm_embedding(hidden_states)
216
+
217
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
218
+
219
+ # decoder layers
220
+ all_hidden_states = () if output_hidden_states else None
221
+ all_self_attns = () if output_attentions else None
222
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
223
+ next_decoder_cache = () if use_cache else None
224
+
225
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
226
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
227
+ if attn_mask is not None:
228
+ if attn_mask.size()[0] != (len(self.layers)):
229
+ raise ValueError(
230
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
231
+ )
232
+
233
+ for idx, decoder_layer in enumerate(self.layers):
234
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
235
+ if output_hidden_states:
236
+ all_hidden_states += (hidden_states,)
237
+ dropout_probability = random.uniform(0, 1)
238
+ if self.training and (dropout_probability < self.layerdrop):
239
+ continue
240
+
241
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
242
+
243
+ if self.gradient_checkpointing and self.training:
244
+
245
+ if use_cache:
246
+ logger.warning(
247
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
248
+ )
249
+ use_cache = False
250
+
251
+ def create_custom_forward(module):
252
+ def custom_forward(*inputs):
253
+ # None for past_key_value
254
+ return module(*inputs, output_attentions, use_cache)
255
+
256
+ return custom_forward
257
+
258
+ layer_outputs = torch.utils.checkpoint.checkpoint(
259
+ create_custom_forward(decoder_layer),
260
+ hidden_states,
261
+ attention_mask,
262
+ encoder_hidden_states,
263
+ encoder_attention_mask,
264
+ head_mask[idx] if head_mask is not None else None,
265
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
266
+ None,
267
+ )
268
+ else:
269
+
270
+ layer_outputs = decoder_layer(
271
+ hidden_states,
272
+ attention_mask=attention_mask,
273
+ encoder_hidden_states=encoder_hidden_states,
274
+ encoder_attention_mask=encoder_attention_mask,
275
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
276
+ cross_attn_layer_head_mask=(
277
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
278
+ ),
279
+ past_key_value=past_key_value,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ )
283
+ hidden_states = layer_outputs[0]
284
+
285
+ if use_cache:
286
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
287
+
288
+ if output_attentions:
289
+ all_self_attns += (layer_outputs[1],)
290
+
291
+ if encoder_hidden_states is not None:
292
+ all_cross_attentions += (layer_outputs[2],)
293
+
294
+ # add hidden states from the last decoder layer
295
+ if output_hidden_states:
296
+ all_hidden_states += (hidden_states,)
297
+
298
+ next_cache = next_decoder_cache if use_cache else None
299
+ if not return_dict:
300
+ return tuple(
301
+ v
302
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
303
+ if v is not None
304
+ )
305
+ return BaseModelOutputWithPastAndCrossAttentions(
306
+ last_hidden_state=hidden_states,
307
+ past_key_values=next_cache,
308
+ hidden_states=all_hidden_states,
309
+ attentions=all_self_attns,
310
+ cross_attentions=all_cross_attentions,
311
+ )
312
+
custom_bart/decoder_layer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import Optional, Tuple
7
+
8
+ # Remote modules
9
+ import torch
10
+ from torch import nn
11
+ from transformers import BartConfig
12
+ from transformers.activations import ACT2FN
13
+
14
+ # Local modules
15
+ from transformers.models.bart.modeling_bart import BartAttention
16
+
17
+ from .config import BartCustomConfig
18
+
19
+
20
+ class BartCustomDecoderLayer(nn.Module):
21
+ def __init__(self, config: BartCustomConfig):
22
+ super().__init__()
23
+ self.embed_dim = config.d_model
24
+
25
+ self.self_attn = BartAttention(
26
+ embed_dim=self.embed_dim,
27
+ num_heads=config.decoder_attention_heads,
28
+ dropout=config.attention_dropout,
29
+ is_decoder=True,
30
+ )
31
+ self.dropout = config.dropout
32
+ self.activation_fn = ACT2FN[config.activation_function]
33
+ self.activation_dropout = config.activation_dropout
34
+
35
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
36
+ self.encoder_attn = BartAttention(
37
+ self.embed_dim,
38
+ config.decoder_attention_heads,
39
+ dropout=config.attention_dropout,
40
+ is_decoder=True,
41
+ )
42
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
43
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
44
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
45
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
46
+
47
+ def forward(
48
+ self,
49
+ hidden_states: torch.Tensor,
50
+ attention_mask: Optional[torch.Tensor] = None,
51
+ encoder_hidden_states: Optional[torch.Tensor] = None,
52
+ encoder_attention_mask: Optional[torch.Tensor] = None,
53
+ layer_head_mask: Optional[torch.Tensor] = None,
54
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
55
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
56
+ output_attentions: Optional[bool] = False,
57
+ use_cache: Optional[bool] = True,
58
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
59
+ """
60
+ Args:
61
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
62
+ attention_mask (`torch.FloatTensor`): attention mask of size
63
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
64
+ encoder_hidden_states (`torch.FloatTensor`):
65
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
66
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
67
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
68
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
69
+ `(encoder_attention_heads,)`.
70
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
71
+ size `(decoder_attention_heads,)`.
72
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
73
+ output_attentions (`bool`, *optional*):
74
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
75
+ returned tensors for more detail.
76
+ """
77
+ residual = hidden_states
78
+
79
+ # Self Attention
80
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
81
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
82
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
83
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
84
+ hidden_states=hidden_states,
85
+ past_key_value=self_attn_past_key_value,
86
+ attention_mask=attention_mask,
87
+ layer_head_mask=layer_head_mask,
88
+ output_attentions=output_attentions,
89
+ )
90
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
91
+ hidden_states = residual + hidden_states
92
+ hidden_states = self.self_attn_layer_norm(hidden_states)
93
+
94
+ # Cross-Attention Block
95
+ cross_attn_present_key_value = None
96
+ cross_attn_weights = None
97
+ if encoder_hidden_states is not None:
98
+ residual = hidden_states
99
+
100
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
101
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
102
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
103
+ hidden_states=hidden_states,
104
+ key_value_states=encoder_hidden_states,
105
+ attention_mask=encoder_attention_mask,
106
+ layer_head_mask=cross_attn_layer_head_mask,
107
+ past_key_value=cross_attn_past_key_value,
108
+ output_attentions=output_attentions,
109
+ )
110
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
111
+ hidden_states = residual + hidden_states
112
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
113
+
114
+ # add cross-attn to positions 3,4 of present_key_value tuple
115
+ present_key_value = present_key_value + cross_attn_present_key_value
116
+
117
+ # Fully Connected
118
+ residual = hidden_states
119
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
120
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
121
+ hidden_states = self.fc2(hidden_states)
122
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
123
+ hidden_states = residual + hidden_states
124
+ hidden_states = self.final_layer_norm(hidden_states)
125
+
126
+ outputs = (hidden_states,)
127
+
128
+ if output_attentions:
129
+ outputs += (self_attn_weights, cross_attn_weights)
130
+
131
+ if use_cache:
132
+ outputs += (present_key_value,)
133
+
134
+ return outputs
custom_bart/encoder.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import (
7
+ Optional,
8
+ Tuple,
9
+ Union,
10
+ )
11
+ import math
12
+ import random
13
+
14
+ # Remote modules
15
+ import torch
16
+ from torch import nn
17
+ from transformers import (
18
+ BartConfig,
19
+ BartPretrainedModel,
20
+ )
21
+ from transformers.modeling_outputs import BaseModelOutput
22
+ from transformers.models.bart.modeling_bart import (
23
+ BartLearnedPositionalEmbedding,
24
+ _expand_mask
25
+ )
26
+
27
+ # Local modules
28
+ from .config import BartCustomConfig
29
+ from .encoder_layer import BartCustomEncoderLayer
30
+
31
+
32
+ class BartCustomEncoder(BartPretrainedModel):
33
+ """
34
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
35
+ [`BartEncoderLayer`].
36
+
37
+ Args:
38
+ config: BartConfig
39
+ embed_tokens (nn.Embedding): output embedding
40
+ """
41
+
42
+ def __init__(self, config: BartCustomConfig, embed_tokens: Optional[nn.Embedding] = None):
43
+ super().__init__(config)
44
+
45
+ self.dropout = config.dropout
46
+ self.layerdrop = config.encoder_layerdrop
47
+
48
+ embed_dim = config.d_model
49
+ self.padding_idx = config.pad_token_id
50
+ self.max_source_positions = config.max_position_embeddings
51
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
52
+
53
+ if embed_tokens is not None:
54
+ self.embed_tokens = embed_tokens
55
+ else:
56
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
57
+
58
+ if not config.should_embed_positions:
59
+ self.embed_positions = None
60
+ else:
61
+ self.embed_positions = BartLearnedPositionalEmbedding(
62
+ config.max_position_embeddings,
63
+ embed_dim,
64
+ )
65
+ device = self.device
66
+ self.layers = nn.ModuleList([BartCustomEncoderLayer(config, heads_mask=torch.Tensor(config.heads_mask[i]).to(device))
67
+ for i in range(config.encoder_layers)])
68
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
69
+
70
+ self.gradient_checkpointing = False
71
+ # Initialize weights and apply final processing
72
+ self.post_init()
73
+ self.run_config = config
74
+
75
+ def get_input_embeddings(self):
76
+ return self.embed_tokens
77
+
78
+ def set_input_embeddings(self, value):
79
+ self.embed_tokens = value
80
+
81
+ def forward(
82
+ self,
83
+ input_ids: torch.LongTensor = None,
84
+ attention_mask: Optional[torch.Tensor] = None,
85
+ head_mask: Optional[torch.Tensor] = None,
86
+ inputs_embeds: Optional[torch.FloatTensor] = None,
87
+ output_attentions: Optional[bool] = None,
88
+ output_hidden_states: Optional[bool] = None,
89
+ return_dict: Optional[bool] = None,
90
+ relation_inputs: Optional[torch.Tensor] = None,
91
+ ) -> Union[Tuple, BaseModelOutput]:
92
+ r"""
93
+ Args:
94
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
95
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
96
+ provide it.
97
+
98
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
99
+ [`PreTrainedTokenizer.__call__`] for details.
100
+
101
+ [What are input IDs?](../glossary#input-ids)
102
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
103
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
104
+
105
+ - 1 for tokens that are **not masked**,
106
+ - 0 for tokens that are **masked**.
107
+
108
+ [What are attention masks?](../glossary#attention-mask)
109
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
110
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
111
+
112
+ - 1 indicates the head is **not masked**,
113
+ - 0 indicates the head is **masked**.
114
+
115
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
116
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
117
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
118
+ than the model's internal embedding lookup matrix.
119
+ output_attentions (`bool`, *optional*):
120
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
121
+ returned tensors for more detail.
122
+ output_hidden_states (`bool`, *optional*):
123
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
124
+ for more detail.
125
+ return_dict (`bool`, *optional*):
126
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
127
+ """
128
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
129
+ output_hidden_states = (
130
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
131
+ )
132
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
133
+
134
+ # retrieve input_ids and inputs_embeds
135
+ if input_ids is not None and inputs_embeds is not None:
136
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
137
+ elif input_ids is not None:
138
+ input_shape = input_ids.size()
139
+ input_ids = input_ids.view(-1, input_shape[-1])
140
+ elif inputs_embeds is not None:
141
+ input_shape = inputs_embeds.size()[:-1]
142
+ else:
143
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
144
+
145
+ if inputs_embeds is None:
146
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
147
+
148
+ # Important for datasets which the order of words deoes not matter(eg: commongen)
149
+ if self.run_config.should_embed_positions:
150
+ embed_pos = self.embed_positions(input_shape)
151
+ hidden_states = inputs_embeds + embed_pos
152
+ else:
153
+ hidden_states = inputs_embeds
154
+
155
+ hidden_states = self.layernorm_embedding(hidden_states)
156
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
157
+
158
+ # expand attention_mask
159
+ if attention_mask is not None:
160
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
161
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
162
+
163
+ encoder_states = () if output_hidden_states else None
164
+ all_attentions = () if output_attentions else None
165
+
166
+ # check if head_mask has a correct number of layers specified if desired
167
+ if head_mask is not None:
168
+ if head_mask.size()[0] != (len(self.layers)):
169
+ raise ValueError(
170
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
171
+ )
172
+
173
+ for idx, encoder_layer in enumerate(self.layers):
174
+ if output_hidden_states:
175
+ encoder_states = encoder_states + (hidden_states,)
176
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
177
+ dropout_probability = random.uniform(0, 1)
178
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
179
+ layer_outputs = (None, None)
180
+ else:
181
+ if self.gradient_checkpointing and self.training:
182
+
183
+ def create_custom_forward(module):
184
+ def custom_forward(*inputs):
185
+ return module(*inputs, output_attentions, relation_inputs=relation_inputs)
186
+
187
+ return custom_forward
188
+
189
+ layer_outputs = torch.utils.checkpoint.checkpoint(
190
+ create_custom_forward(encoder_layer),
191
+ hidden_states,
192
+ attention_mask,
193
+ (head_mask[idx] if head_mask is not None else None),
194
+ )
195
+ else:
196
+ layer_outputs = encoder_layer(
197
+ hidden_states,
198
+ attention_mask,
199
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
200
+ output_attentions=output_attentions,
201
+ relation_inputs=relation_inputs,
202
+ )
203
+
204
+ hidden_states = layer_outputs[0]
205
+
206
+ if output_attentions:
207
+ all_attentions = all_attentions + (layer_outputs[1],)
208
+
209
+ if output_hidden_states:
210
+ encoder_states = encoder_states + (hidden_states,)
211
+
212
+ if not return_dict:
213
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
214
+ return BaseModelOutput(
215
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
216
+ )
custom_bart/encoder_layer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import Optional, Tuple
7
+
8
+ # Remote modules
9
+ import torch
10
+ from torch import nn
11
+ from transformers import BartConfig
12
+ from transformers.activations import ACT2FN
13
+
14
+ # Local modules
15
+ from .bart_attention import BartCustomAttention
16
+ from .bart_mask_attention import BartCustomMaskAttention
17
+ from .config import BartCustomConfig
18
+
19
+
20
+ class BartCustomEncoderLayer(nn.Module):
21
+ def __init__(self, config: BartCustomConfig, heads_mask: Optional[torch.Tensor]):
22
+ super().__init__()
23
+ self.embed_dim = config.d_model
24
+ is_simple_mask_commonsense = config.is_simple_mask_commonsense
25
+ if not is_simple_mask_commonsense:
26
+ print("Selecting complex relation attention")
27
+ self.self_attn = BartCustomAttention(
28
+ embed_dim=self.embed_dim,
29
+ num_heads=config.encoder_attention_heads,
30
+ dropout=config.attention_dropout,
31
+ num_relation_kinds=config.num_relation_kinds,
32
+ use_same_relation_kv_emb=config.use_same_relation_kv_emb,
33
+ heads_mask=heads_mask,
34
+ )
35
+ else:
36
+ print("Selecting simple (MASK) relation attention")
37
+ self.self_attn = BartCustomMaskAttention(
38
+ embed_dim=self.embed_dim,
39
+ num_heads=config.encoder_attention_heads,
40
+ dropout=config.attention_dropout,
41
+ num_relation_kinds=config.num_relation_kinds,
42
+ heads_mask=heads_mask,
43
+ )
44
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
45
+ self.dropout = config.dropout
46
+ self.activation_fn = ACT2FN[config.activation_function]
47
+ self.activation_dropout = config.activation_dropout
48
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
49
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
50
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
51
+
52
+ def forward(
53
+ self,
54
+ hidden_states: torch.FloatTensor,
55
+ attention_mask: torch.FloatTensor,
56
+ layer_head_mask: torch.FloatTensor,
57
+ output_attentions: Optional[bool] = False,
58
+ relation_inputs: Optional[torch.Tensor] = None,
59
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
60
+ """
61
+ Args:
62
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
63
+ attention_mask (`torch.FloatTensor`): attention mask of size
64
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
65
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
66
+ `(encoder_attention_heads,)`.
67
+ output_attentions (`bool`, *optional*):
68
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
69
+ returned tensors for more detail.
70
+ """
71
+ residual = hidden_states
72
+ hidden_states, attn_weights, _ = self.self_attn(
73
+ hidden_states=hidden_states,
74
+ attention_mask=attention_mask,
75
+ layer_head_mask=layer_head_mask,
76
+ output_attentions=output_attentions,
77
+ relation_inputs=relation_inputs,
78
+ )
79
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
80
+ hidden_states = residual + hidden_states
81
+ hidden_states = self.self_attn_layer_norm(hidden_states)
82
+
83
+ residual = hidden_states
84
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
85
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
86
+ hidden_states = self.fc2(hidden_states)
87
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
88
+ hidden_states = residual + hidden_states
89
+ hidden_states = self.final_layer_norm(hidden_states)
90
+
91
+ if hidden_states.dtype == torch.float16 and (
92
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
93
+ ):
94
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
95
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
96
+
97
+ outputs = (hidden_states,)
98
+
99
+ if output_attentions:
100
+ outputs += (attn_weights,)
101
+
102
+ return outputs
custom_tokenizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .bart_custom_tokenizer_fast import *
custom_tokenizer/bart_custom_tokenizer_fast.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ from typing import List, Optional, Tuple, Dict
18
+ from collections import deque
19
+
20
+ import torch
21
+ import numpy as np
22
+
23
+ from tokenizers import pre_tokenizers, processors
24
+
25
+ from transformers.tokenization_utils_base import AddedToken, BatchEncoding
26
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
27
+ from transformers.utils import logging
28
+ from transformers.models.bart.tokenization_bart import BartTokenizer
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
35
+
36
+ # See all BART models at https://huggingface.co/models?filter=bart
37
+ PRETRAINED_VOCAB_FILES_MAP = {
38
+ "vocab_file": {
39
+ "facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/vocab.json",
40
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/vocab.json",
41
+ "facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/vocab.json",
42
+ "facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json",
43
+ "facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/vocab.json",
44
+ "yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/vocab.json",
45
+ },
46
+ "merges_file": {
47
+ "facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/merges.txt",
48
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/merges.txt",
49
+ "facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/merges.txt",
50
+ "facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/merges.txt",
51
+ "facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/merges.txt",
52
+ "yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/merges.txt",
53
+ },
54
+ "tokenizer_file": {
55
+ "facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/tokenizer.json",
56
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/tokenizer.json",
57
+ "facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/tokenizer.json",
58
+ "facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/tokenizer.json",
59
+ "facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/tokenizer.json",
60
+ "yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/tokenizer.json",
61
+ },
62
+ }
63
+
64
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "facebook/bart-base": 1024,
66
+ "facebook/bart-large": 1024,
67
+ "facebook/bart-large-mnli": 1024,
68
+ "facebook/bart-large-cnn": 1024,
69
+ "facebook/bart-large-xsum": 1024,
70
+ "yjernite/bart_eli5": 1024,
71
+ }
72
+
73
+
74
+ class BartCustomTokenizerFast(PreTrainedTokenizerFast):
75
+ r"""
76
+ Construct a "fast" BART tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer,
77
+ using byte-level Byte-Pair-Encoding.
78
+
79
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
80
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
81
+
82
+ ```
83
+ >>> from transformers import BartTokenizerFast
84
+ >>> tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")
85
+ >>> tokenizer("Hello world")['input_ids']
86
+ [0, 31414, 232, 2]
87
+ >>> tokenizer(" Hello world")['input_ids']
88
+ [0, 20920, 232, 2]
89
+ ```
90
+
91
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
92
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
93
+
94
+ <Tip>
95
+
96
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
97
+
98
+ </Tip>
99
+
100
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
101
+ refer to this superclass for more information regarding those methods.
102
+
103
+ Args:
104
+ vocab_file (`str`):
105
+ Path to the vocabulary file.
106
+ merges_file (`str`):
107
+ Path to the merges file.
108
+ errors (`str`, *optional*, defaults to `"replace"`):
109
+ Paradigm to follow when decoding bytes to UTF-8. See
110
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
111
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
112
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
113
+
114
+ <Tip>
115
+
116
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
117
+ sequence. The token used is the `cls_token`.
118
+
119
+ </Tip>
120
+
121
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
122
+ The end of sequence token.
123
+
124
+ <Tip>
125
+
126
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
127
+ The token used is the `sep_token`.
128
+
129
+ </Tip>
130
+
131
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
132
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
133
+ sequence classification or for a text and a question for question answering. It is also used as the last
134
+ token of a sequence built with special tokens.
135
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
136
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
137
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
138
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
139
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
140
+ token instead.
141
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
142
+ The token used for padding, for example when batching sequences of different lengths.
143
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
144
+ The token used for masking values. This is the token used when training this model with masked language
145
+ modeling. This is the token which the model will try to predict.
146
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
147
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
148
+ other word. (BART tokenizer detect beginning of words by the preceding space).
149
+ trim_offsets (`bool`, *optional*, defaults to `True`):
150
+ Whether the post processing step should trim offsets to avoid including whitespaces.
151
+ """
152
+ vocab_files_names = VOCAB_FILES_NAMES
153
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
154
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
155
+ model_input_names = ["input_ids", "attention_mask", "input_commonsense_relations", "commonsense_mask"]
156
+ slow_tokenizer_class = BartTokenizer
157
+
158
+ def __init__(
159
+ self,
160
+ vocab_file=None,
161
+ merges_file=None,
162
+ tokenizer_file=None,
163
+ errors="replace",
164
+ bos_token="<s>",
165
+ eos_token="</s>",
166
+ sep_token="</s>",
167
+ cls_token="<s>",
168
+ unk_token="<unk>",
169
+ pad_token="<pad>",
170
+ mask_token="<mask>",
171
+ add_prefix_space=False,
172
+ trim_offsets=True,
173
+ **kwargs
174
+ ):
175
+ super().__init__(
176
+ vocab_file,
177
+ merges_file,
178
+ tokenizer_file=tokenizer_file,
179
+ errors=errors,
180
+ bos_token=bos_token,
181
+ eos_token=eos_token,
182
+ sep_token=sep_token,
183
+ cls_token=cls_token,
184
+ unk_token=unk_token,
185
+ pad_token=pad_token,
186
+ mask_token=mask_token,
187
+ add_prefix_space=add_prefix_space,
188
+ trim_offsets=trim_offsets,
189
+ **kwargs,
190
+ )
191
+
192
+ self.relational_kind_to_index = None
193
+ self.there_is_difference_between_relations = True
194
+
195
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
196
+ if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
197
+ pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
198
+ pre_tok_state["add_prefix_space"] = add_prefix_space
199
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
200
+
201
+ self.add_prefix_space = add_prefix_space
202
+
203
+ # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
204
+ tokenizer_component = "post_processor"
205
+ tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
206
+ if tokenizer_component_instance:
207
+ state = json.loads(tokenizer_component_instance.__getstate__())
208
+
209
+ # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class`
210
+ if "sep" in state:
211
+ state["sep"] = tuple(state["sep"])
212
+ if "cls" in state:
213
+ state["cls"] = tuple(state["cls"])
214
+
215
+ changes_to_apply = False
216
+
217
+ if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
218
+ state["add_prefix_space"] = add_prefix_space
219
+ changes_to_apply = True
220
+
221
+ if state.get("trim_offsets", trim_offsets) != trim_offsets:
222
+ state["trim_offsets"] = trim_offsets
223
+ changes_to_apply = True
224
+
225
+ if changes_to_apply:
226
+ component_class = getattr(processors, state.pop("type"))
227
+ new_value = component_class(**state)
228
+ setattr(self.backend_tokenizer, tokenizer_component, new_value)
229
+
230
+ def __call__(self, *args, **kwargs):
231
+ input_commonsense_relations = kwargs.get('input_commonsense_relations', None)
232
+ if 'input_commonsense_relations' in kwargs:
233
+ kwargs.pop('input_commonsense_relations')
234
+ out = super(BartCustomTokenizerFast, self).__call__(*args, **kwargs)
235
+ if out.get('input_commonsense_relations') is None:
236
+ out = self._post_process_tokenization(input_commonsense_relations, out)
237
+ return out
238
+
239
+ def set_known_relation_names(self, known_relations_names: List[str]):
240
+ self.relational_kind_to_index = {t: i + 1 for i, t in enumerate(known_relations_names)}
241
+
242
+ def set_operation_mode(self, there_is_difference_between_relations=True):
243
+ self.there_is_difference_between_relations = there_is_difference_between_relations
244
+
245
+ @property
246
+ def mask_token(self) -> str:
247
+ """
248
+ `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
249
+ having been set.
250
+
251
+ BART tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily
252
+ comprise the space before the *<mask>*.
253
+ """
254
+ if self._mask_token is None and self.verbose:
255
+ logger.error("Using mask_token, but it is not set yet.")
256
+ return None
257
+ return str(self._mask_token)
258
+
259
+ @mask_token.setter
260
+ def mask_token(self, value):
261
+ """
262
+ Overriding the default behavior of the mask token to have it eat the space before it.
263
+
264
+ This is needed to preserve backward compatibility with all the previously used models based on Bart.
265
+ """
266
+ # Mask token behave like a normal word, i.e. include the space before it
267
+ # So we set lstrip to True
268
+ value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
269
+ self._mask_token = value
270
+
271
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
272
+ is_split_into_words = kwargs.get("is_split_into_words", False)
273
+
274
+ if is_split_into_words and not self.add_prefix_space:
275
+ raise ValueError(
276
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
277
+ "to use it with pretokenized inputs."
278
+ )
279
+ input_commonsense_relations = kwargs.get('input_commonsense_relations', None)
280
+ if 'input_commonsense_relations' in kwargs:
281
+ kwargs.pop('input_commonsense_relations')
282
+ out = super()._batch_encode_plus(*args, **kwargs)
283
+ if out.get('input_commonsense_relations') is None:
284
+ out = self._post_process_tokenization(input_commonsense_relations, out)
285
+ return out
286
+
287
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
288
+ is_split_into_words = kwargs.get("is_split_into_words", False)
289
+
290
+ if is_split_into_words and not self.add_prefix_space:
291
+ raise ValueError(
292
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
293
+ "to use it with pretokenized inputs."
294
+ )
295
+
296
+ input_commonsense_relations = kwargs.get('input_commonsense_relations', None)
297
+ if 'input_commonsense_relations' in kwargs:
298
+ kwargs.pop('input_commonsense_relations')
299
+ out = super()._encode_plus(*args, **kwargs)
300
+ if out.get('input_commonsense_relations') is None:
301
+ out = self._post_process_tokenization(input_commonsense_relations, out)
302
+ return out
303
+
304
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
305
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
306
+ return tuple(files)
307
+
308
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
309
+ output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
310
+ if token_ids_1 is None:
311
+ return output
312
+
313
+ return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
314
+
315
+ def create_token_type_ids_from_sequences(
316
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
317
+ ) -> List[int]:
318
+ """
319
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. BART does not
320
+ make use of token type ids, therefore a list of zeros is returned.
321
+
322
+ Args:
323
+ token_ids_0 (`List[int]`):
324
+ List of IDs.
325
+ token_ids_1 (`List[int]`, *optional*):
326
+ Optional second list of IDs for sequence pairs.
327
+
328
+ Returns:
329
+ `List[int]`: List of zeros.
330
+ """
331
+ sep = [self.sep_token_id]
332
+ cls = [self.cls_token_id]
333
+
334
+ if token_ids_1 is None:
335
+ return len(cls + token_ids_0 + sep) * [0]
336
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
337
+
338
+ def _post_process_tokenization(self, input_commonsense_relations, out: BatchEncoding) -> BatchEncoding:
339
+ new_input_relations = self.get_new_input_relation_kinds(
340
+ tokenizer_outputs=out, input_relations=input_commonsense_relations
341
+ )
342
+ #if new_input_relations is not None:
343
+ # print('sum:', new_input_relations.sum())
344
+ out['input_commonsense_relations'] = new_input_relations
345
+ return out
346
+
347
+ def find_new_tokens_span_for_multiword(self, pair, aux_dict):
348
+ old_start, old_end = pair
349
+ #print('pair:', pair)
350
+ keys = list(aux_dict.keys())
351
+ #print('aux_dict:', aux_dict)
352
+ new_start, new_end = old_start, old_end
353
+ for (start, end) in keys:
354
+ #print('-----> (start, end)', (start, end))
355
+ #print('old_start, old_end:', old_start, old_end)
356
+ #print('start, end:', start, end)
357
+ if old_start >= start and old_end <= end:
358
+ new_start, new_end = start, end
359
+ break
360
+ return new_start, new_end
361
+
362
+ def find_new_tokens_incoming_span_for_multiword(self, pair, aux_dict):
363
+ old_start, old_end = pair
364
+ incoming_rels = list([coord for v in aux_dict.values() for coord, relation in v.items()])
365
+ new_start, new_end = old_start, old_end
366
+ for (start, end) in incoming_rels:
367
+ #print('-----> (start, end)', (start, end))
368
+ #print('old_start, old_end:', old_start, old_end)
369
+ #print('start, end:', start, end)
370
+ if old_start >= start and old_end <= end:
371
+ new_start, new_end = start, end
372
+ break
373
+ return new_start, new_end
374
+
375
+ def get_new_input_relation_kinds(
376
+ self,
377
+ tokenizer_outputs: BatchEncoding,
378
+ input_relations: Optional[List[Dict[Tuple[int, int], Dict[Tuple[int, int], str]]]] = None
379
+ ) -> torch.Tensor:
380
+
381
+ n_examples = len(tokenizer_outputs['input_ids'])
382
+ n_tokens = len(tokenizer_outputs['input_ids'][0])
383
+ aux_input_relation_kinds = np.zeros(
384
+ (n_examples, n_tokens, n_tokens),
385
+ dtype=np.int64
386
+ )
387
+ if not input_relations and input_relations is not None:
388
+ return torch.from_numpy(aux_input_relation_kinds)
389
+ elif not input_relations:
390
+ return None#torch.tensor([])
391
+ assert 'offset_mapping' in tokenizer_outputs, "Run tokenizer with return_offsets_mapping=True"
392
+ # print('aux_input_relation_kinds.shape', tokenizer_outputs['input_ids'].shape)
393
+ #print('input_relations:', input_relations)
394
+ if input_relations is not None:
395
+ # if input_relations is dirty, clean it
396
+ if isinstance(input_relations, dict):
397
+ input_relations = [input_relations]
398
+ mappings = tokenizer_outputs['offset_mapping']
399
+ assert len(mappings) == len(input_relations)
400
+ # print("to normal:", self.tokenizer.convert_ids_to_tokens(tokenizer_outputs['input_ids'][0]))
401
+ # print('words: ', words)
402
+ # print('x: ', mappings)
403
+ mappings = [[tuple(x) for x in mappings[idx].cpu().detach().tolist()] for idx in range(n_examples)]
404
+ # print(mappings)
405
+ examples_mappings = []
406
+ max_idx = 0
407
+ for idx, mapping in enumerate(mappings):
408
+ #print(idx, mapping)
409
+ words = tokenizer_outputs.word_ids(batch_index=idx)
410
+ tokens_to_words = deque(words)
411
+ token_idx_2_word_span = {}
412
+ for token_idx, (_char_i, _char_j) in enumerate(mapping):
413
+ word_idx_of_token = tokens_to_words.popleft()
414
+ if word_idx_of_token is None:
415
+ continue
416
+ token_span = tokenizer_outputs.word_to_chars(word_idx_of_token)
417
+ token_idx_2_word_span[token_idx] = (token_span.start, token_span.end) # sera que tenho de tirar o menos 1 (estava -1)
418
+ max_idx = max(token_idx, max_idx)
419
+ #print('token_idx_2_word_span:', token_idx_2_word_span)
420
+ ##### Multiword ######
421
+ token_idx_2_word_span_multiword = {}
422
+ d = input_relations[idx]
423
+ for k, v in token_idx_2_word_span.items():
424
+ #print('k,v', k, v)
425
+ new_start, new_end = self.find_new_tokens_span_for_multiword(v, d)
426
+ token_idx_2_word_span_multiword[k] = (new_start, new_end)
427
+ #print('tmp:', token_idx_2_word_span_multiword)
428
+ #print('[before]token_idx_2_word_span_multiword[k]:', token_idx_2_word_span_multiword[k])
429
+ if v[0]==new_start and v[1]==new_end:
430
+ new_start, new_end = self.find_new_tokens_incoming_span_for_multiword(v, d)
431
+ token_idx_2_word_span_multiword[k] = (new_start, new_end)
432
+ #print('tmp2:', token_idx_2_word_span_multiword)
433
+ #print('[after]token_idx_2_word_span_multiword[k]:', token_idx_2_word_span_multiword[k])
434
+ ##### ######
435
+ #print('token_idx_2_word_span_multiword:', token_idx_2_word_span_multiword)
436
+ examples_mappings.append(token_idx_2_word_span_multiword)
437
+ # print('len:', len(examples_mappings))
438
+ # print('max_idx: ', max_idx)
439
+ for i_example in range(n_examples):
440
+ token_idx_2_word_span = examples_mappings[i_example]
441
+ # print('token_idx_2_word_span: ', token_idx_2_word_span)
442
+ possible_relations = input_relations[i_example]
443
+ # print('possible_relations: ', possible_relations)
444
+ for token_i_idx in range(max_idx + 1):
445
+ for token_j_idx in range(max_idx + 1):
446
+ fixed_word_range = token_idx_2_word_span.get(token_i_idx, None)
447
+ other_word_range = token_idx_2_word_span.get(token_j_idx, None)
448
+ if not fixed_word_range or not other_word_range:
449
+ continue
450
+ #print(fixed_word_range, ' | ', other_word_range)
451
+ relations = possible_relations.get(fixed_word_range, None)
452
+ if not relations:
453
+ continue
454
+ #print('possible_relations:' , possible_relations)
455
+ relation_kind = relations.get(other_word_range, None)
456
+ if not relation_kind:
457
+ continue
458
+ #print('relation_kind:',relation_kind)
459
+ if self.there_is_difference_between_relations:
460
+ aux_input_relation_kinds[i_example, token_i_idx, token_j_idx] = self.relational_kind_to_index[relation_kind]
461
+ else:
462
+ # basic relation | only matters that relation exists between tokens
463
+ aux_input_relation_kinds[i_example, token_i_idx, token_j_idx] = 1
464
+ aux_input_relation_kinds = torch.from_numpy(aux_input_relation_kinds)
465
+ return aux_input_relation_kinds
466
+
467
+ def create_commonsense_mask(self, tokenizer_outputs, commonsense_matrix, num_heads=16, specific_head=0):
468
+ bsz = len(tokenizer_outputs['input_ids'])
469
+ n_tokens = len(tokenizer_outputs['input_ids'][0])
470
+ commonsense_mask = np.zeros(
471
+ ((bsz, num_heads, n_tokens, n_tokens)),
472
+ dtype=np.int64
473
+ )
474
+ if commonsense_matrix is None:
475
+ commonsense_matrix = np.zeros(
476
+ ((bsz, n_tokens, n_tokens)),
477
+ dtype=np.int64
478
+ )
479
+ commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
480
+ # commonsense_matrix.shape: (bsz, src_len, tgt_len)
481
+ #print('commonsense_matrix:', commonsense_matrix)
482
+ commonsense_mask[specific_head] = commonsense_matrix
483
+ commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens))
484
+ return commonsense_mask
data/__init__.py ADDED
File without changes
data/relation_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #############################
3
+ # Imports
4
+ #############################
5
+
6
+ # Python modules
7
+ from collections import deque
8
+ from ast import literal_eval
9
+
10
+ # Remote modules
11
+ import torch
12
+
13
+ # Local modules
14
+
15
+ #############################
16
+ # Constants
17
+ #############################
18
+
19
+ ##########################################################
20
+ # Helper functions for Relations in dict format
21
+ ##########################################################
22
+
23
+ def clean_relations(word_relations):
24
+ new_relations = deque()
25
+ for r in word_relations:
26
+ rel = {}
27
+ for r_key, r_value in r.items():
28
+ normal_k = literal_eval(r_key)
29
+ rel_d = {}
30
+ for r_d_key, r_d_value in r_value.items():
31
+ normal_d_k = literal_eval(r_d_key)
32
+ rel_d[normal_d_k] = r_d_value
33
+ rel[normal_k] = rel_d
34
+ new_relations.append(rel)
35
+ list_new_relations = list(new_relations)
36
+ return list_new_relations
37
+
38
+ ##########################################################
39
+ # Helper functions for Relations in Matrix format
40
+ ##########################################################
41
+
42
+ def relation_binary_2d_to_1d(relations_binary_mask, dim=1):
43
+ relations_binary_mask = relations_binary_mask.sum(dim=dim)
44
+ relations_binary_mask[relations_binary_mask > 1] = 1
45
+ return relations_binary_mask
46
+
47
+ def tokens_with_relations(relations_binary_mask):
48
+ relations_binary_mask_dim1 = relations_binary_mask.sum(dim=0)
49
+ relations_binary_mask_dim2 = relations_binary_mask.sum(dim=1)
50
+ tokens_with_rels = relations_binary_mask_dim1 + relations_binary_mask_dim2
51
+ tokens_with_rels[tokens_with_rels > 1] = 1
52
+ mask_rels = torch.tensor(tokens_with_rels, dtype=torch.bool)
53
+ return mask_rels
inference.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import List
7
+
8
+ # Remote modules
9
+ import numpy as np
10
+ import torch
11
+
12
+ # Local modules
13
+ from kgs_binding.relation_mapper_builder import RelationsMapperBuilder
14
+ from kgs_binding.kg_qa_binding_utils import load_kg_handler
15
+ from data.relation_utils import clean_relations
16
+ from model_utils import create_layers_head_mask
17
+
18
+ from transformers import (
19
+ BartForConditionalGeneration,
20
+ BartTokenizer,
21
+ BartConfig,
22
+ DisjunctiveConstraint,
23
+ )
24
+
25
+ from utils import get_jump_chunks
26
+
27
+ #############################
28
+ # Constants
29
+ #############################
30
+
31
+ #############################
32
+ # Stuff
33
+ #############################
34
+ from custom_tokenizer import BartCustomTokenizerFast
35
+ from custom_bart import BartCustomConfig, BartCustomForConditionalGeneration
36
+ from utils import get_device, KGType, Model_Type
37
+
38
+ from kgs_binding.kg_base_wrapper import KGBaseHandler
39
+ from kgs_binding.swow_handler import SwowHandler
40
+ from kgs_binding.conceptnet_handler import ConceptNetHandler
41
+
42
+ class Inference:
43
+ def __init__(self, model_path:str, max_length=32):
44
+ self.device = get_device()
45
+ self.tokenizer = self.prepare_tokenizer()
46
+ self.model = self.prepare_model(model_path)
47
+ self.max_length = max_length
48
+
49
+ def prepare_tokenizer(self):
50
+ tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
51
+ return tokenizer
52
+
53
+ def prepare_model(self, model_path):
54
+ config = BartConfig.from_pretrained(model_path)
55
+ model = BartForConditionalGeneration.from_pretrained(model_path, config=config).to(self.device)
56
+ model.eval()
57
+ return model
58
+
59
+ def pre_process_context(self, context):
60
+ context = context.lower()
61
+ context_tokenized = self.tokenizer(context, padding='max_length',
62
+ truncation='longest_first', max_length=self.max_length,
63
+ return_tensors="pt",
64
+ )
65
+ return context_tokenized
66
+
67
+ def generate_based_on_context(self, context):
68
+ model_input = self.pre_process_context(context)
69
+ generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device),
70
+ attention_mask=model_input["attention_mask"].to(self.device),
71
+ min_length=1,
72
+ max_length=self.max_length,
73
+ do_sample=True,
74
+ early_stopping=True,
75
+ num_beams=4,
76
+ temperature=1.0,
77
+ top_k=None,
78
+ top_p=None,
79
+ # eos_token_id=tokenizer.eos_token_id,
80
+ no_repeat_ngram_size=2,
81
+ num_return_sequences=1,
82
+ return_dict_in_generate=True,
83
+ output_attentions=True,
84
+ output_scores=True)
85
+ # print(f'Scores: {generated_answers_encoded}')
86
+ response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True,
87
+ clean_up_tokenization_spaces=True)
88
+ encoder_attentions = generated_answers_encoded['encoder_attentions']
89
+ return response, encoder_attentions, model_input
90
+
91
+ def prepare_context_for_visualization(self, context):
92
+ examples = []
93
+ response, encoder_outputs, model_input = self.generate_based_on_context(context)
94
+ encoder_outputs = torch.stack(encoder_outputs)
95
+ n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size()
96
+ print(encoder_outputs.size())
97
+ encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt)
98
+ for i, ex in enumerate(encoder_attentions):
99
+ d = {}
100
+ indices = model_input['input_ids'][i].detach().cpu()
101
+ all_tokens = self.tokenizer.convert_ids_to_tokens(indices)
102
+ useful_indeces = indices != self.tokenizer.pad_token_id
103
+ all_tokens = np.array(all_tokens)[useful_indeces]
104
+ all_tokens = [tok.replace('Ġ', '') for tok in all_tokens]
105
+ d['words'] = all_tokens
106
+ d['attentions'] = ex.detach().cpu().numpy()
107
+ examples.append(d)
108
+ print(d['words'])
109
+ return response, examples
110
+
111
+ class RelationsInference:
112
+ def __init__(self, model_path:str, kg_type: KGType, model_type:Model_Type, max_length=32):
113
+ self.device = get_device()
114
+ kg_handler: KGBaseHandler = load_kg_handler(kg_type)
115
+ self.kg_handler = kg_handler
116
+ relation_names = kg_handler.get_relation_types()
117
+ self.tokenizer = self.prepare_tokenizer(relation_names, model_type)
118
+ self.simple_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
119
+ self.model, self.config = self.prepare_model(relation_names, model_path, model_type)
120
+ self.relation_mapper_builder = RelationsMapperBuilder(knowledge=kg_handler)
121
+ self.max_length = max_length
122
+
123
+ def prepare_tokenizer(self, relation_names: List[str], model_type:Model_Type):
124
+ tokenizer = BartCustomTokenizerFast.from_pretrained('facebook/bart-large')
125
+ tokenizer.set_known_relation_names(relation_names)
126
+ tokenizer.set_operation_mode(there_is_difference_between_relations=model_type.there_is_difference_between_relations())
127
+ return tokenizer
128
+
129
+ def prepare_model(self, relation_names: List[str], model_path, model_type:Model_Type):
130
+ config = BartCustomConfig.from_pretrained(model_path, revision='master')
131
+ print('config.heads_mask:', config.heads_mask)
132
+ if config.num_relation_kinds is None:
133
+ config.num_relation_kinds = len(relation_names)
134
+ if config.is_simple_mask_commonsense is None:
135
+ config.is_simple_mask_commonsense = model_type.is_simple_mask_commonsense()
136
+ if config.heads_mask is None:
137
+ config.heads_mask = create_layers_head_mask(config)#, heads_mask_type, specific_heads)
138
+ model = BartCustomForConditionalGeneration.from_pretrained(model_path, config=config, revision='master').to(self.device)
139
+ model.eval()
140
+ return model, config
141
+
142
+ def pre_process_context(self, context):
143
+ context = context.lower()
144
+ # process context in search for relations
145
+ commonsense_relations = self.relation_mapper_builder.get_relations_mapping_complex(context=[context], clear_common_wds=True)
146
+ # clean relation
147
+ commonsense_relation = clean_relations(commonsense_relations)[0]
148
+ # convert this relations to matrices
149
+ print(commonsense_relation)
150
+ context_tokenized = self.tokenizer(context, padding='max_length',
151
+ truncation='longest_first', max_length=self.max_length,
152
+ return_tensors="pt", return_offsets_mapping=True,
153
+ input_commonsense_relations=commonsense_relation,
154
+ )
155
+ return context_tokenized
156
+
157
+ def get_relations_information(self, phrase_generated):
158
+ all_concepts = self.relation_mapper_builder.get_kg_concepts_from_context([phrase_generated], clear_common_wds=True)[0]
159
+ words = phrase_generated.strip().split(' ') # all words
160
+ concepts_with_relations = self.relation_mapper_builder.get_concepts_from_context(phrase_generated, clear_common_wds=True)
161
+ concepts_with_no_relations = list(set(all_concepts).difference(concepts_with_relations))
162
+ #print('without_relations:', concepts_with_no_relations)
163
+ print("====== RELATIONS SUMMARY ======")
164
+ print('phrase_generated:', phrase_generated)
165
+ print('words:', words)
166
+ print('all_concepts:', all_concepts)
167
+ print('concepts_with_relations:', concepts_with_relations)
168
+ print('without_relations:', concepts_with_no_relations)
169
+ print("\n== STATS:")
170
+ print('n_words:', len(words))
171
+ print('n_concepts:', len(all_concepts))
172
+ print('n_concepts_with_relations:', len(concepts_with_relations))
173
+ print('n_c_without_relations:', len(concepts_with_no_relations))
174
+ print("====== ================= ======")
175
+ return words, all_concepts, concepts_with_relations, concepts_with_no_relations
176
+
177
+ def remove_subsets(self, l):
178
+ l2 = l[:]
179
+ for m in l:
180
+ for n in l:
181
+ if set(m).issubset(set(n)) and m != n:
182
+ l2.remove(m)
183
+ break
184
+ return l2
185
+
186
+ def generate_based_on_context(self, context, use_kg=False):
187
+ model_input = self.pre_process_context(context)
188
+ #print(model_input)
189
+ gen_kwargs = {}
190
+ if "input_commonsense_relations" in model_input:
191
+ #print(model_input['input_commonsense_relations'].sum())
192
+ gen_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(self.device)
193
+
194
+ constraints = None
195
+ if use_kg:
196
+ constraints = []
197
+ concepts_from_context = self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True)
198
+ useful_concepts = [self.relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context]
199
+ if not useful_concepts:
200
+ useful_concepts = [self.kg_handler.get_related_concepts(concept) for concept in concepts_from_context]
201
+ useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] # add spaces
202
+ #useful_concepts = [[phrase for phrase in concepts if len(phrase.split(' ')) == 1] for concepts in useful_concepts]
203
+ #useful_concepts = list(itertools.chain.from_iterable(useful_concepts))
204
+ #print('useful_concepts:', useful_concepts[:5])
205
+ if concepts_from_context:
206
+ for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts):
207
+ print('neighbour:', neighbour_concepts[:20])
208
+ #flexible_words = self.most_similar_words(context_concept, neighbour_concepts) # limit the upperbound
209
+ #flexible_words = [word for word in flexible_words if word not in context_concept] # remove input concepts
210
+ flexible_words = [word for word in neighbour_concepts if word not in context_concept] # remove input concepts
211
+ flexible_words_ids: List[List[int]] = self.simple_tokenizer(flexible_words, add_prefix_space=True,add_special_tokens=False).input_ids
212
+ flexible_words_ids = self.remove_subsets(flexible_words_ids)
213
+ #add_prefix_space=True
214
+ #flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets
215
+ flexible_words_ids = flexible_words_ids[:10]
216
+ print('flexible_words_ids:', flexible_words_ids[:3])
217
+ constraint = DisjunctiveConstraint(flexible_words_ids)
218
+ constraints.append(constraint)
219
+ else:
220
+ constraints = None
221
+
222
+ generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device),
223
+ attention_mask=model_input["attention_mask"].to(self.device),
224
+ constraints=constraints,
225
+ min_length=1,
226
+ max_length=self.max_length,
227
+ do_sample=False,
228
+ early_stopping=True,
229
+ num_beams=8,
230
+ temperature=1.0,
231
+ top_k=None,
232
+ top_p=None,
233
+ # eos_token_id=tokenizer.eos_token_id,
234
+ no_repeat_ngram_size=2,
235
+ num_return_sequences=1,
236
+ return_dict_in_generate=True,
237
+ output_attentions=True,
238
+ output_scores=True,
239
+ **gen_kwargs,
240
+ )
241
+ # print(f'Scores: {generated_answers_encoded}')
242
+ response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True,
243
+ clean_up_tokenization_spaces=True)
244
+ encoder_attentions = generated_answers_encoded['encoder_attentions']
245
+ return response, encoder_attentions, model_input
246
+
247
+ def get_related_concepts_list(self, knowledge, list_concepts):
248
+ other_concepts = []
249
+ for concept in list_concepts:
250
+ other_near_concepts = knowledge.get_related_concepts(concept)
251
+ other_concepts.extend(other_near_concepts)
252
+ return other_concepts
253
+
254
+
255
+ def generate_contrained_based_on_context(self, contexts, use_kg=True, max_concepts=1):
256
+ model_inputs = [self.pre_process_context(context) for context in contexts]
257
+ constraints = None
258
+ if use_kg:
259
+ constraints = []
260
+ concepts_from_contexts = [self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) for context in contexts]
261
+ neighbours_contexts = []#[self.get_related_concepts_list(self.relation_mapper_builder.knowledge, context) for context in concepts_from_contexts]
262
+ if not neighbours_contexts:
263
+ neighbours_contexts = [self.get_related_concepts_list(self.kg_handler, context) for context in concepts_from_contexts]
264
+ all_constraints = []
265
+ for context_neighbours in neighbours_contexts:
266
+ # context_neighbours is a collection of concepts
267
+ # lets create sub collections of concepts
268
+ context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3]
269
+ n_size_chuncks = len(context_neighbours) // max_concepts
270
+ n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1
271
+ sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks))
272
+ constraints = []
273
+ for sub_concepts in sub_concepts_collection[:max_concepts]:
274
+ flexible_words_ids: List[List[int]] = self.tokenizer(sub_concepts,
275
+ add_special_tokens=False).input_ids # add_prefix_space=True,
276
+ # flexible_words_ids = self.remove_subsets(flexible_words_ids)
277
+ flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids]
278
+ disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids))))
279
+ if not any(disjunctive_set):
280
+ continue
281
+ constraint = DisjunctiveConstraint(disjunctive_set)
282
+ constraints.append(constraint)
283
+ if not any(constraints):
284
+ constraints = None
285
+ all_constraints.append(constraints)
286
+ else:
287
+ all_constraints = None
288
+ if not all_constraints:
289
+ all_constraints = None
290
+
291
+ generated_answers_encoded = []
292
+ encoder_attentions_list = []
293
+ for i, contraints in enumerate(all_constraints):
294
+ #print('contraints.token_ids:', [x.token_ids for x in contraints])
295
+ gen_kwargs = {}
296
+ inputs = model_inputs[i]
297
+ if "input_commonsense_relations" in inputs:
298
+ # print(model_input['input_commonsense_relations'].sum())
299
+ gen_kwargs["relation_inputs"] = inputs.get("input_commonsense_relations").to(self.device)
300
+ #print('model_kwargs.get("attention_mask"):', model_kwargs.get("attention_mask"))
301
+ gen = self.model.generate(input_ids=inputs["input_ids"].to(self.device),
302
+ attention_mask=inputs["attention_mask"].to(self.device),
303
+ constraints=constraints,
304
+ min_length=1,
305
+ max_length=self.max_length,
306
+ do_sample=False,
307
+ early_stopping=True,
308
+ num_beams=8,
309
+ temperature=1.0,
310
+ top_k=None,
311
+ top_p=None,
312
+ # eos_token_id=tokenizer.eos_token_id,
313
+ no_repeat_ngram_size=2,
314
+ num_return_sequences=1,
315
+ return_dict_in_generate=True,
316
+ output_attentions=True,
317
+ output_scores=True,
318
+ **gen_kwargs,
319
+ )
320
+ # print('[gen]:', gen)
321
+ # print(tokenizer.batch_decode(gen))
322
+ generated_answers_encoded.append(gen['sequences'][0].detach().cpu())
323
+ encoder_attentions_list.append(gen['encoder_attentions'][0].detach().cpu())
324
+ # print(f'Scores: {generated_answers_encoded}')
325
+ text_results = self.tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
326
+ clean_up_tokenization_spaces=True)
327
+ return text_results, encoder_attentions_list, model_inputs
328
+
329
+ def prepare_context_for_visualization(self, context):
330
+ examples, relations = [], []
331
+ response, encoder_outputs, model_input = self.generate_based_on_context(context)
332
+ input_commonsense_relations = model_input.get("input_commonsense_relations")
333
+ encoder_outputs = torch.stack(encoder_outputs)
334
+ n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size()
335
+ print(encoder_outputs.size())
336
+ encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt)
337
+ for i, ex in enumerate(encoder_attentions):
338
+ d = {}
339
+ indices = model_input['input_ids'][i].detach().cpu()
340
+ all_tokens = self.tokenizer.convert_ids_to_tokens(indices)
341
+ useful_indeces = indices != self.tokenizer.pad_token_id
342
+ all_tokens = np.array(all_tokens)[useful_indeces]
343
+ all_tokens = [tok.replace('Ġ', '') for tok in all_tokens]
344
+ d['words'] = all_tokens
345
+ d['attentions'] = ex.detach().cpu().numpy()
346
+ examples.append(d)
347
+ relations.append(input_commonsense_relations[i])
348
+ print(d['words'])
349
+ return response, examples, relations
kgs_binding/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .kg_base_wrapper import KGBaseHandler
2
+ from .relation_mapper_builder import RelationsMapperBuilder
3
+ from . import *
kgs_binding/conceptnet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import *
kgs_binding/conceptnet/conceptnet_english_noun_2_noun_relations.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b82686e2cb4a32a827d3c0a0c63a91d5d102fe5813fe898cabd9a117aa7374c
3
+ size 186932142
kgs_binding/conceptnet/conceptnet_english_nouns.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b90ab07ca7445623bcd90489367c4016ca3b4ed743816a99b730f22e13ac339c
3
+ size 140804377
kgs_binding/conceptnet/conceptnet_english_nouns_simple.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad6e76d470432dc3c9c7c0ebbf340eda3c4f69008c9f8a27df97f8e005e5db02
3
+ size 22419586
kgs_binding/conceptnet_handler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import Tuple, Optional, List
7
+ # Remote modules
8
+
9
+ # Local modules
10
+ from .kg_base_wrapper import KGBaseHandler
11
+ from utils import read_json_file_2_dict
12
+
13
+ #############################
14
+ # Constants
15
+ #############################
16
+
17
+ #############################
18
+ # Handler
19
+ #############################
20
+
21
+ class ConceptNetHandler(KGBaseHandler):
22
+ def __init__(self, database=""):
23
+ super(ConceptNetHandler, self).__init__()
24
+ _store_dir = 'kgs_binding/conceptnet'
25
+ self.conceptnet_concepts = read_json_file_2_dict('conceptnet_english_nouns_simple.json', store_dir=_store_dir)
26
+ self.relations_concepts = read_json_file_2_dict('conceptnet_english_noun_2_noun_relations.json', store_dir=_store_dir)
27
+ self.concept_2_concepts = read_json_file_2_dict('conceptnet_english_nouns.json', store_dir=_store_dir)
28
+
29
+ def get_relation_types(self) -> List[str]:
30
+ updated_relation_names = ['not_has_property', 'not_desires', 'external_u_r_l', 'created_by',
31
+ 'not_capable_of', 'antonym', 'has_first_subevent', 'located_near',
32
+ 'desires', 'has_prerequisite', 'has_last_subevent', 'synonym', 'is_a',
33
+ 'manner_of', 'has_a', 'motivated_by_goal', 'instance_of',
34
+ 'etymologically_derived_from', 'capable_of', 'for', 'at_location',
35
+ 'has_subevent', 'causes', 'has_context', 'symbol_of', 'derived_from',
36
+ 'made_of', 'causes_desire', 'has_property', 'similar_to', 'used_for', 'by',
37
+ 'entails', 'form_of', 'receives_action', 'distinct_from', 'related_to',
38
+ 'part_of', 'defined_as', 'etymologically_related_to']
39
+ return updated_relation_names
40
+
41
+ def exists_relation_between(self, concept, other_concept) -> bool:
42
+ left_2_right, right_2_left = self.relation_between(concept, other_concept)
43
+ return left_2_right is not None or right_2_left is not None
44
+
45
+ def relation_between(self, concept, other_concept) -> Tuple[Optional[str], Optional[str]]:
46
+ left_2_right_txt = f'{concept}|{other_concept}'
47
+ right_2_left_txt = f'{other_concept}|{concept}'
48
+ left_2_right_relations = self.relations_concepts.get(left_2_right_txt, None)
49
+ right_2_left_relations = self.relations_concepts.get(right_2_left_txt, None)
50
+ left_2_right_relation, right_2_left_relation = None, None
51
+ if left_2_right_relations:
52
+ left_2_right_relation = self.ignore_less_relevant_connection(left_2_right_relations)
53
+ if right_2_left_relations:
54
+ right_2_left_relation = self.ignore_less_relevant_connection(right_2_left_relations)
55
+ return left_2_right_relation, right_2_left_relation
56
+
57
+ def get_related_concepts(self, concept) -> Optional[List[str]]:
58
+ return self.concept_2_concepts.get(concept, [])
59
+
60
+ def does_concept_exist(self, concept) -> bool:
61
+ return concept in self.conceptnet_concepts
kgs_binding/english_stopwords.txt ADDED
@@ -0,0 +1,1126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'll
2
+ 'tis
3
+ 'twas
4
+ 've
5
+ a
6
+ a's
7
+ able
8
+ ableabout
9
+ about
10
+ above
11
+ abroad
12
+ abst
13
+ accordance
14
+ according
15
+ accordingly
16
+ across
17
+ act
18
+ actually
19
+ ad
20
+ added
21
+ adj
22
+ adopted
23
+ ae
24
+ af
25
+ affected
26
+ affecting
27
+ ag
28
+ ah
29
+ ai
30
+ ain't
31
+ aint
32
+ al
33
+ all
34
+ almost
35
+ along
36
+ alongside
37
+ also
38
+ although
39
+ am
40
+ amid
41
+ amidst
42
+ among
43
+ amongst
44
+ amoungst
45
+ an
46
+ and
47
+ another
48
+ any
49
+ anybody
50
+ anyhow
51
+ anymore
52
+ anyone
53
+ anything
54
+ anyway
55
+ anyways
56
+ anywhere
57
+ ao
58
+ apart
59
+ apparently
60
+ appear
61
+ appreciate
62
+ appropriate
63
+ approximately
64
+ aq
65
+ ar
66
+ are
67
+ aren
68
+ aren't
69
+ arent
70
+ arise
71
+ around
72
+ arpa
73
+ as
74
+ aside
75
+ ask
76
+ asked
77
+ asking
78
+ asks
79
+ associated
80
+ at
81
+ au
82
+ auth
83
+ aw
84
+ away
85
+ awfully
86
+ az
87
+ b
88
+ ba
89
+ back
90
+ backed
91
+ backing
92
+ backs
93
+ bb
94
+ bd
95
+ be
96
+ became
97
+ because
98
+ become
99
+ becomes
100
+ becoming
101
+ been
102
+ beforehand
103
+ began
104
+ beginning
105
+ beginnings
106
+ begins
107
+ behind
108
+ being
109
+ beings
110
+ believe
111
+ below
112
+ beside
113
+ besides
114
+ best
115
+ better
116
+ between
117
+ beyond
118
+ bf
119
+ bg
120
+ bh
121
+ bi
122
+ biol
123
+ bj
124
+ bm
125
+ bn
126
+ bo
127
+ both
128
+ br
129
+ brief
130
+ briefly
131
+ bs
132
+ bt
133
+ but
134
+ buy
135
+ bv
136
+ bw
137
+ by
138
+ bz
139
+ c
140
+ c'mon
141
+ c's
142
+ ca
143
+ call
144
+ came
145
+ can
146
+ can't
147
+ cannot
148
+ cant
149
+ caption
150
+ case
151
+ cases
152
+ cause
153
+ causes
154
+ cc
155
+ cd
156
+ certain
157
+ certainly
158
+ cf
159
+ cg
160
+ ch
161
+ changes
162
+ ci
163
+ ck
164
+ cl
165
+ clear
166
+ clearly
167
+ cm
168
+ cmon
169
+ cn
170
+ co
171
+ co.
172
+ com
173
+ come
174
+ comes
175
+ con
176
+ consequently
177
+ contain
178
+ containing
179
+ contains
180
+ copy
181
+ corresponding
182
+ could
183
+ could've
184
+ couldn
185
+ couldn't
186
+ couldnt
187
+ cr
188
+ cs
189
+ cu
190
+ currently
191
+ cv
192
+ cx
193
+ cy
194
+ cz
195
+ d
196
+ dare
197
+ daren't
198
+ darent
199
+ de
200
+ dear
201
+ definitely
202
+ describe
203
+ described
204
+ despite
205
+ detail
206
+ did
207
+ didn
208
+ didn't
209
+ didnt
210
+ differ
211
+ different
212
+ differently
213
+ directly
214
+ dj
215
+ dk
216
+ dm
217
+ do
218
+ does
219
+ doesn
220
+ doesn't
221
+ doesnt
222
+ doing
223
+ don
224
+ don't
225
+ done
226
+ dont
227
+ downed
228
+ downing
229
+ due
230
+ during
231
+ dz
232
+ e
233
+ each
234
+ ec
235
+ ed
236
+ edu
237
+ ee
238
+ eg
239
+ eh
240
+ either
241
+ else
242
+ elsewhere
243
+ enough
244
+ entirely
245
+ er
246
+ es
247
+ especially
248
+ et
249
+ et-al
250
+ etc
251
+ even
252
+ evenly
253
+ ever
254
+ evermore
255
+ every
256
+ everybody
257
+ everyone
258
+ everything
259
+ everywhere
260
+ ex
261
+ exactly
262
+ example
263
+ except
264
+ f
265
+ fairly
266
+ far
267
+ farther
268
+ felt
269
+ few
270
+ fewer
271
+ ff
272
+ fi
273
+ fify
274
+ fj
275
+ fk
276
+ fm
277
+ fo
278
+ for
279
+ forever
280
+ formerly
281
+ forth
282
+ found
283
+ fr
284
+ from
285
+ front
286
+ full
287
+ fully
288
+ further
289
+ furthered
290
+ furthering
291
+ furthermore
292
+ furthers
293
+ fx
294
+ g
295
+ ga
296
+ gave
297
+ gb
298
+ gd
299
+ ge
300
+ generally
301
+ gf
302
+ gg
303
+ gh
304
+ gi
305
+ gl
306
+ gm
307
+ gmt
308
+ gn
309
+ go
310
+ got
311
+ gotten
312
+ gov
313
+ gp
314
+ gq
315
+ gr
316
+ great
317
+ greater
318
+ greatest
319
+ greetings
320
+ group
321
+ grouped
322
+ grouping
323
+ groups
324
+ gs
325
+ gt
326
+ gu
327
+ gw
328
+ gy
329
+ h
330
+ had
331
+ hadn't
332
+ hadnt
333
+ half
334
+ happens
335
+ hardly
336
+ has
337
+ hasn
338
+ hasn't
339
+ hasnt
340
+ have
341
+ haven
342
+ haven't
343
+ havent
344
+ having
345
+ he
346
+ he'd
347
+ he'll
348
+ he's
349
+ hed
350
+ hell
351
+ hello
352
+ help
353
+ hence
354
+ her
355
+ here
356
+ here's
357
+ hereafter
358
+ hereby
359
+ herein
360
+ heres
361
+ hereupon
362
+ hers
363
+ herself
364
+ herse”
365
+ hes
366
+ hi
367
+ hid
368
+ high
369
+ higher
370
+ highest
371
+ him
372
+ himself
373
+ himse”
374
+ his
375
+ hither
376
+ hk
377
+ hm
378
+ hn
379
+ hopefully
380
+ how
381
+ how'd
382
+ how'll
383
+ how's
384
+ howbeit
385
+ however
386
+ hr
387
+ ht
388
+ htm
389
+ hu
390
+ i
391
+ i'd
392
+ i'll
393
+ i'm
394
+ i've
395
+ i.e.
396
+ id
397
+ ie
398
+ if
399
+ ignored
400
+ ii
401
+ il
402
+ ill
403
+ im
404
+ immediate
405
+ immediately
406
+ importance
407
+ important
408
+ in
409
+ inasmuch
410
+ inc
411
+ inc.
412
+ indeed
413
+ index
414
+ indicate
415
+ indicated
416
+ indicates
417
+ information
418
+ inner
419
+ inside
420
+ insofar
421
+ instead
422
+ int
423
+ interest
424
+ interested
425
+ interesting
426
+ interests
427
+ into
428
+ inward
429
+ io
430
+ iq
431
+ ir
432
+ is
433
+ isn
434
+ isn't
435
+ isnt
436
+ it
437
+ it'd
438
+ it'll
439
+ it's
440
+ itd
441
+ itll
442
+ its
443
+ itself
444
+ itse”
445
+ ive
446
+ j
447
+ je
448
+ jm
449
+ jo
450
+ join
451
+ jp
452
+ just
453
+ k
454
+ ke
455
+ keep
456
+ keeps
457
+ kept
458
+ kg
459
+ kh
460
+ ki
461
+ kind
462
+ km
463
+ kn
464
+ knew
465
+ know
466
+ known
467
+ knows
468
+ kp
469
+ kr
470
+ kw
471
+ ky
472
+ kz
473
+ l
474
+ la
475
+ large
476
+ largely
477
+ last
478
+ lately
479
+ later
480
+ latest
481
+ latter
482
+ latterly
483
+ lb
484
+ lc
485
+ least
486
+ less
487
+ lest
488
+ let
489
+ let's
490
+ lets
491
+ li
492
+ like
493
+ liked
494
+ likely
495
+ likewise
496
+ line
497
+ lk
498
+ ll
499
+ look
500
+ looking
501
+ looks
502
+ lower
503
+ lr
504
+ ls
505
+ lt
506
+ ltd
507
+ lu
508
+ lv
509
+ ly
510
+ m
511
+ ma
512
+ made
513
+ mainly
514
+ make
515
+ makes
516
+ making
517
+ many
518
+ may
519
+ maybe
520
+ mayn't
521
+ maynt
522
+ mc
523
+ md
524
+ me
525
+ mean
526
+ means
527
+ meantime
528
+ meanwhile
529
+ member
530
+ members
531
+ merely
532
+ mg
533
+ mh
534
+ might
535
+ might've
536
+ mightn't
537
+ mightnt
538
+ mil
539
+ mill
540
+ mine
541
+ miss
542
+ mk
543
+ ml
544
+ mm
545
+ mn
546
+ mo
547
+ more
548
+ moreover
549
+ most
550
+ mostly
551
+ move
552
+ mp
553
+ mq
554
+ mr
555
+ mrs
556
+ ms
557
+ msie
558
+ mt
559
+ mu
560
+ much
561
+ mug
562
+ must
563
+ must've
564
+ mustn't
565
+ mustnt
566
+ mv
567
+ mw
568
+ mx
569
+ my
570
+ myself
571
+ myse”
572
+ mz
573
+ n
574
+ na
575
+ namely
576
+ nay
577
+ nc
578
+ nd
579
+ ne
580
+ nearly
581
+ necessarily
582
+ necessary
583
+ need
584
+ needed
585
+ needing
586
+ needn't
587
+ neednt
588
+ needs
589
+ neither
590
+ net
591
+ never
592
+ neverf
593
+ neverless
594
+ nevertheless
595
+ newer
596
+ newest
597
+ nf
598
+ ng
599
+ ni
600
+ nl
601
+ no
602
+ no-one
603
+ nobody
604
+ non
605
+ none
606
+ nonetheless
607
+ noone
608
+ nor
609
+ normally
610
+ nos
611
+ not
612
+ noted
613
+ nothing
614
+ notwithstanding
615
+ nowhere
616
+ np
617
+ nr
618
+ nu
619
+ null
620
+ nz
621
+ o
622
+ obtain
623
+ obtained
624
+ obviously
625
+ of
626
+ off
627
+ often
628
+ oh
629
+ ok
630
+ okay
631
+ om
632
+ omitted
633
+ on
634
+ once
635
+ one
636
+ one's
637
+ ones
638
+ only
639
+ onto
640
+ open
641
+ opened
642
+ opening
643
+ opens
644
+ opposite
645
+ or
646
+ ord
647
+ order
648
+ ordered
649
+ ordering
650
+ orders
651
+ org
652
+ other
653
+ others
654
+ otherwise
655
+ ought
656
+ oughtn't
657
+ oughtnt
658
+ our
659
+ ours
660
+ ourselves
661
+ out
662
+ over
663
+ overall
664
+ owing
665
+ own
666
+ p
667
+ pa
668
+ part
669
+ parted
670
+ particular
671
+ particularly
672
+ parting
673
+ parts
674
+ past
675
+ pe
676
+ per
677
+ perhaps
678
+ pf
679
+ pg
680
+ ph
681
+ pk
682
+ pl
683
+ place
684
+ placed
685
+ places
686
+ please
687
+ pm
688
+ pmid
689
+ pn
690
+ pointed
691
+ pointing
692
+ poorly
693
+ possible
694
+ possibly
695
+ potentially
696
+ pp
697
+ pr
698
+ predominantly
699
+ present
700
+ presented
701
+ presenting
702
+ presents
703
+ presumably
704
+ previously
705
+ primarily
706
+ probably
707
+ problem
708
+ problems
709
+ promptly
710
+ proud
711
+ provided
712
+ provides
713
+ pt
714
+ put
715
+ puts
716
+ pw
717
+ py
718
+ q
719
+ qa
720
+ que
721
+ quickly
722
+ quite
723
+ qv
724
+ r
725
+ rather
726
+ rd
727
+ re
728
+ readily
729
+ really
730
+ reasonably
731
+ recent
732
+ recently
733
+ ref
734
+ refs
735
+ regarding
736
+ regardless
737
+ regards
738
+ related
739
+ relatively
740
+ reserved
741
+ respectively
742
+ resulted
743
+ resulting
744
+ results
745
+ ro
746
+ ru
747
+ rw
748
+ s
749
+ sa
750
+ said
751
+ same
752
+ saw
753
+ saying
754
+ says
755
+ sb
756
+ sc
757
+ sd
758
+ se
759
+ sec
760
+ section
761
+ see
762
+ seeing
763
+ seem
764
+ seemed
765
+ seeming
766
+ seems
767
+ seen
768
+ sees
769
+ self
770
+ selves
771
+ sensible
772
+ sent
773
+ serious
774
+ seriously
775
+ several
776
+ sg
777
+ sh
778
+ shall
779
+ shan't
780
+ shant
781
+ she
782
+ she'd
783
+ she'll
784
+ she's
785
+ shed
786
+ shell
787
+ shes
788
+ should
789
+ should've
790
+ shouldn
791
+ shouldn't
792
+ shouldnt
793
+ showed
794
+ showing
795
+ shown
796
+ showns
797
+ si
798
+ side
799
+ sides
800
+ significant
801
+ significantly
802
+ similar
803
+ similarly
804
+ since
805
+ sincere
806
+ site
807
+ sj
808
+ sk
809
+ sl
810
+ slightly
811
+ sm
812
+ sn
813
+ so
814
+ some
815
+ somebody
816
+ someday
817
+ somehow
818
+ someone
819
+ somethan
820
+ something
821
+ sometime
822
+ sometimes
823
+ somewhat
824
+ somewhere
825
+ specifically
826
+ specified
827
+ specify
828
+ specifying
829
+ sr
830
+ st
831
+ state
832
+ states
833
+ still
834
+ stop
835
+ strongly
836
+ su
837
+ sub
838
+ substantially
839
+ successfully
840
+ such
841
+ sufficiently
842
+ suggest
843
+ sup
844
+ sure
845
+ sv
846
+ sy
847
+ sz
848
+ t
849
+ t's
850
+ take
851
+ taken
852
+ taking
853
+ tc
854
+ td
855
+ tell
856
+ tends
857
+ tf
858
+ tg
859
+ th
860
+ than
861
+ thank
862
+ thanks
863
+ thanx
864
+ that
865
+ that'll
866
+ that's
867
+ that've
868
+ thatll
869
+ thats
870
+ thatve
871
+ the
872
+ their
873
+ theirs
874
+ them
875
+ themselves
876
+ then
877
+ thence
878
+ there
879
+ there'd
880
+ there'll
881
+ there're
882
+ there's
883
+ there've
884
+ thereafter
885
+ thereby
886
+ thered
887
+ therefore
888
+ therein
889
+ therell
890
+ thereof
891
+ therere
892
+ theres
893
+ thereto
894
+ thereupon
895
+ thereve
896
+ these
897
+ they
898
+ they'd
899
+ they'll
900
+ they're
901
+ they've
902
+ theyd
903
+ theyll
904
+ theyre
905
+ theyve
906
+ thick
907
+ thin
908
+ thing
909
+ things
910
+ think
911
+ thinks
912
+ third
913
+ thirty
914
+ this
915
+ thorough
916
+ thoroughly
917
+ those
918
+ thou
919
+ though
920
+ thoughh
921
+ thought
922
+ thoughts
923
+ thousand
924
+ throug
925
+ through
926
+ throughout
927
+ thru
928
+ thus
929
+ til
930
+ till
931
+ tis
932
+ tj
933
+ tk
934
+ tm
935
+ tn
936
+ to
937
+ today
938
+ together
939
+ too
940
+ took
941
+ tp
942
+ tr
943
+ tried
944
+ tries
945
+ truly
946
+ trying
947
+ ts
948
+ tt
949
+ turn
950
+ turned
951
+ turning
952
+ turns
953
+ tw
954
+ twas
955
+ tz
956
+ u
957
+ ua
958
+ ug
959
+ uk
960
+ um
961
+ un
962
+ underneath
963
+ undoing
964
+ unfortunately
965
+ unless
966
+ unlike
967
+ unlikely
968
+ until
969
+ unto
970
+ upon
971
+ ups
972
+ us
973
+ use
974
+ used
975
+ useful
976
+ usefully
977
+ usefulness
978
+ uses
979
+ using
980
+ usually
981
+ uucp
982
+ uy
983
+ uz
984
+ v
985
+ va
986
+ value
987
+ various
988
+ vc
989
+ ve
990
+ versus
991
+ very
992
+ vg
993
+ vi
994
+ via
995
+ viz
996
+ vn
997
+ vol
998
+ vols
999
+ vs
1000
+ vu
1001
+ w
1002
+ want
1003
+ wanted
1004
+ wanting
1005
+ wants
1006
+ was
1007
+ wasn
1008
+ wasn't
1009
+ wasnt
1010
+ way
1011
+ ways
1012
+ we
1013
+ we'd
1014
+ we'll
1015
+ we're
1016
+ we've
1017
+ web
1018
+ wed
1019
+ welcome
1020
+ well
1021
+ wells
1022
+ went
1023
+ were
1024
+ weren
1025
+ weren't
1026
+ werent
1027
+ weve
1028
+ wf
1029
+ what
1030
+ what'd
1031
+ what'll
1032
+ what's
1033
+ what've
1034
+ whatever
1035
+ whatll
1036
+ whats
1037
+ whatve
1038
+ when
1039
+ when'd
1040
+ when'll
1041
+ when's
1042
+ whence
1043
+ whenever
1044
+ where
1045
+ where'd
1046
+ where'll
1047
+ where's
1048
+ whereafter
1049
+ whereas
1050
+ whereby
1051
+ wherein
1052
+ wheres
1053
+ whereupon
1054
+ wherever
1055
+ whether
1056
+ which
1057
+ whichever
1058
+ while
1059
+ whilst
1060
+ whim
1061
+ whither
1062
+ who
1063
+ who'd
1064
+ who'll
1065
+ who's
1066
+ whod
1067
+ whoever
1068
+ whole
1069
+ wholl
1070
+ whom
1071
+ whomever
1072
+ whos
1073
+ whose
1074
+ why
1075
+ why'd
1076
+ why'll
1077
+ why's
1078
+ widely
1079
+ width
1080
+ will
1081
+ willing
1082
+ with
1083
+ within
1084
+ without
1085
+ won
1086
+ won't
1087
+ wonder
1088
+ wont
1089
+ words
1090
+ worked
1091
+ working
1092
+ works
1093
+ world
1094
+ would
1095
+ would've
1096
+ wouldn
1097
+ wouldn't
1098
+ wouldnt
1099
+ ws
1100
+ www
1101
+ x
1102
+ y
1103
+ ye
1104
+ year
1105
+ years
1106
+ yes
1107
+ yet
1108
+ you
1109
+ you'd
1110
+ you'll
1111
+ you're
1112
+ you've
1113
+ youd
1114
+ youll
1115
+ your
1116
+ youre
1117
+ yours
1118
+ yourself
1119
+ yourselves
1120
+ youve
1121
+ yt
1122
+ yu
1123
+ z
1124
+ za
1125
+ zm
1126
+ zr
kgs_binding/kg_base_wrapper.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #############################
3
+ # Imports
4
+ #############################
5
+
6
+ # Python modules
7
+ from abc import ABC, abstractmethod
8
+ from typing import Tuple, Optional, List
9
+
10
+ # Remote modules
11
+ from nltk.stem import WordNetLemmatizer
12
+
13
+ # Local modules
14
+
15
+ #############################
16
+ # Constants
17
+ #############################
18
+
19
+ class KGBaseHandler(ABC):
20
+ def __init__(self):
21
+ super().__init__()
22
+ self.st = WordNetLemmatizer()
23
+
24
+ def normalize_noun(self, ent):
25
+ try:
26
+ noun = self.st.lemmatize(ent, pos='n')
27
+ noun = self.st.lemmatize(noun, pos='v')
28
+ except Exception as _:
29
+ noun = ent[:-1] if ent[-1] == 's' else ent
30
+ return noun
31
+
32
+ def normalize_nouns(self, ent):
33
+ local_ent = ent[:]
34
+ nouns = local_ent.split(' ')
35
+ if len(nouns) == 1:
36
+ return ' '.join([self.normalize_noun(e) for e in nouns])
37
+ return local_ent
38
+
39
+ def ignore_less_relevant_connection(self, relations):
40
+ if len(relations) >= 2:
41
+ for r in relations:
42
+ if r != 'related_to':
43
+ return r
44
+ return relations[0]
45
+
46
+ @abstractmethod
47
+ def get_relation_types(self) -> List[str]:
48
+ pass
49
+
50
+ @abstractmethod
51
+ def exists_relation_between(self, concept, other_concept) -> bool:
52
+ pass
53
+
54
+ @abstractmethod
55
+ def relation_between(self, concept, other_concept) -> Tuple[Optional[str], Optional[str]]:
56
+ pass
57
+
58
+ @abstractmethod
59
+ def get_related_concepts(self, concept) -> Optional[List[str]]:
60
+ pass
61
+
62
+ @abstractmethod
63
+ def does_concept_exist(self, concept) -> bool:
64
+ pass
65
+
66
+ class NoKnowledge(KGBaseHandler):
67
+ def __init__(self):
68
+ super(NoKnowledge, self).__init__()
69
+
70
+ def get_relation_types(self) -> List[str]:
71
+ return []
72
+
73
+ def exists_relation_between(self, concept, other_concept) -> bool:
74
+ return False
75
+
76
+ def relation_between(self, concept, other_concept) -> Tuple[Optional[str], Optional[str]]:
77
+ return (None, None)
78
+
79
+ def does_concept_exist(self, concept) -> bool:
80
+ return False
kgs_binding/kg_qa_binding_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports
3
+ #############################
4
+
5
+ # Python modules
6
+ from typing import List, Tuple
7
+ from enum import Enum
8
+
9
+ # Remote modules
10
+
11
+ # Local modules
12
+ from .kg_base_wrapper import KGBaseHandler
13
+ from .swow_handler import SwowHandler
14
+ from .conceptnet_handler import ConceptNetHandler
15
+ from utils import read_json_file_2_dict, Data_Type
16
+
17
+ #############################
18
+ # Constants
19
+ #############################
20
+
21
+ #############################
22
+ # Stuff
23
+ #############################
24
+
25
+ class KGType(Enum):
26
+ SWOW = 'swow'
27
+ CSKG = 'cskg'
28
+ CONCEPTNET = 'conceptnet'
29
+
30
+ def load_kg_handler(kg_type: KGType):
31
+ if kg_type.value == KGType.SWOW.value:
32
+ return SwowHandler()
33
+ elif kg_type.value == KGType.CONCEPTNET.value:
34
+ return ConceptNetHandler()
35
+ else:
36
+ raise NotImplementedError()
37
+
38
+ def _load_data_paths_metadata():
39
+ try:
40
+ data = read_json_file_2_dict('data_config.json', store_dir='run_config')
41
+ except:
42
+ data = None
43
+ return data
44
+
45
+ def from_relations_path_2_relations(dataset_types: List[Data_Type], metadata):
46
+ relations = []
47
+ print('metadata:', metadata)
48
+ for dataset_type in dataset_types:
49
+ qa_meta_data = metadata[dataset_type.value]
50
+ filename_path, dir_data = qa_meta_data['local']
51
+ print(filename_path, dir)
52
+ data = read_json_file_2_dict(filename_path, dir_data)
53
+ relations.extend(data)
54
+ return relations
55
+
56
+ def KGHandler_to_str(kg_handler: KGBaseHandler) -> str:
57
+ if isinstance(kg_handler, SwowHandler):
58
+ return 'swow'
59
+ elif isinstance(kg_handler, ConceptNetHandler):
60
+ return 'conceptnet'
61
+ else:
62
+ raise NotImplementedError()
63
+
64
+ def get_kg_qa_data_metadata(kg_handler: KGBaseHandler) -> Tuple[str, str]:
65
+ kg_qa_data_path = _load_data_paths_metadata()
66
+ if isinstance(kg_handler, SwowHandler):
67
+ swow = kg_qa_data_path["swow"]
68
+ return swow
69
+ elif isinstance(kg_handler, ConceptNetHandler):
70
+ conceptnet = kg_qa_data_path["conceptnet"]
71
+ return conceptnet
72
+ else:
73
+ raise NotImplementedError()
kgs_binding/parsing_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #############################
3
+ # Imports
4
+ #############################
5
+
6
+ # Python modules
7
+ import re
8
+ import string
9
+
10
+ # Remote modules
11
+
12
+ # Local modules
13
+ from utils import (
14
+ read_simple_text_file_2_vec
15
+ )
16
+
17
+ #############################
18
+ # Utils
19
+ #############################
20
+
21
+ class ParsingUtils:
22
+
23
+ STOPWORDS = read_simple_text_file_2_vec('english_stopwords.txt', store_dir='kgs_binding')
24
+
25
+ @staticmethod
26
+ def remove_pontuation(text):
27
+ text = re.sub(r"[^a-zA-Z]", " ", text)
28
+ return text.translate(str.maketrans('', '', string.punctuation))
29
+
30
+ @staticmethod
31
+ def clear_common_words(index_with_words):
32
+ return [(word, (s, e)) for (word, (s, e)) in index_with_words if word not in ParsingUtils.STOPWORDS]
33
+
34
+ @staticmethod
35
+ def is_word_a_relevant_one(ignore_common_words, word):
36
+ if ignore_common_words:
37
+ return word not in ParsingUtils.STOPWORDS
38
+ else:
39
+ return True
40
+
41
+ @staticmethod
42
+ def get_word_range_mapping(context, word_token):
43
+ word_token_splitted = word_token.split(' ')
44
+ if len(word_token_splitted) == 1:
45
+ word_token_start = context.index(word_token)
46
+ word_token_end = word_token_start + len(word_token) - 1 # inclusive end
47
+ else:
48
+ word_token_start = context.index(word_token_splitted[0])
49
+ word_token_end = word_token_start + len(word_token) - 1 # inclusive end
50
+ return word_token_start, word_token_end
51
+
52
+ @staticmethod
53
+ def n_grams(words_vector, n):
54
+ grams = [words_vector[i:i + n] for i in range(len(words_vector) - n + 1)]
55
+ print(grams)
56
+ return [' '.join(x) for x in grams]
57
+
58
+ @staticmethod
59
+ def n_grams_with_idx(words_vector, n):
60
+ grams = [words_vector[i:i + n] for i in range(len(words_vector) - n + 1)]
61
+ return [(' '.join([pair[0] for pair in x]), (x[0][1], x[-1][1]+len(x[-1][0]))) for x in grams]
62
+
63
+ @staticmethod
64
+ def n_grams_context_producer_simple(context, n_gram=2):
65
+ context_tokens = context.strip().split(' ')
66
+ #context_tokens = [w for w in context_tokens if w not in STOPWORDS]
67
+ n_grams_context = []
68
+ for i in range(n_gram):
69
+ n_gram_content = ParsingUtils.n_grams(context_tokens, n_gram-i)
70
+ n_grams_context.append(n_gram_content)
71
+ return n_grams_context
72
+
73
+ @staticmethod
74
+ def n_grams_n_words_extractor(context, n_gram=3):
75
+ context_tokens = context.strip().split(' ')
76
+ context_tokens_with_index_info=[]
77
+ word_idx=0
78
+ for word in context_tokens:
79
+ context_tokens_with_index_info.append((word, word_idx))
80
+ word_idx += len(word) + 1
81
+ #context_tokens = [w for w in context_tokens if w not in STOPWORDS]
82
+ n_grams_context = []
83
+ for i in range(n_gram):
84
+ n_gram_content = ParsingUtils.n_grams_with_idx(context_tokens_with_index_info, n_gram-i)
85
+ n_grams_context.extend(n_gram_content)
86
+ return n_grams_context
kgs_binding/relation_mapper_builder.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #############################
3
+ # Imports
4
+ #############################
5
+
6
+ # Python modules
7
+ from collections import deque
8
+ from collections import defaultdict
9
+ from typing import List, Dict, Optional
10
+ from ast import literal_eval
11
+ from random import sample
12
+
13
+ # Remote modules
14
+
15
+ # Local modules
16
+ from .kg_base_wrapper import KGBaseHandler
17
+ from .swow_handler import SwowHandler
18
+
19
+ from utils import (
20
+ read_json_file_2_dict,
21
+ Data_Type,
22
+ )
23
+ from .parsing_utils import ParsingUtils
24
+
25
+ #############################
26
+ # Constants
27
+ #############################
28
+
29
+ #############################
30
+ # Stuff
31
+ #############################
32
+
33
+ class RelationsMapperBuilder:
34
+ def __init__(self, knowledge: KGBaseHandler,
35
+ filename: Optional[str] = None,
36
+ file_dir: Optional[str] = None,
37
+ datatype: Optional[Data_Type] = None,
38
+ tok_sep:str = '</s>',
39
+ use_extra_relations=True):
40
+ self.tok_sep = tok_sep
41
+ self.knowledge = knowledge
42
+ self.swow_knowledge = SwowHandler()
43
+ self.use_extra_relations = use_extra_relations
44
+ if filename and file_dir and datatype:
45
+ full_context = self.load_data(filename, file_dir)
46
+ self.relevant_context = self.fetch_relevant_context_from_data(data=full_context, datatype=datatype)
47
+
48
+ def load_data(self, filename='commongen_qa_final.json', store_dir='./'):
49
+ data = read_json_file_2_dict(filename=filename, store_dir=store_dir)
50
+ print('data[0]:', data[0])
51
+ return data
52
+
53
+ def fetch_relevant_context_from_data(self, data: List[Dict], datatype:Data_Type = Data_Type.COMMONGEN_QA):
54
+ if datatype == Data_Type.COMMONGEN_QA:
55
+ model_input = [data_unit.get('title').lower() for data_unit in data]
56
+ elif datatype in [Data_Type.ELI5, Data_Type.STACK_EXCHANGE]:
57
+ model_input = [data_unit.get('question').lower() for data_unit in data]
58
+ elif datatype in [Data_Type.COMMONSENSE_QA]:
59
+ #questions = [data_unit.get('question').lower() for data_unit in data]
60
+ #model_input = datasets_parsing_utils.compose_commonsenseqa_data(data)
61
+ model_input = [data_unit.get('input_data') for data_unit in data]
62
+ elif datatype in [Data_Type.COMMONGEN]:
63
+ #questions = [data_unit.get('input_data').lower() for data_unit in data]
64
+ #model_input = datasets_parsing_utils.compose_commongen_data(data)
65
+ model_input = [data_unit.get('input_data') for data_unit in data]
66
+ else:
67
+ model_input = []
68
+ return model_input
69
+
70
+ def get_kg_concepts_from_context(self, context=None, clear_common_wds=False):
71
+ if not context:
72
+ context = self.relevant_context
73
+ context_words = []
74
+ for q_id, question in enumerate(context):
75
+ simple_question = ParsingUtils.remove_pontuation(question)
76
+ n_grams = ParsingUtils.n_grams_n_words_extractor(simple_question)
77
+ words = self.relevant_entities_extractor(n_grams)
78
+ if clear_common_wds:
79
+ words = ParsingUtils.clear_common_words(words)
80
+ simple_words = [word[0] for word in words]
81
+ context_words.append(simple_words)
82
+ return context_words
83
+
84
+ def obtain_concept_neighbours(self, context_concepts:List[str], n_neighbours = 20):
85
+ """
86
+ Use swow to get connected concepts, but then refer back to conceptnet for rich relations
87
+ """
88
+ neighbours = []
89
+ for concept in context_concepts:
90
+ external_neighbour_concepts = self.swow_knowledge.get_related_concepts(concept)
91
+ relevant_concepts = external_neighbour_concepts
92
+ #local_neighbour_concepts = self.knowledge.get_related_concepts(concept)
93
+ #relevant_concepts = [ext_concept for ext_concept in external_neighbour_concepts if ext_concept in local_neighbour_concepts]
94
+ neighbours.extend(relevant_concepts)
95
+ n_neighbours = min(n_neighbours, len(neighbours))
96
+ some_neighbours = sample(neighbours, n_neighbours)
97
+ #print('context_concepts:', context_concepts)
98
+ #print('some_neighbours:', some_neighbours)
99
+ return some_neighbours
100
+
101
+
102
+ def get_relations_mapping_complex(self, context=None, clear_common_wds=False):
103
+ if not context:
104
+ context = self.relevant_context
105
+ relations_info = deque()
106
+ for q_id, question in enumerate(context):
107
+ simple_question = ParsingUtils.remove_pontuation(question)
108
+ n_grams = ParsingUtils.n_grams_n_words_extractor(simple_question)
109
+ words = self.relevant_entities_extractor(n_grams)
110
+ if clear_common_wds:
111
+ words = ParsingUtils.clear_common_words(words)
112
+ #print(f'question: {question}')
113
+ #print(f'words: {words}')
114
+ relation_context_between_words = defaultdict(dict)
115
+ known_tokens = set()
116
+ for token_i, (first_word_token, first_word_range) in enumerate(words[:-1]):
117
+ known_tokens.add(first_word_token)
118
+ first_word_range_str = str(first_word_range)
119
+ # normalize
120
+ first_word_phrase_normalized = self.knowledge.normalize_nouns(first_word_token)
121
+ for (second_word_token, second_word_range) in [w for w in words[token_i + 1:] if w not in known_tokens]:
122
+ second_word_range_str = str(second_word_range)
123
+ second_word_phrase_normalized = self.knowledge.normalize_nouns(second_word_token)
124
+ left_2_right, right_2_left = self.knowledge.relation_between(first_word_phrase_normalized, second_word_phrase_normalized)
125
+ #print(first_word_token, second_word_token, left_2_right, right_2_left)
126
+ if left_2_right:
127
+ relation_context_between_words[first_word_range_str][second_word_range_str] = left_2_right
128
+ if right_2_left:
129
+ relation_context_between_words[second_word_range_str][first_word_range_str] = right_2_left
130
+ relations_info.append(dict(relation_context_between_words))
131
+ return list(relations_info)
132
+
133
+ def get_concepts_from_context(self, context=None, clear_common_wds=False,alignment=0):
134
+ relations_info = self.get_relations_mapping_complex(context=[context], clear_common_wds=clear_common_wds)
135
+ words = []
136
+ #print('relations_info here:', relations_info)
137
+ for rels in relations_info:
138
+ for coords, v in rels.items():
139
+ coords_tuple = literal_eval(coords)
140
+ i,j = coords_tuple
141
+ words.append(context[i+alignment:j+alignment])
142
+ for coords_other, rel in v.items():
143
+ coords_other_tuple = literal_eval(coords_other)
144
+ i_other, j_other = coords_other_tuple
145
+ words.append(context[i_other+alignment: j_other+alignment])
146
+ returning_words = list(set(words))
147
+ #print('returning_words:', returning_words)
148
+ return returning_words
149
+
150
+ def relevant_entities_extractor(self, n_grams_n_words, verbose_output=True):
151
+ non_overlapping_knowledge = {}
152
+ # print(n_grams_n_words)
153
+ for concept, (idx_start, idx_end) in n_grams_n_words:
154
+ normalized_concept = self.knowledge.normalize_nouns(concept)
155
+ exists = self.knowledge.does_concept_exist(normalized_concept)
156
+ #print('exists: ', concept, normalized_concept, exists)
157
+ if exists and idx_start not in non_overlapping_knowledge and \
158
+ idx_end not in non_overlapping_knowledge:
159
+ non_overlapping_knowledge[idx_start] = (concept, idx_start, idx_end, 'start_idx')
160
+ non_overlapping_knowledge[idx_end] = (concept, idx_end, idx_end, 'end_idx')
161
+ if verbose_output:
162
+ return [(value[0], (value[1], value[2])) for k, value in sorted(non_overlapping_knowledge.items()) if value[-1] == 'start_idx']
163
+ else:
164
+ return [value[0] for k, value in sorted(non_overlapping_knowledge.items()) if value[-1] == 'start_idx']
kgs_binding/swow/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import *
kgs_binding/swow/swow_knowledge.json ADDED
The diff for this file is too large to render. See raw diff
 
kgs_binding/swow_handler.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #############################
3
+ # Imports
4
+ #############################
5
+
6
+ # Python modules
7
+ import random
8
+ from typing import Tuple, Optional, List
9
+
10
+ # Remote modules
11
+
12
+ # Local modules
13
+ from .kg_base_wrapper import KGBaseHandler
14
+
15
+
16
+ from utils import read_json_file_2_dict
17
+
18
+ #############################
19
+ # Constants
20
+ #############################
21
+
22
+ #############################
23
+ # Stuff
24
+ #############################
25
+
26
+ class SwowHandler(KGBaseHandler):
27
+ def __init__(self, store_dir='kgs_binding/swow'):
28
+ super(SwowHandler, self).__init__()
29
+ self.swow: dict = self.load_stored_data(store_dir=store_dir)
30
+
31
+ def get_relation_types(self) -> List[str]:
32
+ return ['related_to']
33
+
34
+ def load_stored_data(self, filename='swow_knowledge.json', store_dir='kgs_binding/swow'):
35
+ self.swow = read_json_file_2_dict(filename, store_dir)
36
+ return self.swow
37
+
38
+ def exists_relation_between(self, concept, other_concept):
39
+ connections = self.swow.get(concept)
40
+ if not connections:
41
+ return False
42
+ for connetion in connections:
43
+ if connetion == other_concept:
44
+ return True
45
+ return False
46
+
47
+ def does_concept_exist(self, concept):
48
+ return self.swow.get(concept, None) is not None
49
+
50
+ def relation_between(self, concept, other_concept) -> Tuple[Optional[str], Optional[str]]:
51
+ exists_left_right = self.exists_relation_between(concept, other_concept)
52
+ exists_right_left = self.exists_relation_between(other_concept, concept)
53
+ relation = None
54
+ if exists_left_right or exists_right_left:
55
+ relation = 'related_to'
56
+ return relation, relation
57
+
58
+ def get_related_concepts(self, concept) -> Optional[List[str]]:
59
+ return self.swow.get(concept, [])
60
+
61
+ def simple_knowledge_prediction(self, knowledge):
62
+ kw = list(knowledge)
63
+ idx = random.randint(0, len(knowledge)-1) # 0-1-2
64
+ kw[idx] = '<mask>'
65
+ textual_knowledge_input = f'{kw[0]} {kw[1]} {kw[2]}'
66
+ label = f'{knowledge[0]} {knowledge[1]} {knowledge[2]}'
67
+ return f'{textual_knowledge_input},{label}\n', label
68
+
69
+ def create_mask_knowledge_for_model(self):
70
+ with open(f'bart_input/swow_bart.txt', 'w') as f:
71
+ for subject, objects in self.swow.items():
72
+ for obj in objects:
73
+ knowledge = (subject, 'is related to', obj)
74
+ w_kw, label = self.simple_knowledge_prediction(knowledge)
75
+ f.write(w_kw)
model_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #############################
3
+ # Imports
4
+ #############################
5
+
6
+ # Python modules
7
+ from typing import List
8
+ from random import randint
9
+
10
+ # Remote modules
11
+ import torch
12
+
13
+ # Local modules
14
+ from utils import Head_Mask
15
+
16
+ #############################
17
+ # Constants
18
+ #############################
19
+
20
+ #############################
21
+ # Stuff
22
+ #############################
23
+
24
+ def create_layers_head_mask(config, head_mask_type: Head_Mask=Head_Mask.ALL, specific_heads: List[int] = None):
25
+ mask_heads = torch.zeros((config.encoder_layers, config.encoder_attention_heads))
26
+ if head_mask_type == Head_Mask.RANDOM:
27
+ for i in range(config.encoder_layers):
28
+ rand_idx = randint(0, config.encoder_attention_heads-1)
29
+ mask_heads[i, rand_idx] = 1
30
+ elif head_mask_type == Head_Mask.NONE:
31
+ mask_heads[:, :] = 1
32
+ elif head_mask_type == Head_Mask.ALL:
33
+ pass
34
+ elif head_mask_type == Head_Mask.SPECIFIC:
35
+ if specific_heads:
36
+ for layer_i in range(len(mask_heads)):
37
+ specific_head = specific_heads[layer_i] - 1
38
+ mask_heads[layer_i][specific_head] = 1
39
+ else:
40
+ mask_heads = torch.Tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0],
41
+ [1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0],
42
+ [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
43
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
44
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
45
+ [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
46
+ [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
47
+ [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
48
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
49
+ [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1],
50
+ [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1],
51
+ [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]])
52
+ else:
53
+ raise NotImplementedError()
54
+ return mask_heads.tolist()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ numpy
4
+ matplotlib
utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################
2
+ # Imports and Contants #
3
+ #############################
4
+
5
+ # Python modules
6
+ from enum import Enum
7
+ import os
8
+ import json
9
+ import time
10
+
11
+ # Remote packages
12
+ import torch
13
+
14
+ #############################
15
+ # utilities
16
+ #############################
17
+
18
+ class ScoringType(Enum):
19
+ DEFAULT = 'default'
20
+ MAX_PROB = 'max-prob'
21
+ INTERPOL = 'interpol'
22
+ CONSTRAINT = 'constraint'
23
+ MULTIPLE_CHOICE = 'multiple_choice'
24
+
25
+ class LossType(Enum):
26
+ DEFAULT = 'default'
27
+ CP_RP_DEF = 'cp-rp-def'
28
+ CP_DEF = 'cp-def'
29
+ PRP_NRP_DEF = 'prp-nrp-def'
30
+
31
+ class Head_Mask(Enum):
32
+ ALL = 'all'
33
+ NONE = 'none'
34
+ RANDOM = 'random'
35
+ SPECIFIC = 'specific'
36
+
37
+ class KGType(Enum):
38
+ SWOW = 'swow'
39
+ CSKG = 'cskg'
40
+ CONCEPTNET = 'conceptnet'
41
+
42
+ class Model_Type(Enum):
43
+ RELATIONS = 'relations'
44
+ MASK = 'mask'
45
+ DEFAULT = 'default'
46
+
47
+ def is_simple_mask_commonsense(self):
48
+ return self == Model_Type.MASK
49
+
50
+ def there_is_difference_between_relations(self):
51
+ return self == Model_Type.RELATIONS
52
+
53
+ class Data_Type(Enum):
54
+ ELI5 = 'eli5'
55
+ COMMONSENSE_QA = 'commonsense_qa'
56
+ COMMONGEN_QA = 'commongen_qa'
57
+ STACK_EXCHANGE = 'stackexchange_qa'
58
+ ASK_SCIENCE = 'ask_science_qa'
59
+ NATURAL_QUESTIONS = 'natural_questions'
60
+ LAMA = 'lama'
61
+ CONCEPTNET = 'conceptnet'
62
+ CUSTOM = 'custom'
63
+ COMMONGEN = 'commongen'
64
+
65
+ @staticmethod
66
+ def data_types_to_str(data_types):
67
+ datasets_str = '-'.join([x.value for x in data_types])
68
+ return datasets_str
69
+
70
+ #############################
71
+ # Models
72
+ #############################
73
+
74
+ MODELS_PRETRAINING_NAME = {
75
+ "bart_large": "facebook/bart-large",
76
+ "bart_large_fp32": "patrickvonplaten/bart-large-fp32",
77
+ "bart_large_tweak": "",
78
+ "bart_base": "facebook/bart-base"
79
+ }
80
+
81
+ CURRENT_PRETRAINING_NAME = MODELS_PRETRAINING_NAME.get('bart_large_fp32')
82
+
83
+ #############################
84
+ # Files Managment #
85
+ #############################
86
+
87
+ def create_directory(output_dir):
88
+ # Create output directory if needed
89
+ if not os.path.exists(output_dir):
90
+ try:
91
+ os.makedirs(output_dir)
92
+ except FileExistsError as _:
93
+ return
94
+ else:
95
+ print(f"Output directory {output_dir} already exists")
96
+
97
+ def read_simple_text_file_2_vec(filename, store_dir='.'):
98
+ with open(f'{store_dir}/{filename}', 'r') as f:
99
+ return f.read().split('\n')
100
+
101
+ def write_dict_2_json_file(json_object, filename, store_dir='.'):
102
+ create_directory(store_dir)
103
+ with open(f'{store_dir}/{filename}', 'w', encoding='utf-8') as file:
104
+ json.dump(json_object, file, ensure_ascii=False, indent=4)
105
+
106
+
107
+ def read_json_file_2_dict(filename, store_dir='.'):
108
+ with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
109
+ return json.load(file)
110
+
111
+ def read_jsonl_file_2_dict(filename, store_dir='.'):
112
+ elements = []
113
+ with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
114
+ for line in file:
115
+ elements.append(json.loads(line))
116
+ return elements
117
+
118
+ def read_txt_2_list(filename, store_dir='.'):
119
+ with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
120
+ return file.read().split('\n')
121
+
122
+ #############################
123
+ # Data Structures helper functions
124
+ #############################
125
+
126
+ def get_chunks(lst, n):
127
+ """Yield successive n-sized chunks from lst."""
128
+ jump = len(lst)//n
129
+ for i in range(0, len(lst), jump):
130
+ yield lst[i:i + jump]
131
+
132
+ def get_jump_chunks(lst, jump):
133
+ """Yield successive n-sized chunks from lst."""
134
+ for i in range(0, len(lst), jump):
135
+ yield lst[i:i + jump]
136
+
137
+ def join_str_first(sep_str, lis):
138
+ return '{1}{0}'.format(sep_str.join(lis), sep_str).strip()
139
+
140
+ #############################
141
+ # Huggingface
142
+ #############################
143
+
144
+ def inputs_introspection_print(tokenizer, inputs):
145
+ input_ids = inputs.get('input_ids', None)
146
+ input_text = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
147
+ labels_ids = inputs.get('labels', None)
148
+ labels_text = tokenizer.batch_decode(labels_ids, skip_special_tokens=False)
149
+ print('orginal input:', input_text[:2])
150
+ print("::::::::::::::::::::::::::")
151
+ print('orginal labels:', labels_text[:2])
152
+ print("==========|||||==========")
153
+
154
+ def tok_data_2_text(tokenizer, all_inputs):
155
+ def clean_input_text(text):
156
+ real_text = text.split(tokenizer.eos_token)[0]
157
+ real_text = real_text.replace(tokenizer.bos_token, '').strip()
158
+ return real_text
159
+ all_input_text, all_labels_text = [], []
160
+ for inputs in all_inputs:
161
+ input_ids = inputs.get('input_ids', None)
162
+ input_text = tokenizer.decode(input_ids, skip_special_tokens=False)
163
+ labels_ids = inputs.get('labels', None)
164
+ labels_text = tokenizer.decode(labels_ids, skip_special_tokens=True)
165
+ #print('input_text:', input_text)
166
+ #print('labels_text:', labels_text)
167
+ input_text = clean_input_text(input_text)
168
+ all_input_text.append(input_text)
169
+ all_labels_text.append(labels_text)
170
+ return all_input_text, all_labels_text
171
+
172
+ #############################
173
+ # Torch
174
+ #############################
175
+
176
+ def get_device(verbose:bool=True):
177
+ # If there's a GPU available...
178
+ if torch.cuda.is_available():
179
+ device = torch.device("cuda")
180
+ n_gpus = torch.cuda.device_count()
181
+ first_gpu = torch.cuda.get_device_name(0)
182
+ if verbose:
183
+ print(f'There are {n_gpus} GPU(s) available.')
184
+ print(f'GPU gonna be used: {first_gpu}')
185
+ else:
186
+ if verbose:
187
+ print('No GPU available, using the CPU instead.')
188
+ device = torch.device("cpu")
189
+ return device
190
+
191
+ #############################
192
+ # Timing
193
+ #############################
194
+
195
+ def timing_decorator(func):
196
+ def wrapper(*args, **kwargs):
197
+ start = time.time()
198
+ original_return_val = func(*args, **kwargs)
199
+ end = time.time()
200
+ print("time elapsed in ", func.__name__, ": ", end - start, sep='')
201
+ return original_return_val
202
+
203
+ return wrapper
204
+
205
+ #############################
206
+ # PRINTING UTILS
207
+ #############################
208
+
209
+ class LOGGER_COLORS:
210
+ HEADER = '\033[95m'
211
+ OKBLUE = '\033[94m'
212
+ INFOCYAN = '\033[96m'
213
+ OKGREEN = '\033[92m'
214
+ WARNING = '\033[93m'
215
+ FAIL = '\033[91m'
216
+ ENDC = '\033[0m'
217
+ BOLD = '\033[1m'
218
+ UNDERLINE = '\033[4m'
219
+
220
+ def print_info(logger, message):
221
+ logger.info(f'{LOGGER_COLORS.INFOCYAN}[INFO]{LOGGER_COLORS.ENDC}: {message}')
222
+
223
+ def print_success(logger, message):
224
+ logger.info(f'{LOGGER_COLORS.OKGREEN}[SUCCESS]{LOGGER_COLORS.ENDC}: {message}')
225
+
226
+ def print_warning(logger, message):
227
+ logger.info(f'{LOGGER_COLORS.WARNING}[WARNING]{LOGGER_COLORS.ENDC}: {message}')
228
+
229
+ def print_fail(logger, message):
230
+ logger.info(f'{LOGGER_COLORS.FAIL}[FAIL]{LOGGER_COLORS.ENDC}: {message}')