NLP-to-SQL / app.py
Anusha-TS's picture
Update app.py
e3d1e5e verified
from transformers import AutoTokenizer, AutoModelForCausalLM
Instructions ="""
### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
### Database Schema
The user query should run on the database whose schema is represented by using create queries in a json file with Keys as the table name and value as the create query :
{schema}
### Metadata for Schema
The detailed table and column description should be considered while referring to schema to understand link between tables. The json file has an array of tables with each table name having a prompt to be considered while generating the query and detailed description of each data column.
{Metadata}
### Further context
- All are entities i.e. city, county, districts all form entities. Each entity has an Entity ID and Entity type
- Entity type for County is County and the entity type ID for all counties will be same which is 15. Only the entity ID for county will vary.
- Every county has different entity types and each entity type will have different entities in it
- Example: Ada is a county. So Ada's entity type is county, entity type id is 15 and entity id will be 518. Ada has different districts which form entity types.
Fire district, water district, Abatement district are all different entity types in Ada county. Water district's entity type ID is 3 and will remain same for the type across all entities
Each entity type will have different entities like 'Boise Warm Springs Water District' which is an entity of entity type 'water district'
-When is regarding extracting budget of something, first identify the entities in the query like county, entity type, then match the corresponding things.
- You can get list of all counties by running 'select EntityName from ods.Entity where EntityTypeID=15'
- Remember to use only the tables given in schema and use the metadata to understand the context
- Do not create your own columns or tables. Use only those provided in schema and in metadata
- Entities belonging to a particular county can be obtained by querying ods.EntityCounties table by providing appropriate countyid
- If you cannot answer the question based on the information available, respond as 'I dont know'
### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}\n
[SQL]
"""
def format_prompt(question, schema, metadata, Instructions):
"""
Combines schema and question into a single prompt for the model.
"""
return f"""
### Instructions
{Instructions}
### Schema
{schema}
### Metadata:
{metadata}
### Question
{question}
### SQL
"""
def load_model():
"""
Loads the SQL generation model and tokenizer from Hugging Face.
"""
model_name = "defog/sqlcoder-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tokenizer, model
# Step 5: Generate SQL from question
def generate_sql(question, prompt_inputs, tokenizer, model):
"""
Generates an SQL query based on the question and schema.
Parameters:
- question: Natural language question
- schema: Database schema
-metadata : Has detailed description of schema
- tokenizer: Tokenizer instance
- model: Pre-trained SQL generation model
- device: Device to run the model on (e.g., 'cpu' or 'cuda')
Returns:
- Generated SQL query as a string
"""
# Format the prompt
prompt = format_prompt(question, prompt_inputs["schema"], prompt_inputs["metadata"], prompt_inputs["instructions"])
# Tokenize input"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
# Generate SQL
outputs = model.generate(
**inputs,
max_new_tokens=128,
# temperature=0.1,
# do_sample=False,
)
# Decode and return the generated SQL
return tokenizer.decode(outputs[0], skip_special_tokens=True)
tokenizer, model = load_model()
import json
prompt_inputs={
"schema":"",
"metadata":"",
"instructions":Instructions
}
instructions=Instructions
with open('table_create.json', 'r')as file:
prompt_inputs["schema"]=json.load(file)
with open('tables_metadata.json', 'r')as file:
prompt_inputs["metadata"]=json.load(file)
question = "Get list of all available distinct entity types with their entity type id"
# Generate SQL
generated_sql = generate_sql(question, prompt_inputs,tokenizer, model)
print("\nGenerated SQL:")
print(generated_sql)