|
import os
|
|
import argparse
|
|
import tiktoken
|
|
import torch
|
|
import time
|
|
|
|
from modelGenerate import GPT
|
|
from dataclasses import dataclass
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--prompt', type=str, required=True,
|
|
help='Prompt for generation')
|
|
parser.add_argument('--max_num_tokens', type=int, default=100,
|
|
help='Maximum number of tokens to generate')
|
|
parser.add_argument('--model_name', type=str, required=True,
|
|
help='Name of the model checkpoint')
|
|
args = parser.parse_args()
|
|
|
|
|
|
@dataclass
|
|
class GPTConfig:
|
|
block_size: int = 1024
|
|
|
|
|
|
vocab_size: int = 50304
|
|
|
|
n_layer: int = 8
|
|
n_head: int = 8
|
|
n_embd: int = 768
|
|
|
|
num_experts: int = 4
|
|
num_active_experts: int = 4
|
|
expert_dim: int = 512
|
|
dim: int = 768
|
|
|
|
dropout: float = 0.0
|
|
|
|
|
|
bias: bool = False
|
|
|
|
|
|
|
|
ckpt_path = os.path.join('./out', f'{args.model_name}.pt')
|
|
checkpoint = torch.load(ckpt_path,torch.device('cpu'))
|
|
print(checkpoint['config'])
|
|
model_args = checkpoint['model_args']
|
|
gptconf = GPTConfig(**model_args)
|
|
model = GPT(gptconf)
|
|
model.load_state_dict(checkpoint['model'])
|
|
|
|
model.eval()
|
|
|
|
|
|
enc = tiktoken.get_encoding("gpt2")
|
|
prompt_ids = enc.encode_ordinary(args.prompt)
|
|
|
|
|
|
start_time = time.time()
|
|
generated = model.generate(torch.tensor(
|
|
[prompt_ids], device='cpu'), max_new_tokens=args.max_num_tokens)
|
|
end_time = time.time()
|
|
inference_time = end_time - start_time
|
|
|
|
|
|
if inference_time >= 3600:
|
|
hours = int(inference_time // 3600)
|
|
minutes = int((inference_time % 3600) // 60)
|
|
seconds = int(inference_time % 60)
|
|
inference_time_str = f"{hours} hours {minutes} minutes {seconds} seconds"
|
|
elif inference_time >= 60:
|
|
minutes = int(inference_time // 60)
|
|
seconds = int(inference_time % 60)
|
|
inference_time_str = f"{minutes} minutes {seconds} seconds"
|
|
else:
|
|
seconds = int(inference_time)
|
|
inference_time_str = f"{seconds} seconds"
|
|
|
|
output = enc.decode(generated[0].tolist())
|
|
|
|
print(f"Prompt: {args.prompt}")
|
|
print(f"Generated text: {output}")
|
|
print(f"Generated text length: {len(output)}")
|
|
print(f"Inference time: {inference_time_str}")
|
|
|