Spaces:
Running
on
A10G
Running
on
A10G
from typing import List | |
import torch | |
import torch.nn as nn | |
import json | |
import os | |
from .tokenizer import Tokenizer | |
from . import LLM | |
from fairscale.nn.model_parallel import initialize as fs_init | |
class MetaModel(nn.Module): | |
def __init__(self, llama_type, llama_config, llama_ckpt_dir=None, tokenizer_path=None): | |
super().__init__() | |
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0) | |
ModelArgs = LLM.__dict__[llama_type].ModelArgs | |
Transformer = LLM.__dict__[llama_type].Transformer | |
with open(llama_config, "r") as f: | |
params = json.loads(f.read()) | |
model_args: ModelArgs = ModelArgs( | |
max_seq_len=2048, max_batch_size=32, **params | |
) | |
self.tokenizer = Tokenizer(model_path=tokenizer_path) | |
model_args.vocab_size = self.tokenizer.n_words | |
model = Transformer(model_args) | |
mp_rank = fs_init.get_model_parallel_rank() | |
if llama_ckpt_dir is not None: | |
ckpt_path = os.path.join(llama_ckpt_dir, f"consolidated.{mp_rank:02d}.pth") | |
if os.path.exists(ckpt_path): | |
checkpoint = torch.load(ckpt_path, map_location="cpu") | |
msg = model.load_state_dict(checkpoint, strict=False) | |
print(msg) | |
else: | |
print(f'Checkpoint not found at {ckpt_path}') | |
self.llma = model | |
for name, param in self.named_parameters(): | |
if param.requires_grad: | |
print(f"Trainable param: {name}, {param.shape}, {param.dtype}") | |
count = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
print(f"Parameter count : {count}") | |
def forward(self, examples, labels, image=None, modal='image'): | |
output = self.llma(examples, image=image, modal=modal) | |
output = output[:, :-1, :] | |
labels = labels[:, 1:] | |
if labels.sum() == 0: | |
c_loss = output.mean() * 0 | |
else: | |
c_loss = self.criterion(output.reshape(-1, 32000), labels.flatten()) | |
return c_loss | |
def generate( | |
self, | |
prompts: List[str], | |
images, | |
max_gen_len: int, | |
temperature: float = 0.8, | |
top_p: float = 0.95, | |
modal = ['image'], | |
) -> List[str]: | |
bsz = len(prompts) | |
params = self.llma.params | |
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) | |
prompt_tokens = [self.tokenizer.encode( | |
x, bos=True, eos=False) for x in prompts] | |
min_prompt_size = min([len(t) for t in prompt_tokens]) | |
max_prompt_size = max([len(t) for t in prompt_tokens]) | |
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | |
tokens = torch.full( | |
(bsz, total_len), self.tokenizer.pad_id).cuda().long() | |
for k, t in enumerate(prompt_tokens): | |
tokens[k, : len(t)] = torch.tensor(t).long() | |
input_text_mask = tokens != self.tokenizer.pad_id | |
start_pos = min_prompt_size | |
prev_pos = 0 | |
for cur_pos in range(start_pos, total_len): | |
logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal=modal) | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
next_token = self.sample_top_p(probs, top_p) | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.reshape(-1) | |
# only replace token if prompt has already been generated | |
next_token = torch.where( | |
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | |
) | |
tokens[:, cur_pos] = next_token | |
prev_pos = cur_pos | |
decoded = [] | |
for i, t in enumerate(tokens.tolist()): | |
# cut to max gen len | |
t = t[: len(prompt_tokens[i]) + max_gen_len] | |
# cut to eos tok if any | |
try: | |
t = t[: t.index(self.tokenizer.eos_id)] | |
except ValueError: | |
pass | |
decoded.append(self.tokenizer.decode(t)) | |
return decoded | |
def stream_generate( | |
self, | |
prompt: str, | |
images, | |
max_gen_len: int, | |
temperature: float = 0.8, | |
top_p: float = 0.95, | |
modal = ['image'], | |
): | |
params = self.llma.params | |
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) | |
# truncate from the left. leave some space for generation. | |
max_seq_len = params.max_seq_len | |
if images is not None: | |
max_seq_len -= self.llma.image_words | |
max_prompt_size = max_seq_len - max_gen_len | |
prompt_tokens = prompt_tokens[-max_prompt_size:] | |
prompt_size = len(prompt_tokens) | |
total_len = min(max_seq_len, max_gen_len + prompt_size) | |
tokens = torch.full([total_len], 0).cuda().long() | |
tokens[:len(prompt_tokens)] = torch.tensor(prompt_tokens).long() | |
start_pos = prompt_size | |
prev_pos = 0 | |
generate_until = start_pos | |
for cur_pos in range(start_pos, total_len): | |
logits = self.llma.forward_inference(tokens[None, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal = modal) | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
next_token = self.sample_top_p(probs, top_p) | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.item() | |
if next_token == self.tokenizer.eos_id: | |
break | |
tokens[cur_pos] = next_token | |
prev_pos = cur_pos | |
generate_until = cur_pos + 1 | |
yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": False} | |
yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": True} | |
def sample_top_p(self, probs, p): | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
next_token = torch.multinomial(probs_sort, num_samples=1) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token | |
def get_image_words(self): | |
return self.llma.image_words |