""" PyTorch ProteinGLM model. """ import math import copy import warnings import re import sys import os import pathlib import time import random import numpy as np from tqdm.auto import tqdm import torch, deepspeed import torch.utils.checkpoint import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss from torch.nn.utils import skip_init from typing import Optional, Tuple, Union, List, Callable, Dict, Any from copy import deepcopy from collections import namedtuple from transformers.modeling_outputs import ( BaseModelOutputWithPast, MaskedLMOutput, CausalLMOutputWithPast, SequenceClassifierOutput, TokenClassifierOutput ) from transformers import PreTrainedModel from transformers.utils import logging from transformers.generation.logits_process import LogitsProcessor from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput from .configuration_proteinglm import ProteinGLMConfig from .quantization import quantize def get_checkpoint_fn(): if deepspeed.checkpointing.is_configured(): checkpoint = deepspeed.checkpointing.checkpoint else: checkpoint = torch.utils.checkpoint.checkpoint return checkpoint # flags required to enable jit fusion kernels if sys.platform != 'darwin': torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "proteinglm-7b-clm" _CONFIG_FOR_DOC = "ProteinGLMConfig" DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"]) def default_init(cls, *args, **kwargs): return cls(*args, **kwargs) def get_deepnorm_coefficients(config: ProteinGLMConfig): """ DeepNorm coefficients from : https://kexue.fm/archives/8978 """ num_layers = config.num_layers return DeepNormCoefficients(alpha=(2 * num_layers) ** 0.5, beta=(2 * num_layers) ** -0.5) class InvalidScoreLogitsProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() scores[..., 5] = 5e4 return scores def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, ) -> List[torch.Tensor]: """Split a tensor along its last dimension. Arguments: tensor: input tensor. num_partitions: number of partitions to split the tensor contiguous_split_chunks: If True, make each chunk contiguous in memory. Returns: A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 last_dim_size = tensor.size()[last_dim] // num_partitions # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # Note: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000, precision=torch.half, learnable=False): super().__init__() inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)).to(precision) self.dim = dim self.base = base self.learnable = learnable if learnable: self.inv_freq = torch.nn.Parameter(inv_freq) self.max_seq_len_cached = None else: self.register_buffer('inv_freq', inv_freq) self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None self.precision = precision def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if f'{prefix}inv_freq' in state_dict: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) else: self.inv_freq.copy_(1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(self.precision)) def forward(self, x, seq_dim=1, seq_len=None): if seq_len is None: seq_len = x.shape[seq_dim] if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = None if self.learnable else seq_len t = torch.arange(seq_len, device=x.device, dtype=torch.float32) freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(x.device)) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16 or self.precision == torch.half: emb = emb.float() # [sx, 1 (b * np), hn] cos_cached = emb.cos()[:, None, :] sin_cached = emb.sin()[:, None, :] if self.precision == torch.bfloat16: cos_cached = cos_cached.bfloat16() sin_cached = sin_cached.bfloat16() elif self.precision == torch.half: cos_cached = cos_cached.half() sin_cached = sin_cached.half() if self.learnable: return cos_cached, sin_cached self.cos_cached, self.sin_cached = cos_cached, sin_cached return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions def assert_dim_check(tensor, ndim=None, shape=None): if ndim is not None: assert tensor.ndim == ndim, f"Exepct tensor.ndim={ndim}. gut got tensor.shape={tensor.shape}" if shape is not None: assert list(tensor.shape) == list(shape), f"Exepct tensor.shape={shape}. gut got tensor.shape={tensor.shape}" def apply_rotary_pos_emb_index_torch(q, k, cos, sin, position_id): # jitting fails with bf16 # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k class RMSNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) self.eps = eps def forward(self, hidden_states: torch.Tensor): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return (self.weight * hidden_states).to(input_dtype) class CoreAttention(torch.nn.Module): def __init__(self, config: ProteinGLMConfig, layer_number): super(CoreAttention, self).__init__() self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. self.hidden_size_per_partition = projection_size self.hidden_size_per_attention_head = projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.coeff = coeff self.attention_dropout = torch.nn.Dropout(config.attention_dropout) self.is_causal = config.is_causal self.use_pytorch_sdpa = config.use_pytorch_sdpa def forward(self, query_layer, key_layer, value_layer, attention_mask): # query_layer, key_layer, value_layer: [seq_len, batch_size, num_heads, head_dim] # import pdb; pdb.set_trace(); pytorch_major_version = int(torch.__version__.split('.')[0]) # assert pytorch_major_version >= 2, f"Expect PyTorch version > 2.0" if pytorch_major_version >= 2 and self.use_pytorch_sdpa: dropout_p = self.attention_dropout.p if self.training else 0 # [seq_len, batch_size, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim] query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] # import pdb; pdb.set_trace(); if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: # context_layer: [batch_size, num_heads, seq_len, head_dim] context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, is_causal=self.is_causal, dropout_p=dropout_p) else: if (attention_mask is not None) and (attention_mask.dtype == torch.bool): attention_mask = attention_mask.logical_not() ## DO NOT inplace operation!!!! context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, dropout_p=dropout_p) # [batch_size, num_heads, seq_len, head_dim] -> [seq_len, batch_size, num_heads, head_dim] context_layer = context_layer.permute(2, 0, 1, 3) # [seq_len, batch_size, 2560] new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) else: # Raw attention scores # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=query_layer.device ) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, sq, sk] if self.attention_softmax_in_fp32: attention_scores = attention_scores.float() if self.coeff is not None: attention_scores = attention_scores * self.coeff if self.is_causal and attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool) attention_mask.tril_() attention_mask = ~attention_mask if attention_mask is not None: attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = attention_probs.type_as(value_layer) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer class SelfAttention(torch.nn.Module): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__(self, config: ProteinGLMConfig, layer_number, device=None): super(SelfAttention, self).__init__() self.layer_number = max(1, layer_number) self.projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads self.multi_query_attention = config.multi_query_attention self.qkv_hidden_size = 3 * self.projection_size if self.multi_query_attention: self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = ( self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num ) self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, bias=config.add_bias_linear or config.add_qkv_bias, device=device, **_config_to_kwargs(config) ) self.core_attention = CoreAttention(config, self.layer_number) # Output. self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config)) self.rotary_embedding_2d = config.rotary_embedding_2d # dim, base=10000, precision=torch.half, learnable=False self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head // 2 if self.rotary_embedding_2d else self.hidden_size_per_attention_head, base=10000, precision=config.torch_dtype, learnable=False) def forward( self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True ): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, ], dim=-1, ) query_layer = query_layer.view( query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) ) key_layer = key_layer.view( key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) ) value_layer = value_layer.view( value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) ) else: new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) # apply relative positional encoding (rotary embedding) if position_ids is not None: # [seq_len, 2, batch_size, 32, 2] if self.rotary_embedding_2d: q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) # 32 k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) # import pdb; pdb.set_trace(); cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) # 32 position_ids, block_position_ids = \ position_ids[:, 0, :].transpose(0, 1).contiguous(), \ position_ids[:, 1, :].transpose(0, 1).contiguous() q1, k1 = apply_rotary_pos_emb_index_torch(q1, k1, cos, sin, position_ids) q2, k2 = apply_rotary_pos_emb_index_torch(q2, k2, cos, sin, block_position_ids) query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) else: # [b, sq] -> [sq, b] position_ids = position_ids.transpose(0, 1) cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) query_layer, key_layer = apply_rotary_pos_emb_index_torch(query_layer, key_layer, cos, sin, position_ids) # adjust key and value for inference if kv_cache is not None: cache_k, cache_v = kv_cache key_layer = torch.cat((cache_k, key_layer), dim=0) value_layer = torch.cat((cache_v, value_layer), dim=0) if use_cache: kv_cache = (key_layer, value_layer) else: kv_cache = None if self.multi_query_attention: key_layer = key_layer.unsqueeze(-2) key_layer = key_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)) # ================================== # core attention computation # ================================== context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # context_layer: [seq_len, batch_size, num_heads*head_dim] output = self.dense(context_layer) # ================= # Output. [sq, b, h] # ================= # output = context_layer @ self.dense.weight.T + self.dense.bias return output, kv_cache def _config_to_kwargs(args): common_kwargs = { "dtype": args.torch_dtype, } return common_kwargs class MLP(torch.nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__(self, config: ProteinGLMConfig, device=None): super(MLP, self).__init__() self.add_bias = config.add_bias_linear self.moe = config.moe self.num_experts = config.num_experts self.experts_per_token = config.experts_per_token # 2 # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf self.dense_h_to_4h = nn.Linear( config.hidden_size, config.ffn_hidden_size * 2, bias=self.add_bias, device=device, **_config_to_kwargs(config) ) def swiglu(x): x = torch.chunk(x, 2, dim=-1) return x[0] * F.silu(x[1]) def geglu(x): x = torch.chunk(x, 2, dim=-1) return x[0] * F.gelu(x[1]) if config.glu_activation == 'geglu': self.activation_func = geglu elif config.glu_activation == 'swiglu': self.activation_func = swiglu else: assert RuntimeError(f"Unsupported glu_activation: {config.glu_activation}") # Project back to h. self.dense_4h_to_h = nn.Linear( config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config) ) if self.moe: assert self.num_experts > 1 del self.dense_h_to_4h del self.dense_4h_to_h self.router = nn.Linear( config.hidden_size, config.num_experts, bias=False, device=device, dtype=torch.float32 ) for i in range(0, self.num_experts): self.register_module(f"dense_h_to_4h_{i}", nn.Linear( config.hidden_size, config.ffn_hidden_size * 2, bias=self.add_bias, device=device, **_config_to_kwargs(config) )) self.register_module(f"dense_4h_to_h_{i}", nn.Linear( config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config) )) def moe_forward(self, hidden_states, expert_idx): intermediate_parallel = getattr(self, f"dense_h_to_4h_{expert_idx}")(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) output = getattr(self, f"dense_4h_to_h_{expert_idx}")(intermediate_parallel) return output def forward(self, hidden_states): if self.moe: # import pdb; pdb.set_trace(); s, b, n = hidden_states.shape dtype = hidden_states.dtype hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h] route = self.router(hidden_states).to(dtype) weights, selected_experts = torch.topk(route, self.experts_per_token) weights = F.softmax(weights, dim=1, dtype=torch.float).to(hidden_states.dtype) output = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) for expert_idx in range(self.num_experts): batch_idx, nth_expert = torch.where(selected_experts == expert_idx) if nth_expert.shape[0] == 0: continue cur_out = self.moe_forward(hidden_states[batch_idx], expert_idx) output[batch_idx] += weights[batch_idx, nth_expert, None] * cur_out output = output.reshape(s, b, n) else: # [s, b, 4hp] intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output = self.dense_4h_to_h(intermediate_parallel) return output class ProteinGLMBlock(torch.nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__(self, config: ProteinGLMConfig, layer_number, device=None): super(ProteinGLMBlock, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon) # Self attention. self.self_attention = SelfAttention(config, layer_number, device=device) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon) # MLP self.mlp = MLP(config, device=device) self.deepnorm_coeff = get_deepnorm_coefficients(config) if config.deepnorm else None def forward( self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True, ): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, kv_cache = self.self_attention( layernorm_output, attention_mask, position_ids, # [batch_size, 2, seq_len, 32, 2] kv_cache=kv_cache, use_cache=use_cache ) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) if self.deepnorm_coeff is not None: layernorm_input = residual*self.deepnorm_coeff.alpha + layernorm_input else: layernorm_input = residual + layernorm_input # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) if self.deepnorm_coeff is not None: output = residual*self.deepnorm_coeff.alpha + output else: #print(f"2 self.deepnorm_coeff is None") output = residual + output return output, kv_cache class ProteinGLMTransformer(torch.nn.Module): """Transformer class.""" def __init__(self, config: ProteinGLMConfig, device=None): super(ProteinGLMTransformer, self).__init__() self.fp32_residual_connection = config.fp32_residual_connection self.post_layer_norm = config.post_layer_norm # Number of layers. self.num_layers = config.num_layers # Transformer layers. def build_layer(layer_number): return ProteinGLMBlock(config, layer_number, device=device) self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon) self.gradient_checkpointing = False def _get_layer(self, layer_number): return self.layers[layer_number] def forward( self, hidden_states, attention_mask, position_ids, kv_caches=None, use_cache: Optional[bool] = True, output_hidden_states: Optional[bool] = False, ): if not kv_caches: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False all_self_attentions = None all_hidden_states = () if output_hidden_states else None for index in range(self.num_layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer = self._get_layer(index) if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): layer_ret = get_checkpoint_fn()( layer, hidden_states, attention_mask, position_ids, kv_caches[index], use_cache ) else: layer_ret = layer( hidden_states, attention_mask, position_ids, kv_cache=kv_caches[index], use_cache=use_cache ) hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) # Final layer norm. if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) return hidden_states, presents, all_hidden_states, all_self_attentions class ProteinGLMPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ is_parallelizable = False supports_gradient_checkpointing = True config_class = ProteinGLMConfig base_model_prefix = "transformer" _no_split_modules = ["ProteinGLMBlock"] _quantized = False def get_masks(self, input_ids, past_key_values, padding_mask=None, is_causal=True): batch_size, seq_length = input_ids.shape full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) if is_causal: full_attention_mask.tril_() past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[0] if past_length: full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1) if padding_mask is not None: full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) if not past_length and padding_mask is not None: full_attention_mask -= padding_mask.unsqueeze(-1) - 1 full_attention_mask = (full_attention_mask < 0.5).bool() full_attention_mask.unsqueeze_(1) return full_attention_mask def get_position_ids(self, input_ids, device, context_length=0): batch_size, seq_length = input_ids.shape if self.config.rotary_embedding_2d: if self.config.is_causal: # 100b model position_ids_1 = torch.zeros(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] position_ids_2 = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] position_ids = torch.stack([position_ids_1, position_ids_2], axis=1) # [batch_size, 2, seq_len] else: position_ids_1 = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] position_ids_2 = torch.zeros(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] position_ids = torch.stack([position_ids_1, position_ids_2], axis=1) # [batch_size, 2, seq_len] else: position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, 1, seq_len] return position_ids def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, ProteinGLMTransformer): module.gradient_checkpointing = value # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): std = self.config.initializer_range """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def quantize(self, weight_bit_width: int, empty_init=True, device=None): if self._quantized: print(f"Model has been quantized...") return self.transformer.encoder = quantize(self.transformer.encoder, weight_bit_width, empty_init, device) self._quantized = True return self class Embedding(torch.nn.Module): """Language model embeddings.""" def __init__(self, config: ProteinGLMConfig, device=None): super(Embedding, self).__init__() self.hidden_size = config.hidden_size # Word embeddings (parallel). self.word_embeddings = nn.Embedding( config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device ) self.fp32_residual_connection = config.fp32_residual_connection def forward(self, input_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) embeddings = words_embeddings # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. embeddings = embeddings.transpose(0, 1).contiguous() # If the input flag for fp32 residual connection is set, convert for float. if self.fp32_residual_connection: embeddings = embeddings.float() return embeddings class ProteinGLMModel(ProteinGLMPreTrainedModel): def __init__(self, config: ProteinGLMConfig, device=None, empty_init=True): super().__init__(config) if empty_init: init_method = skip_init else: init_method = default_init init_kwargs = {} if device is not None: init_kwargs["device"] = device self.embedding = init_method(Embedding, config, **init_kwargs) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels # Rotary positional embeddings self.seq_length = config.seq_length rotary_dim = ( config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels ) # self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, base=10000, precision=config.torch_dtype, learnable=False) self.encoder = init_method(ProteinGLMTransformer, config, **init_kwargs) self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, dtype=config.torch_dtype, **init_kwargs) def get_input_embeddings(self): return self.embedding.word_embeddings def set_input_embeddings(self, value): self.embedding.word_embeddings = value def forward( self, input_ids, position_ids: Optional[torch.Tensor] = None, # position_ids: [batch_size, 2, seq_len] attention_mask: Optional[torch.BoolTensor] = None, full_attention_mask: Optional[torch.BoolTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if self.config.is_causal: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, position_ids=position_ids, kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states ) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class ProteinGLMForMaskedLM(ProteinGLMPreTrainedModel): def __init__(self, config: ProteinGLMConfig, empty_init=True, device=None): super().__init__(config) self.max_sequence_length = config.max_length self.transformer = ProteinGLMModel(config, empty_init=empty_init, device=device) self.config = config if self.config.quantization_bit: print(f"Begin Quantization to {self.config.quantization_bit} bit") self.quantize(self.config.quantization_bit, empty_init=True, device=device) def forward( self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_last_logit: Optional[bool] = None, return_last_hidden_state: Optional[bool] = None ): if self.config.is_causal: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict if position_ids is None: position_ids = self.get_position_ids(input_ids, device=input_ids.device) full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal=self.config.is_causal) transformer_outputs = self.transformer( input_ids=input_ids, position_ids=position_ids, # position_ids: [batch_size, 2, seq_len] full_attention_mask=full_attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] if return_last_logit: hidden_states = hidden_states[-1:] lm_logits = self.transformer.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() masked_lm_loss = None if labels is not None: lm_logits = lm_logits.to(torch.float32) # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-100) # -100 for padding token. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) lm_logits = lm_logits.to(hidden_states.dtype) loss = loss.to(hidden_states.dtype) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss = masked_lm_loss, logits=lm_logits, hidden_states=transformer_outputs.last_hidden_state if return_last_hidden_state else transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) class ProteinGLMForSequenceClassification(ProteinGLMPreTrainedModel): def __init__(self, config: ProteinGLMConfig, empty_init=True, device=None): super().__init__(config) self.config = config self.num_labels = config.num_labels self.transformer = ProteinGLMModel(config, empty_init=empty_init, device=device) self.classifier = ProteinGLMClassificationHead(config) if self.config.quantization_bit: print(f"Begin Quantization to {self.config.quantization_bit} bit") self.quantize(self.config.quantization_bit, empty_init=True, device=device) def forward( self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_last_logit: Optional[bool] = None, return_last_hidden_state: Optional[bool] = None, **kwargs ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ if self.config.is_causal: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict if position_ids is None: position_ids = self.get_position_ids(input_ids, device=input_ids.device) full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal=self.config.is_causal) transformer_outputs = self.transformer( input_ids=input_ids, position_ids=position_ids, # position_ids: [batch_size, 2, seq_len] full_attention_mask=full_attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.add_special_tokens: hidden_states = transformer_outputs[0][:-1] # get rid of token else: hidden_states = transformer_outputs[0] logits = self.classifier(hidden_states, add_pooling=True) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + transformer_outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) class ProteinGLMForTokenClassification(ProteinGLMPreTrainedModel): def __init__(self, config: ProteinGLMConfig, empty_init=True, device=None): super().__init__(config) self.config = config self.num_labels = config.num_labels self.transformer = ProteinGLMModel(config, empty_init=empty_init, device=device) if config.task_modality == "token": self.classifier = ProteinGLMClassificationHead(config) elif config.task_modality == 'pair': self.classifier = ProteinGLMContactHead(config) self.quantized = False if self.config.quantization_bit: print(f"Begin Quantization to {self.config.quantization_bit} bit") self.quantize(self.config.quantization_bit, empty_init=True, device=device) def forward( self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_last_logit: Optional[bool] = None, return_last_hidden_state: Optional[bool] = None, **kwargs ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ if self.config.is_causal: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict if position_ids is None: position_ids = self.get_position_ids(input_ids, device=input_ids.device) full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal = self.config.is_causal) transformer_outputs = self.transformer( input_ids=input_ids, position_ids=position_ids, # position_ids: [batch_size, 2, seq_len] full_attention_mask=full_attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.add_special_tokens: hidden_states = transformer_outputs[0][:-1] # get rid of token else: hidden_states = transformer_outputs[0] logits = self.classifier(hidden_states, add_pooling=False) loss = None if labels is not None: labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + transformer_outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) class ProteinGLMClassificationHead(nn.Module): """Head for classification tasks.""" def __init__(self, config): super().__init__() self.activation_func = config.activation_func self.layers = torch.nn.ModuleList() last_size = config.hidden_size for sz in config.inter_hidden_size: this_layer = torch.nn.Linear(last_size, sz, bias=config.bias) last_size = sz self.layers.append(this_layer) def forward(self, input_features, add_pooling: Optional[bool] = True ): # [s, b, h] -> [b, s ,h] input_features = input_features.transpose(0,1).contiguous() if add_pooling: # [b, h] input_features = torch.mean(input_features, dim = 1) for i, layer in enumerate(self.layers): if i > 0: input_features = self.activation_func(input_features) input_features = layer(input_features) return input_features class ProteinGLMContactHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.activation_func = config.activation_func self.layers = torch.nn.ModuleList() last_size = config.hidden_size * 2 for sz in config.inter_hidden_size: this_layer = torch.nn.Linear(last_size, sz, bias=config.bias) last_size = sz self.layers.append(this_layer) def outer_concat(self, x): batch_size, seq_len, features = x.shape # Permute to [batch_size, features, seq_len] x = x.permute(0, 2, 1) # Introduce new dimensions for broadcasting x_1 = x[:, None, :, :, None] # [batch_size, 1, features, seq_len, 1] x_2 = x[:, None, :, None, :] # [batch_size, 1, features, 1, seq_len] # Repeat along new dimensions x_1 = x_1.repeat(1, 1, 1, 1, seq_len) # [batch_size, 1, features, seq_len, seq_len] x_2 = x_2.repeat(1, 1, 1, seq_len, 1) # [batch_size, 1, features, seq_len, seq_len] # Concatenate along the second dimension x = torch.cat((x_1, x_2), dim=1) # [batch_size, 2, features, seq_len, seq_len] # Get lower triangular indices I, J = torch.tril_indices(seq_len, seq_len, -1) # Symmetrize x[:, :, :, I, J] = x[:, :, :, J, I] # Permute to desired shape and make contiguous x = x.permute(0, 3, 4, 2, 1).contiguous() # [batch_size, seq_len, seq_len, features, 2] # Reshape to combine the last two dimensions x = x.view(batch_size, seq_len, seq_len, features * 2) # [batch_size, seq_len, seq_len, features * 2] return x def forward(self, input_features, add_pooling: Optional[bool] = True ): # [s, b, h] -> [b, s ,h] input_features = input_features.transpose(0,1).contiguous() input_features = self.outer_concat(input_features) for i, layer in enumerate(self.layers): if i > 0: input_features = self.activation_func(input_features) input_features = layer(input_features) return input_features class ProteinGLMForCasualLM(ProteinGLMPreTrainedModel): def __init__(self, config: ProteinGLMConfig, empty_init=True, device=None): super().__init__(config) self.max_sequence_length = config.max_length self.transformer = ProteinGLMModel(config, empty_init=empty_init, device=device) self.config = config if self.config.quantization_bit: print(f"Begin Quantization to {self.config.quantization_bit} bit") self.quantize(self.config.quantization_bit, empty_init=True, device=device) def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, ) -> Dict[str, Any]: # update past_key_values cache_name, cache = self._extract_past_from_model_output(outputs) model_kwargs[cache_name] = cache # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) # update position ids if "position_ids" in model_kwargs: position_ids = model_kwargs["position_ids"] new_position_id = position_ids[..., -1:].clone() # [batch_size, 2, 1] if self.config.rotary_embedding_2d: new_position_id[:, 1] += 1 # Only update the 2nd dimension else: new_position_id[:] += 1 model_kwargs["position_ids"] = torch.cat( [position_ids, new_position_id], dim=-1 ) # [batch_size, 2, seq_len+1] model_kwargs["is_first_forward"] = False return model_kwargs def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, is_first_forward: bool = True, **kwargs ) -> dict: # only last token for input_ids if past is not None if position_ids is None: position_ids = self.get_position_ids(input_ids, device=input_ids.device) # position_ids: [batch_size, 2, seq_len] if not is_first_forward: if past_key_values is not None: position_ids = position_ids[..., -1:] input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past_key_values, "position_ids": position_ids, "attention_mask": attention_mask, "return_last_logit": True, "use_cache": use_cache } def forward( self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_last_logit: Optional[bool] = False ): if self.config.is_causal: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict if position_ids is None: position_ids = self.get_position_ids(input_ids, device=input_ids.device) transformer_outputs = self.transformer( input_ids=input_ids, position_ids=position_ids, # position_ids: [batch_size, 2, seq_len] attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = transformer_outputs[0] if return_last_logit: hidden_states = hidden_states[-1:] lm_logits = self.transformer.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() loss = None if labels is not None: lm_logits = lm_logits.to(torch.float32) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) lm_logits = lm_logits.to(hidden_states.dtype) loss = loss.to(hidden_states.dtype) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @staticmethod def _reorder_cache( past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. Output shares the same memory storage as `past`. """ return tuple( ( layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), ) for layer_past in past ) @torch.inference_mode() def chat(self, tokenizer, query: str, max_length: int = 256, num_beams=1, do_sample=True, top_p=1.0, temperature=1.0, logits_processor=None, **kwargs): if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, **kwargs} inputs = tokenizer.apply_chat_template(query, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True) position_ids = self.get_position_ids(inputs['input_ids'], device=self.device) # TODO: ADD BATCH eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("")] inputs["position_ids"] = position_ids inputs = inputs.to(self.device) outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) outputs = outputs.tolist()[0][3:] # 3 for generation prompt "" if outputs[-1] in eos_token_id: outputs = outputs[:-1] response = tokenizer.decode(outputs) return response # TODO: fix bug in streaming chat @torch.inference_mode() def stream_chat(self, tokenizer, query: str, max_length: int = 56, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, past_key_values = None, **kwargs): if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("")] gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, **kwargs} inputs = tokenizer.apply_chat_template(query, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True) position_ids = self.get_position_ids(inputs['input_ids'], device=self.device) # TODO: ADD BATCH eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("")] inputs["position_ids"] = position_ids inputs = inputs.to(self.device) offset = 3 # 3 for generation prompt for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, eos_token_id=eos_token_id, return_past_key_values=False, **gen_kwargs): outputs = outputs.tolist()[0][3:] if outputs[-1] in eos_token_id: outputs = outputs[:-1] # offset = 3 + len(outputs) response = tokenizer.decode(outputs) if response: yield response @torch.inference_mode() def stream_generate( self, input_ids, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, return_past_key_values=False, **kwargs, ): breakpoint() batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) model_kwargs["use_cache"] = generation_config.use_cache bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warn( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", UserWarning, ) if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`." ) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, ) stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) logits_warper = self._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None while True: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) if generation_config.do_sample: next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(probs, dim=-1) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) if return_past_key_values: yield input_ids, outputs.past_key_values else: yield input_ids # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): break