MicroRWKV / generate.py
AnshulRanjan2004's picture
Uploading the Model
c50fe14 verified
import os
import argparse
import tiktoken
import torch
import time
from modelGenerate import GPT
from dataclasses import dataclass
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--prompt', type=str, required=True,
help='Prompt for generation')
parser.add_argument('--max_num_tokens', type=int, default=100,
help='Maximum number of tokens to generate')
parser.add_argument('--model_name', type=str, required=True,
help='Name of the model checkpoint')
args = parser.parse_args()
@dataclass
class GPTConfig:
block_size: int = 1024
# GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
vocab_size: int = 50304
n_layer: int = 8
n_head: int = 8
n_embd: int = 768
num_experts: int = 4
num_active_experts: int = 4
expert_dim: int = 512
dim: int = 768
dropout: float = 0.0
# True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
bias: bool = False
# Load the model checkpoint
ckpt_path = os.path.join('./out', f'{args.model_name}.pt')
checkpoint = torch.load(ckpt_path,torch.device('cpu'))
print(checkpoint['config'])
model_args = checkpoint['model_args']
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.load_state_dict(checkpoint['model'])
# model.cuda()
model.eval()
# Encode the prompt using tiktoken
enc = tiktoken.get_encoding("gpt2")
prompt_ids = enc.encode_ordinary(args.prompt)
# Measure inference time
start_time = time.time() # Get the current time before generating text
generated = model.generate(torch.tensor(
[prompt_ids], device='cpu'), max_new_tokens=args.max_num_tokens)
end_time = time.time() # Get the current time after generating text
inference_time = end_time - start_time # Calculate inference time in seconds
# Convert seconds to more readable format
if inference_time >= 3600:
hours = int(inference_time // 3600)
minutes = int((inference_time % 3600) // 60)
seconds = int(inference_time % 60)
inference_time_str = f"{hours} hours {minutes} minutes {seconds} seconds"
elif inference_time >= 60:
minutes = int(inference_time // 60)
seconds = int(inference_time % 60)
inference_time_str = f"{minutes} minutes {seconds} seconds"
else:
seconds = int(inference_time)
inference_time_str = f"{seconds} seconds"
output = enc.decode(generated[0].tolist())
print(f"Prompt: {args.prompt}")
print(f"Generated text: {output}")
print(f"Generated text length: {len(output)}")
print(f"Inference time: {inference_time_str}")