############################# # Imports ############################# # Python modules from typing import List from random import randint # Remote modules import torch # Local modules from utils import Head_Mask ############################# # Constants ############################# ############################# # Stuff ############################# def create_layers_head_mask(config, head_mask_type: Head_Mask=Head_Mask.ALL, specific_heads: List[int] = None): mask_heads = torch.zeros((config.encoder_layers, config.encoder_attention_heads)) if head_mask_type == Head_Mask.RANDOM: for i in range(config.encoder_layers): rand_idx = randint(0, config.encoder_attention_heads-1) mask_heads[i, rand_idx] = 1 elif head_mask_type == Head_Mask.NONE: mask_heads[:, :] = 1 elif head_mask_type == Head_Mask.ALL: pass elif head_mask_type == Head_Mask.SPECIFIC: if specific_heads: for layer_i in range(len(mask_heads)): specific_head = specific_heads[layer_i] - 1 mask_heads[layer_i][specific_head] = 1 else: mask_heads = torch.Tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0], [1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0], [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0], [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]]) else: raise NotImplementedError() return mask_heads.tolist()