import gradio as gr from transformers import BartTokenizer, BartForConditionalGeneration import datetime import os import time from typing import List import torch import torch_xla.core.xla_model as xm from transformers import AutoTokenizer, StaticCache from optimum.tpu.modeling import AutoModelForCausalLM os.environ["PJRT_DEVICE"] = "TPU" def sample_greedy(logits): next_logits = logits[:, -1] next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() return next_token_id def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): logits = model( cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True, past_key_values=past_key_values, )[0] new_token = sample_greedy(logits) return new_token def conditional_compile(func): if "DBG_COMPILE" in os.environ: compiled = torch.compile(func, backend="openxla") return compiled return func model_id = "google/gemma-2b" torch_dtype = torch.bfloat16 model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype) device = model.device model = model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) def summarize(inp, model=model, tokenizer=tokenizer, device=device): with torch.no_grad(): inp = inp.replace('\n','') inputs = tokenizer(inp, return_tensors="pt", padding=True).to(device) batch_size, sequence_length = inputs["input_ids"].shape max_cache_length = 1024 max_new_tokens = 64 # setup static cache past_key_values = StaticCache( config=model.config, max_batch_size=batch_size, max_cache_len=max_cache_length, device=model.device, dtype=model.dtype, ) cache_position = torch.arange(sequence_length, device=device) generated_ids = torch.zeros( (batch_size, sequence_length + max_new_tokens + 1), dtype=torch.int, device=device, ) generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int) # prefill here attention_mask = inputs["attention_mask"] pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) logits = model( **inputs, cache_position=cache_position, return_dict=False, use_cache=True, position_ids=pos_ids, past_key_values=past_key_values, )[0] next_token = sample_greedy(logits) xm.mark_step() generated_ids[:, sequence_length] = next_token[:, 0] pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1 model = conditional_compile(model) cache_position = torch.tensor([sequence_length], device=device) for i in range(max_new_tokens): next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, past_key_values) cache_position += 1 generated_ids[:, cache_position] = next_token pos_ids += 1 xm.mark_step() decoded_texts = tokenizer.batch_decode(generated_ids) response = " ".join(decoded_texts) return response gr.Interface(fn=summarize, inputs=gr.Textbox(lines=7, label="Input Text"), outputs="text", title="gemma-2b simple TPU demo").launch(inline=False)