hertz-dev / model.py
calculating
committing...
824afbf
from typing import Optional, Tuple
import torch as T
import torch.nn as nn
import torch.nn.functional as F
from ioblocks import GaussianMixtureIOLayer, FSQ
from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm
from tokenizer import make_tokenizer
from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored
from utils import load_ckpt
@si_module
class LatentQuantizer(nn.Module):
class Config:
compressor_config: Optional[FSQ.Config] = None
dim: Optional[int] = None
ff_dim: Optional[int] = None
input_dim: int = None
from_pretrained: Optional[Tuple[str, str]] = None
def __init__(self, c: Config):
super().__init__()
if exists(c.from_pretrained):
checkpoint = load_ckpt(*c.from_pretrained)
else:
assert exists(c.compressor_config), f'hmm {c}'
self.compressor = c.compressor_config()
self.ffnn = FFNN(c.dim, c.ff_dim)
self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity()
if exists(c.from_pretrained):
self.load_state_dict(checkpoint)
@T.no_grad()
def forward(self, x, return_latent=False, known_latent=None):
"""
x: (B, S, D)
"""
if exists(known_latent):
return self.compressor.indices_to_codes(known_latent)
x = self.input(x)
x = self.ffnn(x)
x, tokens = self.compressor(x)
if return_latent:
return x, tokens
return x
@si_module
class TransformerVAE(nn.Module):
class Config:
io_config: Optional[GaussianMixtureIOLayer.Config] = None
stack_config: Optional[Stack.Config] = None
quantizer_config: Optional[LatentQuantizer.Config] = None
plex_layer: int = None
plex_roll: int = 1
split: bool = True
from_pretrained: Optional[Tuple[str, str]] = None
def __init__(self, c: Config):
super().__init__()
if exists(c.from_pretrained):
checkpoint = load_ckpt(*c.from_pretrained)
else:
assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
self.io = c.io_config()
self.stack = c.stack_config()
self.plex_layer = c.stack_config.layers//2
self.plex_roll = c.plex_roll
self.plex_dim = c.quantizer_config.dim
assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
self.out_norm = Norm(c.stack_config.dim)
if c.split:
self.io2 = c.io_config()
self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
self.io2.fc_loc = None
self.io2.fc_scale = None
self.io2.fc_weight = None
kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
head_dim = c.stack_config.dim // c.stack_config.n_head
self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
self.cache_shape = cache_shape
self.cache = [None] * self.cache_num_layers
if exists(c.from_pretrained):
result = self.load_state_dict(checkpoint, strict=False)
print0_colored(result, 'yellow')
self.quantizer = c.quantizer_config().eval()
self.quantizer.requires_grad = False
@T.no_grad()
def quantize(self, x):
if self.c.split:
x1, x2 = x.chunk(2, dim=-1)
with T.autocast(device_type='cuda', dtype=T.bfloat16):
quantized1 = self.quantizer(x1)
quantized2 = self.quantizer(x2)
return quantized1, quantized2
else:
with T.autocast(device_type='cuda', dtype=T.bfloat16):
return self.quantizer(x)
@T.no_grad()
def untokenize(self, token_data):
return self.quantizer(None, known_latent=token_data)
def init_cache(self, bsize, device, dtype, length:int=None):
cache_shape = self.cache_shape.copy()
cache_shape[1] = length or cache_shape[1]
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
def deinit_cache(self):
self.cache = [None] * self.cache_num_layers
@T.no_grad()
def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
if self.c.split:
x1, x2 = data.chunk(2, dim=-1)
x = self.io.input(x1) + self.io2.input(x2)
else:
x = self.io.input(data)
cache_idx = 0
for l, layer in enumerate(self.stack.layers):
if l == self.plex_layer:
if self.c.split:
plex1, plex2 = self.quantize(data)
plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
if exists(next_tokens):
plex1[:, -1:] = self.untokenize(next_tokens[0])
plex2[:, -1:] = self.untokenize(next_tokens[1])
x1 = x + self.plex_projection(plex1)
x2 = x + self.plex_projection2(plex2)
else:
plex = self.quantize(data)
plex = T.roll(plex, -self.c.plex_roll, dims=1)
if exists(next_tokens):
plex[:, -1:] = self.untokenize(next_tokens)
x = x + self.plex_projection(plex)
if l < self.plex_layer:
x = layer(x, kv=self.cache[l])
else:
if self.c.split:
x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
cache_idx += 1
x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
cache_idx += 1
else:
x = layer(x, kv=self.cache[l])
with T.autocast(device_type='cuda', dtype=T.bfloat16):
if self.c.split:
x1, x2 = self.out_norm(x1), self.out_norm(x2)
out1, out2 = self.io.output(x1), self.io.output(x2)
else:
x = self.out_norm(x)
out = self.io.output(x)
if isnt(temps):
if self.c.split:
return out1, out2
else:
return out
else:
if self.c.split:
next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
next_data = T.cat([next_data1, next_data2], dim=-1)
return next_data
else:
next_data = self.io.temp_sample(out, temps)[:, -1:, :]
return next_data
@si_module
class HertzDevModel(nn.Module):
class Config:
dim: int
vocab_size: int
stack_config: Optional[Stack.Config] = None
latent_size: int = 32
split: bool = True
quantizer_config: Optional[LatentQuantizer.Config] = None
resynthesizer_config: Optional[TransformerVAE.Config] = None
from_pretrained: Optional[Tuple[str, str]] = None
def __init__(self, c: Config):
super().__init__()
if exists(c.from_pretrained):
checkpoint = load_ckpt(*c.from_pretrained)
else:
assert (exists(c.stack_config)), f'hmm {c}'
self.input = nn.Linear(c.latent_size, c.dim)
if self.c.split:
self.input2 = nn.Linear(c.latent_size, c.dim)
self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta)
self.layers = nn.ModuleList([
PerfBlock(
dim=c.stack_config.dim,
layer_id=l,
n_head=c.stack_config.n_head,
kv_heads=c.stack_config.kv_heads,
ff_dim=c.stack_config.ff_dim,
eps=c.stack_config.eps,
shape_rotator=self.shape_rotator,
) for l in range(c.stack_config.layers)
])
self.output = GPTOutput(c.dim, c.vocab_size)
if self.c.split:
self.output2 = GPTOutput(c.dim, c.vocab_size)
self.cache = [None] * c.stack_config.layers
self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
self.head_dim = c.stack_config.dim // c.stack_config.n_head
if exists(c.from_pretrained):
result = self.load_state_dict(checkpoint, strict=False)
print0_colored(result, 'yellow')
self.resynthesizer = c.resynthesizer_config().eval()
self.resynthesizer.requires_grad = False
self.audio_tokenizer = make_tokenizer(device='cpu')
self.audio_cache = None
self.audio_latent_cache = None
self.use_audio_cache = False
@T.no_grad()
def tokenize(self, audio_data):
orig_audio_shape = audio_data.shape
if exists(self.audio_cache):
audio_data = T.cat([self.audio_cache, audio_data], dim=-1)
self.audio_cache = audio_data[..., -(6*16_000):]
elif self.use_audio_cache:
self.audio_cache = audio_data[..., -(6*16_000):]
if audio_data.shape[1] == 2:
enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1])
enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2])
return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):]
else:
return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):]
@T.no_grad()
def untokenize(self, token_data):
if exists(self.audio_latent_cache):
token_data = T.cat([self.audio_latent_cache, token_data], dim=1)
self.audio_latent_cache = token_data[:, -(6*8):]
elif self.use_audio_cache:
self.audio_latent_cache = token_data[:, -(6*8):]
if token_data.shape[-1] == 2*self.c.latent_size:
dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size])
dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:])
return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):]
else:
return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):]
def init_cache(self, bsize, device, dtype, length:int=None):
cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim]
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
self.resynthesizer.init_cache(bsize, device, dtype, length)
self.use_audio_cache = True
def deinit_cache(self):
self.cache = [None] * len(self.layers)
self.resynthesizer.deinit_cache()
self.audio_cache = None
self.audio_latent_cache = None
self.use_audio_cache = False
@T.no_grad()
def forward(self, data):
if self.c.split:
x1, x2 = data.chunk(2, dim=-1)
x = self.input(x1) + self.input2(x2)
else:
x = self.input(data)
for l, layer in enumerate(self.layers):
x = layer(x, kv=self.cache[l])
if self.c.split:
return self.output(x), self.output2(x)
else:
return self.output(x)
@T.no_grad()
def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))):
latents_in = self.tokenize(audio_data)
next_latents = self.next_latent(latents_in, temps)
next_model_latent = next_latents[..., self.c.latent_size:]
audio_decoded = self.untokenize(next_model_latent)[..., -2000:]
return audio_decoded
@T.no_grad()
def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))):
if self.c.split:
logits1, logits2 = self.forward(model_input)
next_logits1 = logits1[:, -1]
next_logits2 = logits2[:, -1]
next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1)
next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1)
next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1])
else:
logits = self.forward(model_input)
next_logits = logits[:, -1]
next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1)
next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1])
return next_input
@T.no_grad()
def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor:
"""
only accepts latent-space data.
"""
if use_cache:
self.init_cache(data.shape[0], data.device, T.bfloat16)
next_input = generated = data
target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len)
for _ in tqdm0(range(data.shape[1], target_len)):
model_input = next_input if use_cache else generated
next_input = self.next_latent(model_input, temps)
generated = T.cat([generated, next_input], dim=1)
if use_cache:
self.deinit_cache()
return generated
def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False):
if is_split:
checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')]
elif not use_pure_audio_ablation:
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')]
else:
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')]
quantizer_config=LatentQuantizer.Config(
from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'),
compressor_config=FSQ.Config(
levels=[8,8,8,8,8],
dim=2048,
num_codebooks=1,
keep_num_codebooks_dim=None,
scale=None,
allowed_dtypes=['float32', 'float64', 'bfloat16'],
channel_first=False,
projection_has_bias=True,
return_indices=True,
force_quantization_f32=True,
use_rms=False
),
dim=2048,
ff_dim=8192,
input_dim=32
)
resynthesizer_config=TransformerVAE.Config(
io_config=GaussianMixtureIOLayer.Config(
latent_dim=32,
dim=4096,
num_components=8,
),
stack_config=Stack.Config(
layers=8,
dim=4096,
seq_len=8192,
n_head=16,
ff_dim=11008,
kv_heads=16,
eps=1e-5,
theta=10_000
),
quantizer_config=quantizer_config,
plex_layer=None,
plex_roll=1,
split=is_split,
from_pretrained=checkpoints[0],
)
return HertzDevModel.Config(
dim=4096,
vocab_size=32_768,
stack_config=Stack.Config(
layers=32,
dim=4096,
seq_len=2048,
n_head=32,
ff_dim=None,
kv_heads=None,
eps=1e-5,
theta=10_000,
),
quantizer_config=quantizer_config,
resynthesizer_config=resynthesizer_config,
split=is_split,
from_pretrained=checkpoints[1],
)