File size: 7,936 Bytes
ac66ae2
ad99c34
7465957
cbbb9fd
50ba747
 
 
083fde1
ada7179
251d88f
74c640a
c8534fb
e92ef1c
 
df44c11
38576e5
df44c11
467c88a
df44c11
50ba747
ed270e5
df44c11
38576e5
f0acdd7
 
 
022150f
 
df44c11
38576e5
df44c11
 
 
76019db
74c640a
 
 
df44c11
6dd3828
01e1b5d
 
50ba747
01e1b5d
50ba747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825f16a
50ba747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ada7179
 
 
 
251d88f
 
01e1b5d
 
 
 
 
 
 
 
 
6dd3828
2f9af72
01e1b5d
 
 
 
 
 
 
 
 
 
 
 
 
92146e5
ada7179
74c640a
cbbb9fd
 
5eaca58
ada7179
cbbb9fd
ada7179
39546c6
0fb434b
ada7179
 
2b03f9f
74c640a
ada7179
74c640a
2b03f9f
e69ea59
e566aba
df44c11
227477e
e92ef1c
74c640a
065dd39
a184b8b
 
2371111
835fa92
13e776a
 
94ca6da
 
13e776a
73899fd
083fde1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, pipeline
from trl import SFTTrainer, setup_chat_format

# Fine-tune on NVidia A10G Large (sleep after 1 hour)

hf_profile = "bstraehle"

action_1 = "Fine-tune pre-trained model"
action_2 = "Prompt fine-tuned model"

system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"

base_model_id = "codellama/CodeLlama-7b-hf" # "ibm-granite/granite-8b-code-instruct" "meta-llama/Meta-Llama-3-8B-Instruct"
dataset = "b-mc2/sql-create-context"

def prompt_model(model_id, system_prompt, user_prompt, schema):
    pipe = pipeline("text-generation", 
                    model=model_id, 
                    model_kwargs={"torch_dtype": torch.bfloat16}, 
                    device_map="auto",
                    max_new_tokens=1000)
    messages = [
      {"role": "system", "content": system_prompt.format(schema=schema)},
      {"role": "user", "content": user_prompt},
      {"role": "assistant", "content": ""}
    ]
    output = pipe(messages)
    result = output[0]["generated_text"][-1]["content"]
    print(result)
    return result

def fine_tune_model(base_model_id, dataset):
    #tokenizer = download_model(base_model_id)
    download_dataset(dataset)
    train_model(base_model_id)
    #fine_tuned_model_id = upload_model(base_model_id, tokenizer)
    return "fine_tuned_model_id"

def train_model(model_id):
    print("111")
    dataset = load_dataset("json", data_files="train_dataset.json", split="train")

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
    )

    print("222")
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        #attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        quantization_config=bnb_config
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.padding_side = 'right' # to prevent warnings

    print("333")
    # # set chat template to OAI chatML, remove if you start from a fine-tuned model
    model, tokenizer = setup_chat_format(model, tokenizer)

    peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.05,
        r=256,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
    )

    print("444")
    args = TrainingArguments(
        output_dir="code-llama-7b-text-to-sql", # directory to save and repository id
        num_train_epochs=3,                     # number of training epochs
        per_device_train_batch_size=3,          # batch size per device during training
        gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
        gradient_checkpointing=True,            # use gradient checkpointing to save memory
        optim="adamw_torch_fused",              # use fused adamw optimizer
        logging_steps=10,                       # log every 10 steps
        save_strategy="epoch",                  # save checkpoint every epoch
        learning_rate=2e-4,                     # learning rate, based on QLoRA paper
        bf16=True,                              # use bfloat16 precision
        tf32=True,                              # use tf32 precision
        max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
        warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
        lr_scheduler_type="constant",           # use constant learning rate scheduler
        push_to_hub=True,                       # push model to hub
        report_to="tensorboard",                # report metrics to tensorboard
    )

    max_seq_length = 3072 # max sequence length for model and packing of the dataset

    print("555")
    trainer = SFTTrainer(
        model=model,
        args=args,
        train_dataset=dataset,
        peft_config=peft_config,
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        packing=True,
        dataset_kwargs={
            "add_special_tokens": False,  # We template with special tokens
            "append_concat_token": False, # No need to add additional separator token
        }
    )

    # start training, the model will be automatically saved to the hub and the output directory
    #trainer.train()
     
    # save model
    #trainer.save_model()

    del model
    del trainer
    torch.cuda.empty_cache()

def download_model(base_model_id):
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    model = AutoModelForCausalLM.from_pretrained(base_model_id)
    model.save_pretrained(base_model_id)
    return tokenizer

def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_prompt.format(schema=sample["context"])},
      {"role": "user", "content": sample["question"]},
      {"role": "assistant", "content": sample["answer"]}
    ]
  }
    
def download_dataset(dataset):
    dataset = load_dataset(dataset, split="train")
    dataset = dataset.shuffle().select(range(12500))
     
    # Convert dataset to OAI messages
    dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
    # split dataset into 10,000 training samples and 2,500 test samples
    dataset = dataset.train_test_split(test_size=2500/12500)
     
    print(dataset["train"][345]["messages"])
     
    # save datasets to disk
    dataset["train"].to_json("train_dataset.json", orient="records")
    dataset["test"].to_json("test_dataset.json", orient="records")
    ###

def upload_model(base_model_id, tokenizer):
    fine_tuned_model_id = replace_hf_profile(base_model_id)
    login(token=os.environ["HF_TOKEN"])
    api = HfApi()
    #api.delete_repo(repo_id=fine_tuned_model_id, repo_type="model")
    api.create_repo(repo_id=fine_tuned_model_id)
    api.upload_folder(
        folder_path=base_model_id,
        repo_id=fine_tuned_model_id
    )
    tokenizer.push_to_hub(fine_tuned_model_id)
    return fine_tuned_model_id

def replace_hf_profile(base_model_id):
    model_id = base_model_id[base_model_id.rfind('/')+1:]
    return f"{hf_profile}/{model_id}"

def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
    #raise gr.Error("Please clone and bring your own credentials.")
    if action == action_1:
        result = fine_tune_model(base_model_id, dataset)
    elif action == action_2:
        fine_tuned_model_id = replace_hf_profile(base_model_id)
        result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
    return result

demo = gr.Interface(fn=process, 
                    inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
                            gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
                            gr.Textbox(label = "Dataset", value = dataset, lines = 1),
                            gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
                            gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
                            gr.Textbox(label = "Schema", value = schema, lines = 2)],
                    outputs=[gr.Textbox(label = "Completion", value = os.environ["OUTPUT"])])
demo.launch()