๐ค PEFT๋ก ์ด๋ํฐ ๊ฐ์ ธ์ค๊ธฐ
Parameter-Efficient Fine Tuning (PEFT) ๋ฐฉ๋ฒ์ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋งค๊ฐ๋ณ์๋ฅผ ๋ฏธ์ธ ์กฐ์ ์ค ๊ณ ์ ์ํค๊ณ , ๊ทธ ์์ ํ๋ จํ ์ ์๋ ๋งค์ฐ ์ ์ ์์ ๋งค๊ฐ๋ณ์(์ด๋ํฐ)๋ฅผ ์ถ๊ฐํฉ๋๋ค. ์ด๋ํฐ๋ ์์ ๋ณ ์ ๋ณด๋ฅผ ํ์ตํ๋๋ก ํ๋ จ๋ฉ๋๋ค. ์ด ์ ๊ทผ ๋ฐฉ์์ ์์ ํ ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ํ์ ํ๋ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ๋ฉด์, ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ด๊ณ ๋น๊ต์ ์ ์ ์ปดํจํ ๋ฆฌ์์ค๋ฅผ ์ฌ์ฉํฉ๋๋ค.
๋ํ PEFT๋ก ํ๋ จ๋ ์ด๋ํฐ๋ ์ผ๋ฐ์ ์ผ๋ก ์ ์ฒด ๋ชจ๋ธ๋ณด๋ค ํจ์ฌ ์๊ธฐ ๋๋ฌธ์ ๊ณต์ , ์ ์ฅ ๋ฐ ๊ฐ์ ธ์ค๊ธฐ๊ฐ ํธ๋ฆฌํฉ๋๋ค.
๐ค PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ํด ์์ธํ ์์๋ณด๋ ค๋ฉด ๋ฌธ์๋ฅผ ํ์ธํ์ธ์.
์ค์
๐ค PEFT๋ฅผ ์ค์นํ์ฌ ์์ํ์ธ์:
pip install peft
์๋ก์ด ๊ธฐ๋ฅ์ ์ฌ์ฉํด๋ณด๊ณ ์ถ๋ค๋ฉด, ๋ค์ ์์ค์์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํ๋ ๊ฒ์ด ์ข์ต๋๋ค:
pip install git+https://github.com/huggingface/peft.git
์ง์๋๋ PEFT ๋ชจ๋ธ
๐ค Transformers๋ ๊ธฐ๋ณธ์ ์ผ๋ก ์ผ๋ถ PEFT ๋ฐฉ๋ฒ์ ์ง์ํ๋ฉฐ, ๋ก์ปฌ์ด๋ Hub์ ์ ์ฅ๋ ์ด๋ํฐ ๊ฐ์ค์น๋ฅผ ๊ฐ์ ธ์ค๊ณ ๋ช ์ค์ ์ฝ๋๋ง์ผ๋ก ์ฝ๊ฒ ์คํํ๊ฑฐ๋ ํ๋ จํ ์ ์์ต๋๋ค. ๋ค์ ๋ฐฉ๋ฒ์ ์ง์ํฉ๋๋ค:
๐ค PEFT์ ๊ด๋ จ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ(์: ํ๋กฌํํธ ํ๋ จ ๋๋ ํ๋กฌํํธ ํ๋) ๋๋ ์ผ๋ฐ์ ์ธ ๐ค PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ํด ์์ธํ ์์๋ณด๋ ค๋ฉด ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ธ์.
PEFT ์ด๋ํฐ ๊ฐ์ ธ์ค๊ธฐ
๐ค Transformers์์ PEFT ์ด๋ํฐ ๋ชจ๋ธ์ ๊ฐ์ ธ์ค๊ณ ์ฌ์ฉํ๋ ค๋ฉด Hub ์ ์ฅ์๋ ๋ก์ปฌ ๋๋ ํฐ๋ฆฌ์ adapter_config.json
ํ์ผ๊ณผ ์ด๋ํฐ ๊ฐ์ค์น๊ฐ ํฌํจ๋์ด ์๋์ง ํ์ธํ์ญ์์ค. ๊ทธ๋ฐ ๋ค์ AutoModelFor
ํด๋์ค๋ฅผ ์ฌ์ฉํ์ฌ PEFT ์ด๋ํฐ ๋ชจ๋ธ์ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด ์ธ๊ณผ ๊ด๊ณ ์ธ์ด ๋ชจ๋ธ์ฉ PEFT ์ด๋ํฐ ๋ชจ๋ธ์ ๊ฐ์ ธ์ค๋ ค๋ฉด ๋ค์ ๋จ๊ณ๋ฅผ ๋ฐ๋ฅด์ญ์์ค:
- PEFT ๋ชจ๋ธ ID๋ฅผ ์ง์ ํ์ญ์์ค.
- AutoModelForCausalLM ํด๋์ค์ ์ ๋ฌํ์ญ์์ค.
from transformers import AutoModelForCausalLM, AutoTokenizer
peft_model_id = "ybelkada/opt-350m-lora"
model = AutoModelForCausalLM.from_pretrained(peft_model_id)
AutoModelFor
ํด๋์ค๋ ๊ธฐ๋ณธ ๋ชจ๋ธ ํด๋์ค(์: OPTForCausalLM
๋๋ LlamaForCausalLM
) ์ค ํ๋๋ฅผ ์ฌ์ฉํ์ฌ PEFT ์ด๋ํฐ๋ฅผ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค.
load_adapter
๋ฉ์๋๋ฅผ ํธ์ถํ์ฌ PEFT ์ด๋ํฐ๋ฅผ ๊ฐ์ ธ์ฌ ์๋ ์์ต๋๋ค.
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "facebook/opt-350m"
peft_model_id = "ybelkada/opt-350m-lora"
model = AutoModelForCausalLM.from_pretrained(model_id)
model.load_adapter(peft_model_id)
8๋นํธ ๋๋ 4๋นํธ๋ก ๊ฐ์ ธ์ค๊ธฐ
bitsandbytes
ํตํฉ์ 8๋นํธ์ 4๋นํธ ์ ๋ฐ๋ ๋ฐ์ดํฐ ์ ํ์ ์ง์ํ๋ฏ๋ก ํฐ ๋ชจ๋ธ์ ๊ฐ์ ธ์ฌ ๋ ์ ์ฉํ๋ฉด์ ๋ฉ๋ชจ๋ฆฌ๋ ์ ์ฝํฉ๋๋ค. ๋ชจ๋ธ์ ํ๋์จ์ด์ ํจ๊ณผ์ ์ผ๋ก ๋ถ๋ฐฐํ๋ ค๋ฉด from_pretrained()์ load_in_8bit
๋๋ load_in_4bit
๋งค๊ฐ๋ณ์๋ฅผ ์ถ๊ฐํ๊ณ device_map="auto"
๋ฅผ ์ค์ ํ์ธ์:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
peft_model_id = "ybelkada/opt-350m-lora"
model = AutoModelForCausalLM.from_pretrained(peft_model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True))
์ ์ด๋ํฐ ์ถ๊ฐ
์ ์ด๋ํฐ๊ฐ ํ์ฌ ์ด๋ํฐ์ ๋์ผํ ์ ํ์ธ ๊ฒฝ์ฐ์ ํํด ๊ธฐ์กด ์ด๋ํฐ๊ฐ ์๋ ๋ชจ๋ธ์ ์ ์ด๋ํฐ๋ฅผ ์ถ๊ฐํ๋ ค๋ฉด ~peft.PeftModel.add_adapter
๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด ๋ชจ๋ธ์ ๊ธฐ์กด LoRA ์ด๋ํฐ๊ฐ ์ฐ๊ฒฐ๋์ด ์๋ ๊ฒฝ์ฐ:
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig
model_id = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_id)
lora_config = LoraConfig(
target_modules=["q_proj", "k_proj"],
init_lora_weights=False
)
model.add_adapter(lora_config, adapter_name="adapter_1")
์ ์ด๋ํฐ๋ฅผ ์ถ๊ฐํ๋ ค๋ฉด:
# attach new adapter with same config
model.add_adapter(lora_config, adapter_name="adapter_2")
์ด์ ~peft.PeftModel.set_adapter
๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ํฐ๋ฅผ ์ฌ์ฉํ ์ด๋ํฐ๋ก ์ค์ ํ ์ ์์ต๋๋ค:
# use adapter_1
model.set_adapter("adapter_1")
output = model.generate(**inputs)
print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))
# use adapter_2
model.set_adapter("adapter_2")
output_enabled = model.generate(**inputs)
print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))
์ด๋ํฐ ํ์ฑํ ๋ฐ ๋นํ์ฑํ
๋ชจ๋ธ์ ์ด๋ํฐ๋ฅผ ์ถ๊ฐํ ํ ์ด๋ํฐ ๋ชจ๋์ ํ์ฑํ ๋๋ ๋นํ์ฑํํ ์ ์์ต๋๋ค. ์ด๋ํฐ ๋ชจ๋์ ํ์ฑํํ๋ ค๋ฉด:
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig
model_id = "facebook/opt-350m"
adapter_model_id = "ybelkada/opt-350m-lora"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = "Hello"
inputs = tokenizer(text, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PeftConfig.from_pretrained(adapter_model_id)
# to initiate with random weights
peft_config.init_lora_weights = False
model.add_adapter(peft_config)
model.enable_adapters()
output = model.generate(**inputs)
์ด๋ํฐ ๋ชจ๋์ ๋นํ์ฑํํ๋ ค๋ฉด:
model.disable_adapters() output = model.generate(**inputs)
PEFT ์ด๋ํฐ ํ๋ จ
PEFT ์ด๋ํฐ๋ Trainer ํด๋์ค์์ ์ง์๋๋ฏ๋ก ํน์ ์ฌ์ฉ ์ฌ๋ก์ ๋ง๊ฒ ์ด๋ํฐ๋ฅผ ํ๋ จํ ์ ์์ต๋๋ค. ๋ช ์ค์ ์ฝ๋๋ฅผ ์ถ๊ฐํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด LoRA ์ด๋ํฐ๋ฅผ ํ๋ จํ๋ ค๋ฉด:
Trainer๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ฒ์ด ์ต์ํ์ง ์๋ค๋ฉด ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ ํํ ๋ฆฌ์ผ์ ํ์ธํ์ธ์.
- ์์
์ ํ ๋ฐ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ง์ ํ์ฌ ์ด๋ํฐ ๊ตฌ์ฑ์ ์ ์ํฉ๋๋ค. ํ์ดํผํ๋ผ๋ฏธํฐ์ ๋ํ ์์ธํ ๋ด์ฉ์
~peft.LoraConfig
๋ฅผ ์ฐธ์กฐํ์ธ์.
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
- ๋ชจ๋ธ์ ์ด๋ํฐ๋ฅผ ์ถ๊ฐํฉ๋๋ค.
model.add_adapter(peft_config)
- ์ด์ ๋ชจ๋ธ์ Trainer์ ์ ๋ฌํ ์ ์์ต๋๋ค!
trainer = Trainer(model=model, ...) trainer.train()
ํ๋ จํ ์ด๋ํฐ๋ฅผ ์ ์ฅํ๊ณ ๋ค์ ๊ฐ์ ธ์ค๋ ค๋ฉด:
model.save_pretrained(save_dir) model = AutoModelForCausalLM.from_pretrained(save_dir)