Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
|
|
3 |
import os, torch
|
4 |
from datasets import load_dataset
|
5 |
from huggingface_hub import HfApi, login
|
6 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
7 |
|
8 |
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
9 |
|
@@ -34,10 +34,14 @@ def fine_tune_model(base_model_id, dataset):
|
|
34 |
# return fine_tuned_model_id
|
35 |
# Load the dataset
|
36 |
dataset = load_dataset("gretelai/synthetic_text_to_sql")
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# Load pre-trained model and tokenizer
|
39 |
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
40 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
41 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
42 |
|
43 |
# Preprocess the dataset
|
|
|
3 |
import os, torch
|
4 |
from datasets import load_dataset
|
5 |
from huggingface_hub import HfApi, login
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
7 |
|
8 |
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
9 |
|
|
|
34 |
# return fine_tuned_model_id
|
35 |
# Load the dataset
|
36 |
dataset = load_dataset("gretelai/synthetic_text_to_sql")
|
37 |
+
|
38 |
+
bnb_config = BitsAndBytesConfig(
|
39 |
+
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
40 |
+
)
|
41 |
|
42 |
# Load pre-trained model and tokenizer
|
43 |
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
44 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=bnb_config)
|
45 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
46 |
|
47 |
# Preprocess the dataset
|