Fine-Tuned Google T5 Model for Text to SQL Translation
A fine-tuned version of the Google T5 model, trained for the task of translating natural language queries into SQL statements.
Model Details
- Architecture: Google T5 Base (Text-to-Text Transfer Transformer)
- Task: Text to SQL Translation
- Fine-Tuning Datasets:
Training Parameters
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=3,
predict_with_generate=True,
fp16=True,
push_to_hub=False,
)
Usage
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the tokenizer and model
model_path = 'juanfra218/text2sql'
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.to(device)
# Function to generate SQL queries
def generate_sql(prompt, schema):
input_text = "translate English to SQL: " + prompt + " " + schema
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
inputs = {key: value.to(device) for key, value in inputs.items()}
max_output_length = 1024
outputs = model.generate(**inputs, max_length=max_output_length)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Interactive loop
print("Enter 'quit' to exit.")
while True:
prompt = input("Insert prompt: ")
schema = input("Insert schema: ")
if prompt.lower() == 'quit':
break
sql_query = generate_sql(prompt, schema)
print(f"Generated SQL query: {sql_query}")
print()
Files
optimizer.pt
: State of the optimizer.training_args.bin
: Training arguments and hyperparameters.tokenizer.json
: Tokenizer vocabulary and settings.spiece.model
: SentencePiece model file.special_tokens_map.json
: Special tokens mapping.tokenizer_config.json
: Tokenizer configuration settings.model.safetensors
: Trained model weights.generation_config.json
: Configuration for text generation.config.json
: Model architecture configuration.test_results.csv
: Results on the testing set, contains: prompt, context, true_answer, predicted_answer, exact_match
- Downloads last month
- 5
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for juanfra218/text2sql
Base model
google-t5/t5-baseDatasets used to train juanfra218/text2sql
Evaluation results
- exact_matchself-reported0.433
- bleuself-reported0.669