|
import gradio as gr |
|
import os |
|
from datasets import load_dataset |
|
from huggingface_hub import HfApi, login |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
profile = "bstraehle" |
|
|
|
def download_model(model_id): |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
model.save_pretrained(model_id) |
|
|
|
return tokenizer |
|
|
|
def download_dataset(dataset): |
|
ds = load_dataset(dataset) |
|
return "" |
|
|
|
def fine_tune_model(): |
|
return "" |
|
|
|
def upload_model(model_id, tokenizer): |
|
model_name = model_id[model_id.rfind('/')+1:] |
|
model_repo_name = f"{profile}/{model_name}" |
|
|
|
login(token=os.environ["HF_TOKEN"]) |
|
|
|
api = HfApi() |
|
api.create_repo(repo_id=model_repo_name) |
|
api.upload_folder( |
|
folder_path=model_id, |
|
repo_id=model_repo_name |
|
) |
|
|
|
tokenizer.push_to_hub(model_repo_name) |
|
|
|
return model_repo_name |
|
|
|
def process(action, system_prompt, user_prompt, model_id, dataset): |
|
if action == "Prompt base model": |
|
result = "Prompt base model" |
|
elif action == "Fine-tuned base model": |
|
result = "Fine-tuned base model" |
|
elif action == "Prompt fine-tuned model": |
|
result = "Prompt fine-tuned model" |
|
|
|
|
|
|
|
|
|
return result |
|
|
|
system_prompt = "You are a text to SQL query translator. Users will ask you a question in English and you will generate a SQL query." |
|
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?" |
|
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
dataset = "gretelai/synthetic_text_to_sql" |
|
|
|
demo = gr.Interface(fn=process, |
|
inputs=[gr.Radio(["Prompt base model", "Fine-tune base model", "Prompt fine-tuned model"], label = "Action")], |
|
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 1), |
|
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 1), |
|
gr.Textbox(label = "Model ID", value = model_id, lines = 1), |
|
gr.Textbox(label = "Dataset", value = dataset, lines = 1)], |
|
outputs=[gr.Textbox(label = "Completion")]) |
|
demo.launch() |