Spaces:
Paused
๊ณ ์ ๊ธธ์ด ๋ชจ๋ธ์ ํํ๋ ์ํฐ(Perplexity)[[perplexity-of-fixedlength-models]]
[[open-in-colab]]
ํํ๋ ์ํฐ(Perplexity, PPL)๋ ๊ฐ์ฅ ์ผ๋ฐ์ ์ธ ์ธ์ด ๋ชจ๋ธ ํ๊ฐ์งํ ์ค ํ๋์ ๋๋ค. ์์ธํ ์์๋ณด๊ธฐ ์ ์ ์ด ํ๊ฐ์งํ๋ ๊ณ ์ ์ ์ธ ์ธ์ด ๋ชจ๋ธ(์๊ธฐํ๊ท ๋๋ ์ธ๊ณผ์ ์ธ์ด ๋ชจ๋ธ์ด๋ผ๊ณ ๋ ํจ)์๋ง ์ ์ฉ๋๋ฉฐ BERT์ ๊ฐ์ ๋ง์คํน๋ ์ธ์ด ๋ชจ๋ธ์๋ ์ ์ ์ฉํ์ง ์์ต๋๋ค (BERT๋ summary of the models ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ์ธ์).
ํํ๋ ์ํฐ๋ ์ํ์ค์ ์์ ๋ก๊ทธ ์ฐ๋(negative log-likelihood, NLL) ๊ฐ์ ํ๊ท ์ ์ง์(exponentiate)๋ฅผ ์ทจํ ๊ฐ์ผ๋ก ์ ์๋ฉ๋๋ค. ํ ํฐํ๋ ์ํ์ค ๊ฐ ์์ ๋, ์ ํํ๋ ์ํฐ๋ ์๋ ์์๊ณผ ๊ฐ์ด ๊ตฌํ ์ ์์ต๋๋ค.
๋ ๋ชจ๋ธ์ i๋ฒ์งธ ์ด์ ๊น์ง ํ ํฐ์ด ์ฃผ์ด์ก์ ๋ i๋ฒ์งธ ํ ํฐ์ ๋ก๊ทธ ์ฐ๋๊ฐ์ ๋๋ค.
์ง๊ด์ ์ผ๋ก ๋ง๋ญ์น์์ ์ง์ ๋ ํ ํฐ ์งํฉ์ ๊ท ์ผํ๊ฒ ์์ธกํ๋ ๋ชจ๋ธ์ ๋ฅ๋ ฅ์ ๋ํ ํ๊ฐ๋ก ์๊ฐํ ์ ์์ต๋๋ค. ์ค์ํ ์ ์ ํ ํฐํ ๊ณผ์ ์ด ๋ชจ๋ธ์ ํํ๋ ์ํฐ์ ์ง์ ์ ์ธ ์ํฅ์ ๋ฏธ์น๋ฏ๋ก ์๋ก ๋ค๋ฅธ ๋ชจ๋ธ์ ๋น๊ตํ ๋ ํญ์ ์ด๋ฅผ ๊ณ ๋ คํด์ผ ํฉ๋๋ค.
์ด๋ ๋ฐ์ดํฐ์ ๋ชจ๋ธ ์์ธก ๊ฐ์ cross-entropy ๊ฐ์ ์ง์๋ฅผ ์ทจํ ๊ฒ๊ณผ ๋์ผํฉ๋๋ค. ํํ๋ ์ํฐ์ ๋ฌธ์๋น ๋นํธ ์(BPC) ๋ฐ ๋ฐ์ดํฐ ์์ถ๊ณผ์ ๊ด๊ณ์ ๋ํด ๋ ์ง๊ด์ ์ธ ์ดํด๋ฅผ ์ํ์ ๋ค๋ฉด ๋ค์ ๊ธ fantastic blog post on The Gradient์ ํ์ธํ์ธ์.
๊ณ ์ ๊ธธ์ด ๋ชจ๋ธ์ ํํ๋ ์ํฐ(PPL) ๊ณ์ฐํ๊ธฐ[[calculating-ppl-with-fixedlength-models]]
๋ชจ๋ธ์ ์ปจํ ์คํธ ํฌ๊ธฐ๊ฐ ์ ํด์ ธ์์ง ์๋ค๋ฉด, ์๋์ ๊ฐ์ด ์ํ์ค๋ฅผ ์๋ ํ๊ท์ ์ผ๋ก ๋ถํดํ๊ณ ๊ฐ ๋จ๊ณ์์ ์ ํ ํ๋ ์ ์ฒด ์ํ์ค๋ฅผ ์กฐ๊ฑด๋ถ ํ๋ฅ ์ ๋ฃ์ด ๋ชจ๋ธ์ ํํ๋ ์ํฐ๋ฅผ ๊ณ์ฐํ ๊ฒ์ ๋๋ค.
๊ทธ๋ฌ๋ ๋ชจ๋ธ์ ๊ทผ์ฌ์น๋ฅผ ๊ตฌํ ๋๋ ์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋ธ์ด ์ฒ๋ฆฌํ ์ ์๋ ํ ํฐ ์์ ์ ํ์ด ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๊ฐ์ฅ ํฐ ๋ฒ์ ์ GPT-2๋ ํ ํฐ์ ๊ธธ์ด๊ฐ 1024๋ก ๊ณ ์ ๋์ด ์์ต๋๋ค. ๋ฐ๋ผ์ ๊ฐ 1024๋ณด๋ค ํฐ ๊ฒฝ์ฐ์ ์ ๊ณ์ฐํ ์ ์์ต๋๋ค.
๋์ ์ํ์ค๋ ์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋ธ์ ์ต๋ ์ ๋ ฅ ํฌ๊ธฐ์ ๋์ผํ ๊ธธ์ด๋ ๊ฐ์ง๋ ๋ถ๋ถ ์ํ์ค๋ก ์ชผ๊ฐญ๋๋ค. ๋ง์ฝ ๋ชจ๋ธ์ ์ต๋ ์ ๋ ฅ ๊ธธ์ด๊ฐ ๋ผ๋ฉด, ํ ํฐ ์ ์ฐ๋ ๊ฐ์ ๊ณ์ฐํ ๋ ์ด์ ํ ํฐ์ ๋ชจ๋ ์ฌ์ฉํ์ง ์๊ณ , ํ ํฐ๊น์ง ์ฌ์ฉํด ๋๋ต์ ์ธ ์ฐ๋ ๊ฐ์ ์ถ์ ํฉ๋๋ค.
๋ชจ๋ธ์ ์ํ์ค์ ๋ํ ํํ๋ ์ํฐ๋ฅผ ๊ณ์ฐํ ๋, ์์ํ์ง๋ง ์ฐจ์ ์ฑ ์ ์ํ์ค๋ฅผ ์ฒญํฌ๋ก ์ชผ๊ฐ๊ณ ๋ถํด๋ ๊ฐ ๋ถ๋ถ์ ๋ก๊ทธ ์ฐ๋ ๊ฐ์ ๋ ๋ฆฝ์ ์ผ๋ก ํฉ์ฐํ๋ ๊ฒ์ ๋๋ค.
์ด ๋ฐฉ๋ฒ์ ๊ฐ ๋ถ๋ถ์ ํํ๋ ์ํฐ๋ฅผ ํ ๋ฒ์ ํฌ์๋ ํจ์ค๋ก ๊ณ์ฐํ ์ ์์ด ๋น ๋ฅด์ง๋ง ์ผ๋ฐ์ ์ผ๋ก ๋ ๋์(๋ ๋์) PPL์ ์ฐ์ถํฉ๋๋ค. ์๋ํ๋ฉด ๋๋ถ๋ถ์ ์์ธก ๋จ๊ณ์์ ๋ชจ๋ธ์ ์ปจํ ์คํธ๊ฐ ์ ๊ธฐ ๋๋ฌธ์ ๋๋ค.
๋์ , ๊ณ ์ ๊ธธ์ด ๋ชจ๋ธ์ PPL์ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ์ ๋ต์ผ๋ก ํ๊ฐํด์ผ ํฉ๋๋ค. ์ด ์ ๋ต์๋ ์ปจํ ์คํธ ์๋์ฐ์ ๋ฐ๋ณต์ ์ผ๋ก ์ฌ๋ผ์ด๋ฉํด ๋ชจ๋ธ์ด ๊ฐ ์์ธก์ ์ํํ ๋ ๋ ๋ง์ ์ปจํ ์คํธ๋ฅผ ๊ฐ๋๋ก ํ๋ ์์ ์ด ํฌํจ๋ฉ๋๋ค.
์ด๋ ์ํ์ค ํ๋ฅ ์ ์ค์ ๋ถํด์ ๋ ๊ฐ๊น์ด ๊ทผ์ฌ์น์ด๋ฉฐ ์ผ๋ฐ์ ์ผ๋ก ๋ ์ ๋ฆฌํ ์ ์๋ฅผ ์ฐ์ถํฉ๋๋ค. ๋จ์ ์ ๋ง๋ญ์น์ ๊ฐ ํ ํฐ์ ๋ํด ๋ณ๋์ ํฌ์๋ ํจ์ค๊ฐ ํ์ํ๋ค๋ ๊ฒ์ ๋๋ค. ํ์ค์ ์ผ๋ก ์ข์ ์ ์ถฉ์์ ํ ๋ฒ์ ํ ํ ํฐ์ฉ ์ฌ๋ผ์ด๋ฉํ๋ ๊ฒ์ด ์๋๋ผ ๋ ํฐ ๊ฐ๊ฒฉ์ผ๋ก ์ปจํ ์คํธ๋ฅผ ์ด๋ํ๋ ์คํธ๋ผ์ด๋๊ฐ ์ ์ฉ๋ ์ฌ๋ผ์ด๋ฉ ์๋์ฐ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๊ณ์ฐ์ ํจ์ฌ ๋ ๋น ๋ฅด๊ฒ ์งํํ๋ฉด์๋ ๋ชจ๋ธ์ ๊ฐ ๋จ๊ณ์์ ์์ธก์ ์ํํ ์ ์๋ ๊ธด ์ปจํ ์คํธ๋ฅผ ์ ๊ณตํ ์ ์์ต๋๋ค.
์์ : ๐ค Transformers์์ GPT-2๋ก ํํ๋ ์ํฐ(perplexity) ๊ณ์ฐํ๊ธฐ[[example-calculating-perplexity-with-gpt2-in-transformers]]
์ด์ GPT-2๋ก ์์ ๊ณผ์ ์ ์์ฐํด ๋ณด๊ฒ ์ต๋๋ค.
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
device = "cuda"
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
WikiText-2 ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๊ฐ์ ธ์ค๊ณ ๋ช ๊ฐ์ง ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ์ ๋ต์ ์ฌ์ฉํด ํํ๋ ์ํฐ๋ฅผ ๊ณ์ฐํด๋ณด๊ฒ ์ต๋๋ค. ์ด ๋ฐ์ดํฐ ์ธํธ๋ ํฌ๊ธฐ๊ฐ ์๊ณ ํฌ์๋ ํจ์ค ํ ๋ฒ๋ง ์ํํ๊ธฐ ๋๋ฌธ์ ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ๊ฐ์ ธ์ค๊ณ ์ธ์ฝ๋ฉํ ์ ์์ต๋๋ค.
from datasets import load_dataset
test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")
๐ค Transformers๋ฅผ ์ฌ์ฉํ๋ฉด ๋ชจ๋ธ์ labels
๋ก input_ids
๋ฅผ ์ ๋ฌํด ๊ฐ ํ ํฐ์ ๋ํ ํ๊ท ์์ ์ฐ๋ ๊ฐ์ ์์ค๋ก ๋ฐํํ ์ ์์ต๋๋ค.
ํ์ง๋ง ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ๋ฐฉ์์ ์ฌ์ฉํ๋ฉด ๊ฐ ๋ฐ๋ณต๋ง๋ค ๋ชจ๋ธ์ ์ ๋ฌํ๋ ํ ํฐ์ด ๊ฒน์นฉ๋๋ค.
์ปจํ
์คํธ๋ก ์ฒ๋ฆฌํ๋ ํ ํฐ์ ๋ํ ๋ก๊ทธ ์ฐ๋ ๊ฐ์ด ์์ค์ ํฌํจ๋๋ ๊ฒ์ ์ํ์ง ์๊ธฐ ๋๋ฌธ์ ์ด๋ฌํ ํ ํฐ์ input_ids
๋ฅผ -100
์ผ๋ก ์ค์ ํ์ฌ ๋ฌด์ํ ์ ์์ต๋๋ค.
๋ค์์ ์คํธ๋ผ์ด๋(stride)๋ฅผ 512
๋ก ์ฌ์ฉํ ์์์
๋๋ค.
์ฆ, ๋ชจ๋ธ์ด ํ ํ ํฐ์ ์กฐ๊ฑด๋ถ ์ฐ๋ ๊ฐ์ ๊ณ์ฐํ ๋ ์ปจํ
์คํธ์ ์ต์ํ 512๊ฐ์ ํ ํฐ์ด ํฌํจ๋์ด์๋ค๋ ์๋ฏธ์
๋๋ค (ํด๋น ํ ํฐ ์์ 512๊ฐ์ ํ ํฐ์ด ์๋ ๊ฒฝ์ฐ).
import torch
from tqdm import tqdm
max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # ๋ง์ง๋ง ๋ฃจํ์ ์คํธ๋ผ์ด๋ ๊ฐ๊ณผ ๋ค๋ฅผ ์ ์์
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
# ์์ค์ ๋ชจ๋ ์ ํจํ ๋ ์ด๋ธ์ ๋ํ ํ๊ท ๊ฐ์ ๊ตฌํ๋ ๊ต์ฐจ ์ํธ๋กํผ(cross entropy)๋ก ๊ณ์ฐ๋ฉ๋๋ค.
# ๋์ด๋ธ ๋ฒ ์ด์ง์ ๋ชจ๋ธ์ ๋ด๋ถ์ ์ผ๋ก ๋ ์ด๋ธ์ ์ผ์ชฝ์ผ๋ก 1๊ฐ์ฉ ๋ฐ๊ธฐ ๋๋ฌธ์, (ํ์ผ - 1)๊ฐ ๋งํผ์ ๋ ์ด๋ธ์ ๋ํด ์์ค์ ๊ณ์ฐํฉ๋๋ค.
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).mean())
์คํธ๋ผ์ด๋๋ฅผ ์ต๋ ์ ๋ ฅ ๊ธธ์ด์ ๋์ผํ๊ฒ ์ค์ ํ๋ฉด ์์์ ์ค๋ช ํ ์ฐจ์ ์ฑ ์ธ ๋น์ฌ๋ผ์ด๋ฉ ์๋์ฐ ์ ๋ต๊ณผ ๋์ผํฉ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ์คํธ๋ผ์ด๋๊ฐ ์์์๋ก ๋ชจ๋ธ์ด ๊ฐ ์์ธก์ ํ ๋ ๋ ๋ง์ ์ปจํ ์คํธ๋ฅผ ๋ณผ ์ ์๊ฒ ๋์ด ํํ๋ ์ํฐ ๊ฐ์ด ์ข์์ง๋๋ค.
์์ ๊ณ์ฐ์ ํ ํฐ์ด ๊ฒน์น์ง ์๋๋ก stride = 1024
๋ก ์ค์ ํ๋ฉด PPL์ 19.44
๋ก GPT-2 ๋
ผ๋ฌธ์์ ๋ณด๊ณ ๋ 19.93
๊ณผ ๊ฑฐ์ ๋์ผํฉ๋๋ค.
stride = 512
๋ก ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ์ ๋ต์ ์ฌ์ฉํ๋ฉด PPL์ 16.45
๋ก ๋จ์ด์ง๋๋ค.
์ด๋ ๋ ์ข์ ์ ์์ผ ๋ฟ๋ง ์๋๋ผ ์ํ์ค ํ๋ฅ ์ ์ค์ ์๋ ํ๊ท ๋ถํด์ ๋ ๊ฐ๊น์ด ๋ฐฉ์์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค.