|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch PagnolXl model.""" |
|
|
|
import math |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss |
|
from torch.nn import functional as F |
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
CausalLMOutputWithCrossAttentions, |
|
QuestionAnsweringModelOutput, |
|
SequenceClassifierOutputWithPast, |
|
TokenClassifierOutput, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
add_code_sample_docstrings, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
) |
|
|
|
from .configuration_pagnolxl import PagnolXlConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
PAGNOLXL_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
|
"XXXX/pagnol-xl", |
|
] |
|
|
|
_CHECKPOINT_FOR_DOC = "XXXX/pagnol-xl" |
|
_CONFIG_FOR_DOC = "PagnolXlConfig" |
|
|
|
|
|
class PagnolXlEmbeddings(nn.Module): |
|
"""Implementation of the PagnolXl Embedding layer. |
|
|
|
Parameters |
|
---------- |
|
vocab_size: int, |
|
size of the vocabulary. |
|
d_model: int, |
|
Dimension of the hidden representations. |
|
sigma: int, default 0.02, |
|
standard deviation for the Gaussian initialization of the embedding weights. |
|
""" |
|
|
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__() |
|
self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
|
|
|
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: |
|
return self.embedding(input_ids) |
|
|
|
|
|
|
|
def rotate_half(x): |
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
class PagnoXlRotaryEmbeddings(nn.Module): |
|
"""Implementation of RotaryEmbedding from GPT-NeoX and Falcon. |
|
This implementation is designed to operate on queries and keys that are compatible with `[batch_size, |
|
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format). |
|
""" |
|
|
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__() |
|
assert ( |
|
config.d_model % config.n_heads == 0 |
|
), "d_model must be divisible by n_heads. Currently d_model: {}, n_heads: {}".format( |
|
config.d_model, config.n_heads |
|
) |
|
|
|
self.d_model = config.d_model |
|
self.n_heads = config.n_heads |
|
self.head_dim = config.d_model // config.n_heads |
|
self.base = config.to_dict().get("base", 10000) |
|
inv_freq = 1.0 / ( |
|
self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim) |
|
) |
|
self.register_buffer("inv_freq", inv_freq) |
|
self.seq_len_cached = -1 |
|
self.cos_cached: torch.Tensor | None = None |
|
self.sin_cached: torch.Tensor | None = None |
|
|
|
def cos_sin( |
|
self, |
|
seq_len: int, |
|
past_key_values_length: int, |
|
device="cpu", |
|
dtype=torch.bfloat16, |
|
) -> torch.Tensor: |
|
total_length = seq_len + past_key_values_length |
|
if total_length > self.seq_len_cached: |
|
self.seq_len_cached = total_length |
|
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) |
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1).to(device) |
|
|
|
if dtype in [torch.float16, torch.bfloat16]: |
|
emb = emb.float() |
|
|
|
self.cos_cached = emb.cos()[None, :, :] |
|
self.sin_cached = emb.sin()[None, :, :] |
|
|
|
self.cos_cached = self.cos_cached.type(dtype) |
|
self.sin_cached = self.sin_cached.type(dtype) |
|
|
|
return ( |
|
self.cos_cached[ |
|
:, past_key_values_length : seq_len + past_key_values_length |
|
], |
|
self.sin_cached[ |
|
:, past_key_values_length : seq_len + past_key_values_length |
|
], |
|
) |
|
|
|
def forward(self, query, key, past_key_values_length=0): |
|
batch, num_heads, seq_len, head_dim = query.shape |
|
cos, sin = self.cos_sin( |
|
seq_len, past_key_values_length, query.device, query.dtype |
|
) |
|
return (query * cos) + (rotate_half(query) * sin), (key * cos) + ( |
|
rotate_half(key) * sin |
|
) |
|
|
|
|
|
def _make_causal_mask( |
|
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int |
|
) -> torch.BoolTensor: |
|
""" |
|
Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it |
|
just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1, |
|
target_length, target_length+past_key_values_length]`. |
|
""" |
|
batch_size, target_length = input_ids_shape |
|
|
|
mask = torch.triu( |
|
torch.ones((target_length, target_length), dtype=torch.bool, device=device), |
|
diagonal=1, |
|
) |
|
|
|
|
|
|
|
past_mask = torch.zeros( |
|
(target_length, past_key_values_length), dtype=torch.bool, device=device |
|
) |
|
mask = torch.cat([past_mask, mask], dim=-1) |
|
expanded_mask = mask[None, None, :, :].expand( |
|
batch_size, 1, target_length, target_length + past_key_values_length |
|
) |
|
return expanded_mask |
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor: |
|
""" |
|
Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`. |
|
""" |
|
batch_size, total_length = mask.shape |
|
seq_length = ( |
|
total_length - past_key_values_length |
|
if past_key_values_length is not None |
|
else total_length |
|
) |
|
|
|
expanded_mask = ~(mask[:, None, None, :].to(torch.bool)) |
|
return expanded_mask.expand(batch_size, 1, seq_length, total_length) |
|
|
|
|
|
class PagnolXlAttention(nn.Module): |
|
"""Implementation of Pagnol's MultiHeadAttention following `Karpathy's MinGPT <https://github.com/karpathy/minGPT>`_. |
|
The internals are easier to modify with respect to the native Pytorch version, however it does not support |
|
providing padding masks in the forward. |
|
""" |
|
|
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__() |
|
assert config.d_model % config.n_heads == 0 |
|
self.d_model = config.d_model |
|
self.n_heads = config.n_heads |
|
self.dropout = config.dropout |
|
self.sigma = config.sigma |
|
self.n_layers = config.n_layers |
|
|
|
|
|
self.key = nn.Linear(config.d_model, config.d_model) |
|
self.query = nn.Linear(config.d_model, config.d_model) |
|
self.value = nn.Linear(config.d_model, config.d_model) |
|
|
|
|
|
self.attn_drop = nn.Dropout(config.dropout) |
|
self.resid_drop = nn.Dropout(config.dropout) |
|
|
|
|
|
self.proj = nn.Linear(config.d_model, config.d_model) |
|
|
|
|
|
self.n_heads = config.n_heads |
|
|
|
self.rotary_embedding = PagnoXlRotaryEmbeddings(config) |
|
|
|
def init_weights(self): |
|
|
|
std = self.sigma / math.sqrt(2.0 * self.n_layers) |
|
torch.nn.init.normal_(self.key.weight, mean=0.0, std=self.sigma) |
|
torch.nn.init.normal_(self.query.weight, mean=0.0, std=self.sigma) |
|
torch.nn.init.normal_(self.value.weight, mean=0.0, std=self.sigma) |
|
|
|
torch.nn.init.constant_(self.key.bias, 0.0) |
|
torch.nn.init.constant_(self.query.bias, 0.0) |
|
torch.nn.init.constant_(self.value.bias, 0.0) |
|
|
|
torch.nn.init.normal_(self.proj.weight, mean=0.0, std=std) |
|
torch.nn.init.constant_(self.proj.bias, 0.0) |
|
|
|
def forward( |
|
self, |
|
hidden_states: Optional[Tuple[torch.FloatTensor]], |
|
layer_past: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.BoolTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = False, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: |
|
N, L, D = hidden_states.size() |
|
|
|
key = ( |
|
self.key(hidden_states) |
|
.view(N, L, self.n_heads, D // self.n_heads) |
|
.transpose(1, 2) |
|
) |
|
query = ( |
|
self.query(hidden_states) |
|
.view(N, L, self.n_heads, D // self.n_heads) |
|
.transpose(1, 2) |
|
) |
|
value = ( |
|
self.value(hidden_states) |
|
.view(N, L, self.n_heads, D // self.n_heads) |
|
.transpose(1, 2) |
|
) |
|
|
|
if self.rotary_embedding is not None: |
|
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] |
|
query, key = self.rotary_embedding(query, key, past_kv_length) |
|
|
|
if layer_past is not None: |
|
past_key, past_value = layer_past |
|
|
|
|
|
|
|
key = torch.cat((past_key, key), dim=-2) |
|
value = torch.cat((past_value, value), dim=-2) |
|
|
|
if use_cache: |
|
present = (key, value) |
|
else: |
|
present = None |
|
|
|
|
|
attn_output = (query @ key.transpose(-2, -1)) * (1.0 / math.sqrt(key.size(-1))) |
|
attn_output = ( |
|
attn_output.masked_fill(attention_mask, float("-inf")) |
|
if attention_mask is not None |
|
else attn_output |
|
) |
|
attn_output = F.softmax(attn_output, dim=-1) |
|
|
|
attn_output = self.attn_drop(attn_output) |
|
|
|
|
|
if head_mask is not None: |
|
attn_output = attn_output * head_mask |
|
|
|
outputs = ( |
|
attn_output @ value |
|
) |
|
outputs = ( |
|
outputs.transpose(1, 2).contiguous().view(N, L, D) |
|
) |
|
|
|
|
|
outputs = self.resid_drop(self.proj(outputs)) |
|
|
|
if output_attentions: |
|
return outputs, present, attn_output.sum(dim=1) / self.n_heads |
|
else: |
|
return outputs, present |
|
|
|
|
|
class PagnolXlStandardMLP(nn.Module): |
|
"""Implementation of Pagnol's StandardMLP""" |
|
|
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__() |
|
self.config = config |
|
self.d_model = config.d_model |
|
self.d_feedforward = config.d_feedforward |
|
self.n_layers = config.n_layers |
|
self.activation = ACT2FN[config.activation_function] |
|
|
|
self.mlp = nn.Sequential( |
|
nn.Linear(config.d_model, config.d_feedforward, bias=True), |
|
self.activation, |
|
nn.Linear(config.d_feedforward, config.d_model, bias=True), |
|
) |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
std = self.config.sigma / math.sqrt(2.0 * self.n_layers) |
|
|
|
torch.nn.init.normal_(self.mlp[0].weight, mean=0.0, std=self.config.sigma) |
|
torch.nn.init.zeros_(self.mlp[0].bias) |
|
|
|
torch.nn.init.normal_(self.mlp[2].weight, mean=0.0, std=std) |
|
torch.nn.init.zeros_(self.mlp[2].bias) |
|
|
|
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: |
|
return self.mlp(hidden_states) |
|
|
|
|
|
class PagnolXlLayerNorm(nn.Module): |
|
"""Implementation of Pagnol's LayerNorm""" |
|
|
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__() |
|
self.config = config |
|
self.d_model = config.d_model |
|
self.norm = nn.LayerNorm(self.d_model, eps=config.layer_norm_epsilon) |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
nn.init.ones_(self.norm.weight) |
|
nn.init.zeros_(self.norm.bias) |
|
|
|
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: |
|
return self.norm(hidden_states) |
|
|
|
|
|
class PagnoXlBlock(nn.Module): |
|
"""Transformer block containing the self-attention module and the feedforward module. |
|
Implemented as a decoder layer of GPT-3.""" |
|
|
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__() |
|
self.d_model = config.d_model |
|
self.n_layers = config.n_layers |
|
|
|
self.self_attention = PagnolXlAttention(config) |
|
self.attn_norm = PagnolXlLayerNorm(config) |
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
|
|
self.mlp = PagnolXlStandardMLP(config) |
|
self.mlp_norm = PagnolXlLayerNorm(config) |
|
self.mlp_dropout = nn.Dropout(config.dropout) |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
self.self_attention.init_weights() |
|
self.mlp.init_weights() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
layer_past: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.BoolTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = False, |
|
output_attentions: Optional[bool] = False, |
|
) -> Union[ |
|
Tuple[torch.Tensor], |
|
Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]], |
|
]: |
|
attn_outputs = self.attn_norm(hidden_states) |
|
attn_outputs = self.self_attention( |
|
attn_outputs, |
|
layer_past=layer_past, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
attn_output = attn_outputs[0] |
|
outputs = attn_outputs[1:] |
|
|
|
hidden_states = hidden_states + self.attn_dropout(attn_output) |
|
|
|
feed_forward_hidden_states = self.mlp_norm(hidden_states) |
|
feed_forward_hidden_states = self.mlp(feed_forward_hidden_states) |
|
hidden_states = hidden_states + self.mlp_dropout(feed_forward_hidden_states) |
|
|
|
if use_cache: |
|
outputs = (hidden_states,) + outputs |
|
else: |
|
outputs = (hidden_states,) + outputs[1:] |
|
|
|
return outputs |
|
|
|
|
|
class PagnolXlPreTrainedModel(PreTrainedModel): |
|
config_class = PagnolXlConfig |
|
base_model_prefix = "pagnolxl" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["PagnolXlBlock"] |
|
|
|
def __init__(self, *inputs, **kwargs): |
|
super().__init__(*inputs, **kwargs) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.sigma) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.sigma) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): |
|
if isinstance(module, PagnolXlModel): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
class PagnolXlTransformer(PagnolXlPreTrainedModel): |
|
"""Pagnol's Transformer model""" |
|
|
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__(config) |
|
self.layers = nn.ModuleList( |
|
[PagnoXlBlock(config) for _ in range(config.n_layers)] |
|
) |
|
self.gradient_checkpointing = False |
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
for layer in self.layers: |
|
layer.init_weights() |
|
|
|
@staticmethod |
|
def _prepare_attn_mask( |
|
attention_mask: torch.Tensor, |
|
input_shape: Tuple[int, int], |
|
past_key_values_length: int, |
|
) -> torch.BoolTensor: |
|
|
|
|
|
|
|
|
|
if input_shape[1] + past_key_values_length != attention_mask.shape[1]: |
|
raise ValueError( |
|
"Attention mask shape should be (batch_size, seq_length + past_key_values_length)" |
|
f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" |
|
f" {past_key_values_length}." |
|
) |
|
combined_attention_mask = None |
|
device = attention_mask.device |
|
_, seq_length = input_shape |
|
|
|
if seq_length > 1: |
|
combined_attention_mask = _make_causal_mask( |
|
input_shape, |
|
device=device, |
|
past_key_values_length=past_key_values_length, |
|
) |
|
|
|
|
|
expanded_attn_mask = _expand_mask( |
|
attention_mask, past_key_values_length=past_key_values_length |
|
) |
|
combined_attention_mask = ( |
|
expanded_attn_mask |
|
if combined_attention_mask is None |
|
else expanded_attn_mask | combined_attention_mask |
|
) |
|
|
|
return combined_attention_mask |
|
|
|
def forward( |
|
self, |
|
inputs_embeds: Optional[torch.LongTensor], |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: |
|
|
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
device = inputs_embeds.device |
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layers) |
|
|
|
if past_key_values is None: |
|
past_length = 0 |
|
past_key_values = tuple([None] * len(self.layers)) |
|
else: |
|
past_length = past_key_values[0][0].size(-2) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones( |
|
(batch_size, seq_length + past_length), |
|
device=hidden_states.device, |
|
) |
|
else: |
|
attention_mask = attention_mask.to(hidden_states.device) |
|
|
|
causal_mask = self._prepare_attn_mask( |
|
attention_mask, |
|
input_shape=(batch_size, seq_length), |
|
past_key_values_length=past_length, |
|
) |
|
|
|
presents = () if use_cache else None |
|
all_self_attentions = () if output_attentions else None |
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
) |
|
use_cache = False |
|
|
|
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
outputs = self._gradient_checkpointing_func( |
|
layer.__call__, |
|
hidden_states, |
|
None, |
|
causal_mask, |
|
head_mask[i], |
|
use_cache, |
|
output_attentions, |
|
) |
|
else: |
|
outputs = layer( |
|
hidden_states, |
|
layer_past=layer_past, |
|
attention_mask=causal_mask, |
|
head_mask=head_mask[i], |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = outputs[0] |
|
if use_cache is True: |
|
presents = presents + (outputs[1],) |
|
|
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + ( |
|
outputs[2 if use_cache else 1], |
|
) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [ |
|
hidden_states, |
|
presents, |
|
all_hidden_states, |
|
all_self_attentions, |
|
] |
|
if v is not None |
|
) |
|
|
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=presents, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
) |
|
|
|
|
|
class PagnolXlModel(PagnolXlPreTrainedModel): |
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.embedding = PagnolXlEmbeddings(config) |
|
self.transformer = PagnolXlTransformer(config) |
|
self.final_norm = PagnolXlLayerNorm(config) |
|
self.projector = PagnolXlLMHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embedding.embedding |
|
|
|
def set_input_embeddings(self, value): |
|
self.embedding.embedding = value |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: |
|
|
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError( |
|
"You cannot specify both input_ids and inputs_embeds at the same time" |
|
) |
|
elif input_ids is not None: |
|
batch_size, seq_length = input_ids.shape |
|
elif inputs_embeds is not None: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
|
transformer_outputs = self.transformer( |
|
inputs_embeds, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
return transformer_outputs |
|
|
|
|
|
class PagnolXlLMHead(nn.Module): |
|
"""Pagnol's Language Model head Projector""" |
|
|
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__() |
|
self.proj = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
def init_weights(self): |
|
torch.nn.init.normal_(self.proj.weight, mean=0.0, std=self.config.sigma) |
|
|
|
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: |
|
return self.proj(hidden_states) |
|
|
|
|
|
class PagnolXlForCausalLM(PagnolXlPreTrainedModel): |
|
def __init__(self, config: PagnolXlConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.embedding = PagnolXlEmbeddings(config) |
|
self.transformer = PagnolXlTransformer(config) |
|
self.final_norm = PagnolXlLayerNorm(config) |
|
self.projector = PagnolXlLMHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embedding.embedding |
|
|
|
def set_input_embeddings(self, value): |
|
self.embedding.embedding = value |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.LongTensor, |
|
past_key_values: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> dict: |
|
|
|
if past_key_values: |
|
past_length = past_key_values[0][0].shape[2] |
|
|
|
|
|
if input_ids.shape[1] > past_length: |
|
remove_prefix_length = past_length |
|
else: |
|
|
|
remove_prefix_length = input_ids.shape[1] - 1 |
|
|
|
input_ids = input_ids[:, remove_prefix_length:] |
|
|
|
attention_mask = kwargs.get("attention_mask", None) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
} |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: |
|
|
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError( |
|
"You cannot specify both input_ids and inputs_embeds at the same time" |
|
) |
|
elif input_ids is not None: |
|
batch_size, seq_length = input_ids.shape |
|
elif inputs_embeds is not None: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embedding(input_ids) |
|
|
|
transformer_outputs = self.transformer( |
|
inputs_embeds, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = transformer_outputs[0] |
|
|
|
hidden_states = self.final_norm(hidden_states) |
|
|
|
lm_logits = self.projector(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
batch_size, seq_length, vocab_size = shift_logits.shape |
|
|
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct( |
|
shift_logits.view(batch_size * seq_length, vocab_size), |
|
shift_labels.view(batch_size * seq_length), |
|
) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + transformer_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
) |
|
|