Edit model card

GPT-2 Medium SFT and DPO on Anthropic-hh Dataset

This repository GPT-2 Medium model instruct tuned first on the Anthropic-hh dataset and then further aligned on the same dataset with DPO.

Model Information

  • Model Name: RaushanTurganbay/GPT2_sft_and_dpo_tuned
  • Base Model: GPT-2 Medium
  • Training Data: Anthropic-hh dataset
  • Fine-Tuning Approach: Direct Preference Optiization (DPO)

How to Use

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, StoppingCriteria, StoppingCriteriaList

tokenizer_dpo = GPT2Tokenizer.from_pretrained("RaushanTurganbay/GPT2_sft_and_dpo_tuned")
model_dpo = GPT2LMHeadModel.from_pretrained("RaushanTurganbay/GPT2_sft_and_dpo_tuned")

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = [stop.to("cuda") for stop in stops]
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False


def stopping_criteria(tokenizer, stop_words):
    stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    return stopping_criteria


# Generate responses
stopping = stopping_criteria(tokenizer, ["\n\nHuman:"])
prompt = "\n\nHuman: {your_instruction}\n\nAssistant:"
inputs_dpo = tokenizer_dpo(prompt, return_tensors="pt")
outputs_dpo = model_dpo.generate(**inputs_dpo,  stopping_criteria=stopping, max_length=150)

print("Model Response:", tokenizer_dpo.batch_decode(outputs_dpo))
Downloads last month
15
Safetensors
Model size
355M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train RaushanTurganbay/GPT2_sft_and_dpo_tuned