Spaces:
Sleeping
Sleeping
""" | |
Run simply with | |
$ pytest | |
""" | |
import os | |
import pytest # pip install pytest | |
import subprocess | |
import torch | |
from model import ModelArgs, Transformer | |
def test_argmax_inference(): | |
""" | |
Only the simplest test for now: run inference with temperature 0 | |
(for determinism) in both C and PyTorch, and see that the sampled tokens | |
are the same. | |
""" | |
test_ckpt_dir = "out" # TODO create a dummy test checkpoint for this? | |
# run C version | |
model_path = os.path.join(test_ckpt_dir, "model.bin") | |
command = ["./run", model_path, "0.0"] | |
proc = subprocess.Popen(command, stdout=subprocess.PIPE) | |
c_tokens = [] | |
for line in proc.stdout: | |
token = int(line.decode('utf-8').strip()) | |
c_tokens.append(token) | |
proc.wait() | |
#print(c_tokens) | |
# run PyTorch version | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
ckpt_path = os.path.join(test_ckpt_dir, "ckpt.pt") | |
checkpoint = torch.load(ckpt_path, map_location=device) | |
gptconf = ModelArgs(**checkpoint['model_args']) | |
model = Transformer(gptconf) | |
state_dict = checkpoint['model'] | |
unwanted_prefix = '_orig_mod.' | |
for k,v in list(state_dict.items()): | |
if k.startswith(unwanted_prefix): | |
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | |
model.load_state_dict(state_dict, strict=False) | |
model.eval() | |
model.to(device) | |
x = torch.tensor([[1]], dtype=torch.long, device=device) # 1 is BOS | |
with torch.inference_mode(): | |
y = model.generate(x, max_new_tokens=gptconf.max_seq_len, temperature=0.0) | |
pt_tokens = y[0].tolist() | |
pt_tokens = pt_tokens[1:] # remove BOS | |
#print(pt_tokens) | |
# compare | |
assert c_tokens == pt_tokens | |