|
import json |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
db_schema = { |
|
"products": ["product_id", "name", "price", "description", "type"], |
|
"orders": ["order_id", "product_id", "quantity", "order_date"], |
|
"customers": ["customer_id", "name", "email", "phone_number"] |
|
} |
|
|
|
|
|
model_name = "EleutherAI/gpt-neox-20b" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) |
|
|
|
def generate_sql_query(context, question): |
|
""" |
|
Generate an SQL query based on the question and context. |
|
|
|
Args: |
|
context (str): Description of the database schema or table relationships. |
|
question (str): User's natural language query. |
|
|
|
Returns: |
|
str: Generated SQL query. |
|
""" |
|
|
|
prompt = f""" |
|
Context: {context} |
|
|
|
Question: {question} |
|
|
|
Write an SQL query to address the question based on the context. |
|
Query: |
|
""" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
output = model.generate(inputs.input_ids, max_length=512, num_beams=5, early_stopping=True) |
|
query = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
sql_query = query.split("Query:")[-1].strip() |
|
return sql_query |
|
|
|
|
|
schema_description = json.dumps(db_schema, indent=4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
user_question = 'Show all products that cost more than $50' |
|
|
|
|
|
sql_query = generate_sql_query(schema_description, user_question) |
|
print(f"Generated SQL Query:\n{sql_query}\n") |
|
|