# coding=utf-8 # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Flax IndicTrans model.""" import math from typing import List, Optional, Tuple, Union, Callable from functools import partial import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey from transformers.modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) from transformers.modeling_flax_utils import ( ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, append_replace_return_docstrings, overwrite_call_docstring, ) from configuration_indictrans import IndicTransConfig from transformers.utils import logging logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "IndicTransConfig" INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""] def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ shifted_input_ids = jnp.zeros_like(input_ids) shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids class FlaxIndicTransSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" num_positions: int embedding_dim: int padding_idx: Optional[int] = None # IndicTrans is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack offset: int = 2 def setup(self) -> None: self.weights = self._make_weights(self.num_positions + self.offset, self.embedding_dim, padding_idx=self.padding_idx) def _make_weights( self, num_embeddings: int, embedding_dim: int, existing_weights: Optional[jnp.array] = None, padding_idx: Optional[int] = None ): emb_weights = self._get_embedding(num_embeddings, embedding_dim, padding_idx) if existing_weights is not None: # Convert emb_weights to the same dtype as existing_weights emb_weights = emb_weights.astype(existing_weights.dtype) return emb_weights def _get_embedding( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None ): """ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32)) emb = jnp.arange(num_embeddings, dtype=jnp.float32).reshape(-1, 1) * emb.reshape(1, -1) emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1).reshape(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = jnp.concatenate([emb, jnp.zeros((num_embeddings, 1), dtype=emb.dtype)], axis=1) if padding_idx is not None: emb = emb.at[padding_idx].set(0) return emb def __call__( self, input_ids: jnp.array = None, inputs_embeds: jnp.array = None, past_key_values_length: int = 0 ): if input_ids is not None: bsz, seq_len = input_ids.shape position_ids = self._create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) else: bsz, seq_len = inputs_embeds.shape[:-1] position_ids = self._create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) # Expand embeddings if needed max_pos = self.padding_idx + 1 + seq_len + past_key_values_length if max_pos > self.weights.shape[0]: self.weights = self.make_weights(max_pos + self.offset, self.embedding_dim, self.weights, self.padding_idx) return self.weights[position_ids.ravel()].reshape(bsz, seq_len, -1) def _create_position_ids_from_input_ids( self, input_ids, padding_idx, past_key_values_length=0 ): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is a JAX conversion of the PyTorch function. """ mask = (input_ids != padding_idx) incremental_indices = (jnp.cumsum(mask, axis=1) + past_key_values_length) * mask return incremental_indices + padding_idx def _create_position_ids_from_inputs_embeds( self, inputs_embeds, past_key_values_length ): """ Generate sequential position ids from input embeddings. Args: inputs_embeds: jnp.array (JAX array) past_key_values_length: int Returns: jnp.array: Position IDs corresponding to the inputs. """ input_shape = inputs_embeds.shape[:-1] sequence_length = input_shape[1] position_ids = jnp.arange(self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=jnp.int64) return jnp.expand_dims(position_ids, axis=0).repeat(input_shape[0], axis=0) + past_key_values_length class FlaxIndicTransAttention(nn.Module): config: IndicTransConfig embed_dim: int num_heads: int dropout: float = 0.0 causal: bool = False bias: bool = True dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self) -> None: self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {self.num_heads})." ) # Not required in Flax Module as `dot_product_attention_weights` handles scaling internally. # For more details, check: https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html#dot_product_attention_weights # self.scaling dense = partial( nn.Dense, self.embed_dim, use_bias=self.bias, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.out_proj = dense() self.dropout_layer = nn.Dropout(rate=self.dropout) if self.causal: self.causal_mask = make_causal_mask( jnp.ones((1, self.config.max_source_positions), dtype="bool"), dtype="bool" ) def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) @nn.compact def _concatenate_to_cache(self, key, value, query, attention_mask): """ This function takes projected key, value states from a single input token and concatenates the states to cached states from previous steps. This function is slighly adapted from the official Flax repository: https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 """ # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0,) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = combine_masks(pad_mask, attention_mask) return key, value, attention_mask def __call__( self, hidden_states: jnp.ndarray, key_value_states: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] # get query proj query_states = self.q_proj(hidden_states) # Scaling is handled internally by `dot_product_attention_weights`. # get key, value proj if is_cross_attention: # cross_attentions key_states = self.k_proj(key_value_states) value_states = self.v_proj(key_value_states) else: # self_attention key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # handle cache prepare causal attention mask if self.causal: query_length, key_length = query_states.shape[1], key_states.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) ) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) # combine masks if needed if attention_mask is not None and self.causal: attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) elif self.causal: attention_mask = causal_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision="high", ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class FlaxIndicTransEncoderLayer(nn.Module): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.embed_dim = self.config.encoder_embed_dim self.self_attn = FlaxIndicTransAttention( config=self.config, embed_dim=self.embed_dim, num_heads=self.config.encoder_attention_heads, dropout=self.config.attention_dropout, dtype=self.dtype, ) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.normalize_before = self.config.encoder_normalize_before def __call__( self, hidden_states: jnp.ndarray, attention_mask: jnp.ndarray, output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.attn_layer_norm(hidden_states) residual = hidden_states if self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class FlaxIndicTransEncoderLayerCollection(nn.Module): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layers = [ FlaxIndicTransEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers) ] self.layerdrop = self.config.encoder_layerdrop def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = jax.random.normal(jax.random.PRNGKey(0), []) if not deterministic and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions, deterministic, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states, all_hidden_states, all_attentions) if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions ) class FlaxIndicTransDecoderLayer(nn.Module): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.embed_dim = self.config.decoder_embed_dim self.self_attn = FlaxIndicTransAttention( config=self.config, embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=True, dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.encoder_attn = FlaxIndicTransAttention( config=self.config, embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=False, dtype=self.dtype, ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.fc1 = nn.Dense( self.config.decoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.normalize_before = self.config.decoder_normalize_before def __call__( self, hidden_states: jnp.ndarray, attention_mask: jnp.ndarray, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states if self.normalize_before: hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, # init_cache=init_cache ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.encoder_attn_layer_norm(hidden_states) # Fully Connected residual = hidden_states if self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) return outputs class FlaxIndicTransDecoderLayerCollection(nn.Module): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layers = [ FlaxIndicTransDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers) ] self.layerdrop = self.config.decoder_layerdrop def __call__( self, hidden_states, attention_mask, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = jax.random.normal(jax.random.PRNGKey(0), []) if not deterministic and (dropout_probability < self.layerdrop): layer_outputs = (None, None, None) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, init_cache=init_cache, output_attentions=output_attentions, deterministic=deterministic, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) class FlaxIndicTransEncoder(nn.Module): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.encoder_embed_dim self.padding_idx = self.config.pad_token_id self.max_source_positions = self.config.max_source_positions self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 self.embed_tokens = nn.Embed( self.config.encoder_vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.embed_positions = FlaxIndicTransSinusoidalPositionalEmbedding( self.config.max_source_positions, embed_dim, self.padding_idx, ) self.layers = FlaxIndicTransEncoderLayerCollection(self.config, self.dtype) self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.encoder_normalize_before else None self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.layernorm_embedding else None def __call__( self, input_ids, attention_mask, position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): input_shape = input_ids.shape input_ids = input_ids.reshape(-1, input_shape[-1]) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_ids, inputs_embeds) hidden_states = inputs_embeds + embed_pos if self.layernorm_embedding is not None: hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) outputs = self.layers( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_states = outputs[0] if self.layer_norm is not None: last_hidden_states = self.layer_norm(last_hidden_states) # update the last element in `hidden_states` after applying `layernorm` above hidden_states = None if output_hidden_states: hidden_states = outputs[1] hidden_states = hidden_states[:-1] + (last_hidden_states,) if not return_dict: outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=last_hidden_states, hidden_states=hidden_states, attentions=outputs.attentions, ) class FlaxIndicTransDecoder(nn.Module): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.encoder_embed_dim self.padding_idx = self.config.pad_token_id self.max_target_positions = self.config.max_target_positions self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 self.embed_tokens = nn.Embed( self.config.decoder_vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.embed_positions = FlaxIndicTransSinusoidalPositionalEmbedding( self.config.max_target_positions, embed_dim, self.padding_idx, ) self.layers = FlaxIndicTransDecoderLayerCollection(self.config, self.dtype) self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.decoder_normalize_before else None self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.layernorm_embedding else None def __call__( self, input_ids, attention_mask, position_ids, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): input_shape = input_ids.shape input_ids = input_ids.reshape(-1, input_shape[-1]) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale # embed positions positions = self.embed_positions(input_ids, inputs_embeds) hidden_states = inputs_embeds + positions if self.layernorm_embedding is not None: hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) outputs = self.layers( hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_states = outputs[0] if self.layer_norm is not None: last_hidden_states = self.layer_norm(last_hidden_states) # update the last element in `hidden_states` after applying `layernorm` above hidden_states = None if output_hidden_states: hidden_states = outputs[1] hidden_states = hidden_states[:-1] + (last_hidden_states,) if not return_dict: outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=last_hidden_states, hidden_states=hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) class FlaxIndicTransModule(nn.Module): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.encoder = FlaxIndicTransEncoder(self.config, dtype=self.dtype) self.decoder = FlaxIndicTransDecoder(self.config, dtype=self.dtype) def _get_encoder_module(self): return self.encoder def _get_decoder_module(self): return self.decoder def __call__( self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) if not return_dict: return decoder_outputs + encoder_outputs return FlaxSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class FlaxIndicTransPreTrainedModel(FlaxPreTrainedModel): config_class = IndicTransConfig base_model_prefix: str = "model" module_class: nn.Module = None def __init__( self, config: IndicTransConfig, input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxMBartForSequenceClassificationModule input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) decoder_input_ids = input_ids decoder_attention_mask = jnp.ones_like(input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init( rngs, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, )["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params def init_cache(self, batch_size, max_length, encoder_outputs): # init input variables to retrieve cache decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) decoder_position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape ) def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) init_variables = self.module.init( jax.random.PRNGKey(0), decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], init_cache=True, method=_decoder_forward, # we only need to call the decoder to init the cache ) return unfreeze(init_variables["cache"]) def encode( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): encode_module = module._get_encoder_module() return encode_module(input_ids, attention_mask, position_ids, **kwargs) return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, method=_encoder_forward, ) def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxMBartAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past = outputs outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past = outputs outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, decoder_input_ids: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # prepare decoder inputs if decoder_input_ids is None: decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id, self.config.decoder_start_token_id) if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, ) class FlaxIndicTransModel(FlaxIndicTransPreTrainedModel): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation module_class = FlaxIndicTransModule class FlaxIndicTransForConditionalGenerationModule(nn.Module): config: IndicTransConfig dtype: jnp.dtype = jnp.float32 bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros def setup(self): self.model = FlaxIndicTransModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.decoder_vocab_size, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) def _get_encoder_module(self): return self.model.encoder def _get_decoder_module(self): return self.model.decoder def __call__( self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, position_ids=position_ids, decoder_position_ids=decoder_position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) hidden_states = outputs[0] if self.config.share_decoder_input_output_embed: shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] lm_logits = jax.lax.stop_gradient(self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)) else: lm_logits = jax.lax.stop_gradient(self.lm_head(hidden_states)) if not return_dict: output = (lm_logits,) + outputs[1:] return output return FlaxSeq2SeqLMOutput( logits=lm_logits, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) class FlaxIndicTransForConditionalGeneration(FlaxIndicTransPreTrainedModel): module_class = FlaxIndicTransForConditionalGenerationModule dtype: jnp.dtype = jnp.float32 def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxMBartAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() outputs = decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) hidden_states = outputs[0] if self.config.share_decoder_input_output_embed: shared_embedding = module.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) else: lm_logits = module.lm_head(hidden_states) return lm_logits, outputs outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) if past_key_values is None: lm_logits, decoder_outputs = outputs else: (lm_logits, decoder_outputs), past = outputs if return_dict: outputs = FlaxCausalLMOutputWithCrossAttentions( logits=lm_logits, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, ) else: outputs = (lm_logits,) + decoder_outputs[1:] # add updated cache to model output if past_key_values is not None and return_dict: outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def prepare_inputs_for_generation( self, decoder_input_ids, max_length, attention_mask: Optional[jax.Array] = None, decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): # initializing the cache batch_size, seq_length = decoder_input_ids.shape past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since the decoder uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "encoder_attention_mask": attention_mask, "decoder_attention_mask": extended_attention_mask, "decoder_position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 return model_kwargs