youjunhyeok's picture
Update README.md
d97843a verified
metadata
library_name: transformers
tags:
  - llama-factory
license: apache-2.0

Info

SFT > DPO 순서가 아닌 DPO > SFT 순서로 학습시킨 모델입니다. SFT > DPO는 여기에서 확인해 주세요.

Model

Dataset

Load Model

Use the following Python code to load the model:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "youjunhyeok/llama3-8B-dpo-sft-v1"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

Chat

def chat(message):
    messages = [
        {"role": "system", "content": "당신은 친절하고 도움이 되는 챗봇입니다."},
        {"role": "user", "content": message},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=512,
        eos_token_id=terminators,
        do_sample=False,
        temperature=0.5,
        top_p=0.8,
    )
    response = outputs[0][input_ids.shape[-1]:]
    print(tokenizer.decode(response, skip_special_tokens=True))

chat('한산도 대첩에 대해 아는 대로 얘기해봐')

Output

한산도 대첩은 조선시대에 일어난 전투로, 이순신 장군이 이끄는 조선군이 일본군을 물리친 전투입니다.

BenchMark (KOR)

# alias
A = youjunhyeok/llama3-8B-dpo-sft-v1
B = DavidAhn/Llama-3-8B-slerp-262k
C = meta-llama/Meta-Llama-3-8B
D = chihoonlee10/T3Q-ko-solar-dpo-v7.0 (24.05.24 ko 리더보드 1등)
Benchmark (macro_f1) A B C D
kobest_boolq (0-shot) 84.7 33.5 38.2 34.1
kobest_boolq (5-shot) 86.1 68.8 83.8 93.1
kobest_copa (0-shot) 60.6 58.5 63.1 81.0
kobest_copa (5-shot) 67.2 61.7 69.1 91.0
kobest_hellaswag (0-shot) 40.0 43.2 42.1 55.1
kobest_hellaswag (5-shot) 42.4 45.3 44.2 55.2
kobest_sentineg (0-shot) 52.1 34.8 51.5 82.7
kobest_sentineg (5-shot) 89.4 85.8 94.7 91.4

BenchMark (ENG)

openbookqa hellaswag boolq arc_easy arc_challenge
youjunhyeok/llama3-8B-dpo-sft-v1 0.320 0.547 0.529 0.748 0.446
meta-llama/Meta-Llama-3-8B-Instruct 0.338 0.576 0.831 0.815 0.529