lumenspark / modeling_lumenspark.py
anto18671's picture
Upload 3 files
4306d2e verified
from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
from .configuration_lumenspark import LumensparkConfig
from torch import nn
import torch
import math
# ----------------------------
# Low-Rank Linear Layer Implementation
# ----------------------------
class LowRankLinear(nn.Module):
def __init__(self, in_features, out_features, rank, init_std=0.02):
super().__init__()
self.U = nn.Linear(in_features, rank, bias=False)
self.V = nn.Linear(rank, out_features, bias=False)
nn.init.normal_(self.U.weight, std=init_std)
nn.init.normal_(self.V.weight, std=init_std)
def forward(self, x):
return self.V(self.U(x))
# ----------------------------
# Lumenspark Self-Attention Implementation
# ----------------------------
class LumensparkSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0):
super().__init__()
assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'
self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
self.dropout_layer = nn.Dropout(dropout)
self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)
def stable_softmax(self, x, dim=-1):
x_max = torch.max(x, dim=dim, keepdim=True)[0]
exp_x = torch.exp(x - x_max)
return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6)
def forward(self, inputs, attention_mask=None):
batch_size, seq_len, _ = inputs.shape
q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
attention_weights = self.stable_softmax(attention_scores, dim=-1)
attention_weights = self.dropout_layer(attention_weights)
attention_output = torch.matmul(attention_weights, v)
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.output_transform(attention_output)
# ----------------------------
# Define Lumenspark Model Wrapper
# ----------------------------
class LumensparkModel(PreTrainedModel):
config_class = LumensparkConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Token and position embeddings
self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)
# Lumenspark transformer encoder layers with prenormalization and LayerScale
self.layers = nn.ModuleList()
for _ in range(config.depth):
layer = nn.ModuleDict({
"norm1": nn.LayerNorm(config.embed_dim),
"attn": LumensparkSelfAttention(
embed_dim=config.embed_dim,
num_heads=config.heads,
head_dim=config.embed_dim // config.heads,
dropout=config.dropout
),
"norm2": nn.LayerNorm(config.embed_dim),
"ffn": nn.Sequential(
LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank),
nn.GELU(),
nn.Dropout(config.dropout),
LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank),
nn.Dropout(config.dropout)
),
})
# Assign the parameters directly as attributes
layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
self.layers.append(layer)
# Final LayerNorm layer
self.final_norm = nn.LayerNorm(config.embed_dim)
# Feed-forward output layer
self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
self.dropout = nn.Dropout(config.dropout)
# Initialize model weights
self.init_weights()
@staticmethod
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
"""
Filter a distribution of logits using top-k and/or top-p filtering.
"""
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
return logits
def generate(self, input_ids, attention_mask=None, max_length=160, min_length=20, temperature=0.6, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True):
"""
Text generation method that handles auto-regressive generation with repetition penalty.
Input `input_ids` should be a tensor. Returns generated tokens.
"""
self.eval()
device = input_ids.device
generated_tokens = input_ids
for _ in range(max_length - input_ids.size(1)):
# Forward pass for logits
outputs = self.forward(input_ids=generated_tokens, attention_mask=attention_mask)
logits = outputs["logits"][:, -1, :]
# Adjust logits by temperature
logits = logits / temperature
# Apply repetition penalty by reducing logits of tokens already generated
for token in set(generated_tokens.view(-1).tolist()):
logits[:, token] /= repetition_penalty
# Apply sampling with top-k and top-p
if do_sample:
filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
# Append the generated token
generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
attention_mask = torch.ones_like(generated_tokens).to(device)
# Ensure min_length before stopping generation with end-of-sequence (EOS) token
if next_token.item() == self.config.eos_token_id and generated_tokens.size(1) < min_length:
continue
if next_token.item() == self.config.eos_token_id:
break
return generated_tokens
def forward(self, input_ids, attention_mask=None, labels=None):
"""
Forward pass of the model. If `labels` are provided, computes the loss.
"""
batch_size, seq_length = input_ids.size()
# Generate position ids for input tokens
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
# Embed tokens and positions
token_embeddings = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
# Combine token and position embeddings
embeddings = token_embeddings + position_embeddings
embeddings = self.dropout(embeddings)
# Create causal mask for self-attention to ensure autoregressive behavior
device = embeddings.device
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=device)).unsqueeze(0).unsqueeze(0)
# Combine with attention mask if provided
combined_mask = causal_mask if attention_mask is None else attention_mask[:, None, None, :].float() * causal_mask
# Pass through transformer layers
for layer in self.layers:
embeddings_norm = layer["norm1"](embeddings)
attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
embeddings = embeddings + layer.layer_scale_attn * attn_output
embeddings_norm = layer["norm2"](embeddings)
ffn_output = layer["ffn"](embeddings_norm)
embeddings = embeddings + layer.layer_scale_ffn * ffn_output
# Final normalization and output projection to logits
embeddings = self.final_norm(embeddings)
logits = self.fc_out(embeddings)
# Compute loss if labels are provided
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
shift_labels = labels[:, 1:].contiguous().view(-1)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
return {"loss": loss, "logits": logits}
# Register LumensparkForCausalLM with AutoModelForCausalLM
AutoConfig.register("lumenspark", LumensparkConfig)
AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)