Gemma LLM Model Fine-Tuning for Technical Summarization Chat Bot
The Gemma LLM model is being fine-tuned specifically for use in a technical summarization chatbot. This chatbot will leverage the model's ability to understand and summarize complex technical content, making it easier for users to engage with technical materials. The fine-tuning process is aimed at improving the model's performance in accurately capturing the essential points from dense, technical information, and providing concise, user-friendly summaries. The end goal is to enhance user experience in environments where quick, reliable technical insights are required.
Table of Contents
Dataset
The dataset used for this project is sourced from the Hugging Face repository, specifically from the wisenut-nlp-team/llama_ko_smr collection. This dataset contains various types of summarization data, including document summaries, book summaries, research paper summaries, TV content script summaries, Korean dialogue summaries, and technical/scientific summaries. Each entry in the dataset consists of the instruction, main text, and its corresponding summary.
Instead of limiting the training to just the technical and scientific summarization data, I opted to use the entire dataset to expose the model to a wider variety of content types. This decision was made to ensure the model is well-rounded and can handle diverse types of summarization tasks, improving its overall performance across different domains.
Here is an example of the dataset:
{
"instruction": "์ด ๊ธ์ ์ฃผ์ ๋ด์ฉ์ ์งง๊ฒ ์ค๋ช
ํด ์ฃผ์ค ์ ์์ต๋๊น?",
"input": "๋ถํ ์ฐ๊ทน์ ๋ํ ๋์ ํ๊ตฌ๋ ํด๋ฐฉ๊ณต๊ฐ์ ๋ถ์ผ๋ก ์ฌ๋ผ์ ธ ๊ฐ ์๋ง์ ์ฐ๊ทน์ธ๋ค์ ํ์ ์ ์ฐพ์๋ณด๊ณ ์ ํ๋ ๋จ์ํ ํธ๊ธฐ์ฌ์์ ์์๋์๋ค. ํด๋ฐฉ๊ณต๊ฐ์์ ํ๋ํ๋ ์ฐ๊ทน์ธ์ ๋๋ค์๊ฐ ๋ฉโค์๋ถ์ ๊ณผ์ ์ ๊ฑฐ์ณ ๋ถํ ์ฐ๊ทน๊ณ์ ์๋ฆฌ๋ฅผ ์ก์๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ ์์๋ ๊ทน์๊ฐ ์ก์, ํจ์ธ๋, ๋ฐ์ํธ, ์กฐ์์ถ, ์ฐ์ถ๊ฐ ์ด์ํฅ, ์์์ผ, ์ ๊ณ ์ก, ๋ฌด๋๋ฏธ์ ๊ฐ ๊น์ผ์, ๊ฐํธ, ๋ฐฐ์ฐ ํฉ์ฒ , ๊น์ ์, ๋ฌธ์๋ด, ๋ง๋ด๊ฐ ์ ๋ถ์ถ ๋ฑ ๊ธฐ๋ผ์ฑ ๊ฐ์ ๋ฉค๋ฒ๋ค์ด ํฌํจ๋์ด ์์๋ค. ๊ทธ ์ซ์๋ก๋ง ๋ณธ๋ค๋ฉด ์ผ์ ๊ฐ์ ๊ธฐ ์์ธ์ ์ฐ๊ทน๊ณ๊ฐ ํต์ผ๋ก ํ์์ผ๋ก ์ฎ๊ฒจ๊ฐ ์
์ด์๋ค. ๊ทธ๋ ์ง๋ง ์ด์ ๋ถํ ์ฐ๊ทน์์ ๋ ์ด์ ๊ทธ๋ค์ ์กด์ฌ๋ฅผ ํ์ธํ๊ธฐ ์ด๋ ค์ด ์ํฉ์ด๋ค. ๊ทธ๋ค์ ๋จ์์๋ ๋ถ์์๋ ์๊ณ์์ ์์ํ ์ฌ๋ผ์ ธ๋ฒ๋ฆฐ โ์์ด๋ฒ๋ฆฐ ์ธ๋โ ๊ทธ ์์ฒด์ด๋ค. ๊ทธ๋ค์ ํ์ ์ ์ฐพ๋ ๊ฒ์ ์ฐจ๋ผ๋ฆฌ ๊ณ ๊ณ ํ์ ๊ณผ์ ๊ฐ ๋์๋ค. ๊ทธ๋ค์ด ์ญ์ฌ์ ์ ํธ์ผ๋ก ์ฌ๋ผ์ง ๊ทธ ์๋ฆฌ์ ์ค๋์ ๋ถํ ์ฐ๊ทน์ด ์ฑ์ฑ์ฒ๋ผ ์์ฉ์ ์๋ํ๊ณ ์๋ค. ์ค๋๋ ์ ๋ถํ ์ฐ๊ทน์ ๋ชจ๋๊ฐ ์ฃผ์ฒด์ฌ์ค์ฃผ์์ ์
๊ฐํ์ฌ ๋ง๋ค์ด์ง๋ ์ด๋ฅธ๋ฐ โ<์ฑํฉ๋น>์ ํ๋ช
์ฐ๊ทนโ ์ผ์์ด๋ค. 1978๋
๊ตญ๋ฆฝ์ฐ๊ทน๋จ์ <์ฑํฉ๋น> ๊ณต์ฐ์ ์ฑ๊ณผ๋ฅผ ๋ณธ๋ณด๊ธฐ๋ก ์ผ์ ๋ชจ๋ ์ฐ๊ทน์ด โ๋ฐ๋ผ ๋ฐฐ์ฐ๊ธฐโ๋ฅผ ํ๊ณ ์๊ธฐ ๋๋ฌธ์ด๋ค. ๋ถํ์ ์ฐ๊ทน๊ณผ ํฌ๊ณก์ ์ ์ ์์ ๋ด๋ ค ์๋ ๋จ์ฑ์ (ๅฎ่ฒ็) ๋ฌธํํ๋ก ์์ ๊ฐํ ์๋ค. ํ๋ช
์ฐ๊ทน <์ฑํฉ๋น>(1978)์ ๋ณธ๋ณด๊ธฐ๋ ํ๋ช
๊ฐ๊ทน <ํผ๋ฐ๋ค>(1971)์ด๋ฉฐ, ๊ทธ ๊ทผ์ ์๋ 1960๋
๋๋ถํฐ ์์๋ ๊น์ ์ผ ์ฃผ๋์ ๋ฌธํ์์ ํ๋ช
์ด ๊ฐ๋ก๋์ฌ ์๋ค. ๋ถํ ์ฐ๊ทน์ ์ฐฝ์๊ณผ ํฅ์ , ๊ทธ ๋ชจ๋ ๊ณผ์ ์์ ๊น์ ์ผ์ ๊ทธ๋ฆผ์์ ๋ง๋ฅ๋จ๋ฆฌ์ง ์์ ์ ์๋ค. ์ต๊ทผ์ ๋ฐฉ๋ฌธํ ์กฐ์ ์์ ์ํ์ดฌ์์ ์ ์๋ โ๋ฌธํ์ฑํ๋ช
์ฌ์ ๊ดโ(๊น์ ์ผ๊ด)์๋ 1960๋
๋ ์ค๋ฐ๋ถํฐ 2000๋
๋๊น์ง 40๋
๋์ ๊น์ ์ผ์ ๋ฌธํ์์ ๋ถ๋ฌธ ์ง๋๊ฐ 11,890๊ฑด์ด๋ฉฐ, ๊ทธ ์ค ๋ฌธํ์์ ๊ธฐ๊ด์ ์ง์ ๋ฐฉ๋ฌธํ์ฌ ์ง๋ํ ์ด๋ฅธ๋ฐ โํ์ง์ง๋โ๊ฐ 1,770๊ฑด์ด๋ผ๋ ์๋ดํ์ด ์์๋ค. ๋ถํ ์ฐ๊ทน์ด ๊น์ ์ผ๊ณผ ์ฃผ์ฒด์ฌ์์ด๋ผ๋ ํค์๋๋ฅผ ๋ ๋ ์กด์ฌํ ์ ์๋ค๋ ๊ฒ์ ๋จ์ ์ผ๋ก ๋งํด ์ค๋ค.",
"output": "ํด๋ฐฉ๊ณต๊ฐ์์ ํ๋ํ ๋๋ค์์ ์ฐ๊ทน์ธ์ ๋ฉยท์๋ถ์ ํตํด ๋ถํ ์ฐ๊ทน๊ณ์ ์๋ฆฌ ์ก์์ง๋ง ์ด์ ๋ถํ ์ฐ๊ทน์์ ๊ทธ๋ค์ ๋ณด๊ธฐ ์ด๋ ต๋ค. ๊ทธ ์๋ฆฌ์ ์ง๊ธ ๋ถํ ์ฐ๊ทน์ด ์๋ฆฌ ์ก์๋ค. 1978๋
๊ตญ๋ฆฝ์ฐ๊ทน๋จ <์ฑํฉ๋น> ๊ณต์ฐ์ ํ ๋๋ก ๋ถํ์ ๋ชจ๋ ์ฐ๊ทน์ โ<์ฑํฉ๋น>์ ํ๋ช
์ฐ๊ทนโ ์ผ์์ด๋ค. ๋ถํ ์ฐ๊ทน๊ณผ ํฌ๊ณก์ ๋จ์ฑ์ ๋ฌธํํ๋ก์ ๋ฌถ์ฌ์๊ณ , ๊ทธ ์์์ ๊น์ ์ผ ์ฃผ๋ ๋ฌธํ์์ ํ๋ช
์ด ์๊ณ , ๋ถํ ์ฐ๊ทน์ ์ฐฝ์๊ณผ ํฅ์ ๋ฑ ๊น์ ์ผ ํ์ ์ด ์๋ค. ๊น์ ์ผ์ ๋ฌธํ์์ ๋ถ๋ฌธ ์ง๋ ๊ธฐ๋ก์ ๋ถํ ์ฐ๊ทน์ด ๊น์ ์ผ๊ณผ ์ฃผ์ฒด์ฌ์์ ๋ ๋ ์ ์๋ ๊ฒ์ ๋ณด์ฌ์ค๋ค."
}
Model
This model is built on the gemma-2-2b-it base and fine-tuned using advanced techniques such as BitsAndBytes for memory optimization, LoRA for efficient adaptation, and the SFTTrainer framework. You can find the fine-tuned version of this model on Hugging Face at this link.
Highlight
LoRA Configuration for Model Efficiency: The model is fine-tuned using Low-Rank Adaptation (LoRA) with specific configurations like r=6, lora_alpha=8, and a dropout of 0.05. This allows for efficient adaptation of the model without modifying all layers.
Quantization for Memory Optimization: The BitsAndBytesConfig is set to load the model in 4-bit precision, using nf4 quantization. This reduces memory usage, making it possible to fine-tune the model on larger datasets.
Fine-Tuning Parameters: Fine-tuning is set up using SFTTrainer, with a batch size of 1, gradient_accumulation_steps=4, and max_steps=3000. The training uses 8-bit AdamW optimizer (paged_adamw_8bit) for better performance in a memory-constrained environment.
Inference Example Code
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TrainingArguments
BASE_MODEL = "google/gemma-2b-it"
FINETUNE_MODEL = "./gemma-2b-it-sum-ko-science"
finetune_model = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)
pipe = pipeline("text-generation", model=finetune_model, tokenizer=tokenizer, max_new_tokens=512)
pipe_finetuned = pipeline("text-generation", model=finetune_model, tokenizer=tokenizer, max_new_tokens=512)
doc=None
doc = r"๊ทธ๋ ๊ฒ ๋ฑ์ฅํ ๊ฒ์ด ์์์๊ณ๋ค. ์์๊ฐ 1์ด ๋์ ์์ง์ด๋ ํ์์ธ โ๊ณ ์ ์ง๋์โ๋ฅผ ์ด์ฉํด ์ ํํ 1์ด๋ฅผ ์ธก์ ํ๋ค. ์์ ์์ ์๋ ์ ์๋ค์ ํน์ ์๋์ง ์ํ๋ก ์๋ค. ์ด ์ํ์์ ๋ค๋ฅธ ์ํ๋ก ๋ณํํ๋ ค๋ฉด ์๋์ง๋ฅผ ๋ ์ํ์ ์ฐจ์ด๋งํผ ํก์ํ๊ฑฐ๋ ๋ฐฉ์ถํด์ผ ํ๋ค. ์ ์๊ฐ ์๋์ง๋ฅผ ์ป๊ธฐ ์ํด(๋ค๋ฅธ ์๋์ง ์ํ๋ก ๋ณํ๊ธฐ ์ํด) ์ ์๊ธฐํ๋ฅผ ํก์ํ ๋ ์ง๋์ด ๋ฐ์ํ๋๋ฐ, ์ด๊ฒ์ด ๋ฐ๋ก ๊ณ ์ ์ง๋์๋ค."
#doc = r"์ฒ๋
๋ง๋
์ง๋๋ ๋ณํ์ง ์๋ ๊ณณ์ด ์์๊น. ๊ณผํ์๋ค์ ์ฒ๋
๋ง๋
์ ๋์ด ์์ต ๋
์ด ์ง๋๋ 1์ด์ ์ค์ฐจ๋ ์์ด ์ผ์ ํ๊ฒ ํ๋ฅด๋ ์๊ณ๋ฅผ ๊ฐ๋ฐํ๊ณ ์๋ค. ์ง๊ตฌ๊ฐ ํ ๋ฐํด ์์ ํ๋ ์๊ฐ์ 1์ผ์ด๋ผ๊ณ ํ๋ค. ์ด๊ฒ์ ์ชผ๊ฐ ์๊ฐ๊ณผ ๋ถ, ์ด๋ฅผ ์ ํ๋ค. ํ์ง๋ง ์ง๊ตฌ ์์ ์๋๋ ์๊ฐ์ ๋ฐ๋ผ ๋ณํ๋ฏ๋ก ์๊ฐ์ ์ค์ฐจ๊ฐ ์๊ฒผ๋ค. ์๋ก์ด ์๊ฐ์ ์ ์๊ฐ ํ์ํด์ง ์ด์ ๋ค."
messages = [
{
"role": "user",
"content": "๋ค์ ๊ธ์ ์์ฝํด์ฃผ์ธ์:\n\n{}".format(doc)
}
]
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = pipe_finetuned(
prompt,
do_sample=True,
temperature=0.2,
top_k=50,
top_p=0.95,
add_special_tokens=True
)
print(outputs[0]["generated_text"][len(prompt):])
- Downloads last month
- 2