|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
|
|
|
def compute_memory_used_pct(device): |
|
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) |
|
memory_pct = ( |
|
memory_used |
|
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3)) |
|
* 100 |
|
) |
|
return memory_pct |
|
|
|
model_path = "./out" |
|
|
|
n_ahead = 8 |
|
n_ahead_talk = 4 |
|
merged_talk_heads = True |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
max_thoughts=n_ahead + n_ahead_talk + 1, |
|
merged_talk_heads=merged_talk_heads, |
|
merged_lm_and_talk_heads=False, |
|
merged_lm_and_think_heads=True, |
|
use_concat_talk_head=True, |
|
use_shallow_think=True, |
|
use_shallow_talk=False, |
|
use_complex_think_head=False, |
|
use_complex_talk_head=True, |
|
use_weighted_talk_head=True, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model.tokenizer = tokenizer |
|
|
|
model.use_end_thought_token = True |
|
model.use_start_thought_token = True |
|
model.wandb_enabled = True |
|
model.n_ahead = n_ahead |
|
model.n_passes = 2 |
|
model.eval_mode = True |
|
model.first_run = False |
|
model.kill_after = 100 |
|
model.rm_initialized = True |
|
model.original_mode = False |
|
|
|
|
|
def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs): |
|
with torch.no_grad(): |
|
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device) |
|
for cur_token_idx in range(max_new_tokens): |
|
|
|
new_ids = model( |
|
input_ids[~finished_generating], |
|
attention_mask=attention_mask[~finished_generating] |
|
)['logits'] |
|
|
|
new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf") |
|
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): |
|
|
|
base_answer_ids = input_ids[answer_idx] |
|
new_answer_ids = new_ids[list_idx] |
|
last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() |
|
|
|
new_ids_sampled = torch.multinomial( |
|
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1) |
|
|
|
if last_token_idx + 1 >= len(base_answer_ids): |
|
|
|
new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long, |
|
device=input_ids.device) |
|
input_ids = torch.cat([input_ids, new_padding], dim=-1) |
|
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) |
|
attention_mask[answer_idx, last_token_idx + 1] = 1 |
|
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled |
|
if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id: |
|
finished_generating[answer_idx] = 1 |
|
|
|
if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"): |
|
finished_generating[answer_idx] = 1 |
|
if finished_generating.all(): |
|
break |
|
streamer.put(new_ids_sampled) |
|
return input_ids, attention_mask |
|
|
|
|
|
prompt_template = "[INST] {prompt} [/INST]" |
|
|
|
prompt = "You're standing on the surface of the Earth. "\ |
|
"You walk one mile south, one mile west and one mile north. "\ |
|
"You end up exactly where you started. Where are you?" |
|
|
|
|
|
tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device) |
|
|
|
|
|
attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device) |
|
|
|
streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True) |
|
|
|
|
|
output_ids, _ = custom_generate( |
|
model, |
|
input_ids=tokens, |
|
attention_mask=attention_mask, |
|
max_new_tokens=512, |
|
streamer=streamer, |
|
temperature=0.9, |
|
) |
|
|
|
generated_text = "" |
|
|
|
print() |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|