|
import sqlite3 |
|
import spacy |
|
import re |
|
from thefuzz import process |
|
import numpy as np |
|
from transformers import pipeline |
|
|
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
nlp = spacy.load("en_core_web_sm") |
|
nlp_vectors = spacy.load("en_core_web_md") |
|
|
|
|
|
|
|
operator_mappings = { |
|
"greater than": ">", |
|
"less than": "<", |
|
"equal to": "=", |
|
"not equal to": "!=", |
|
"starts with": "LIKE", |
|
"ends with": "LIKE", |
|
"contains": "LIKE", |
|
"above": ">", |
|
"below": "<", |
|
"more than": ">", |
|
"less than": "<", |
|
"<": "<", |
|
">": ">" |
|
} |
|
|
|
|
|
def connect_to_db(db_path): |
|
conn = sqlite3.connect(db_path) |
|
return conn |
|
|
|
|
|
def fetch_schema(conn): |
|
cursor = conn.cursor() |
|
query = """ |
|
SELECT name |
|
FROM sqlite_master |
|
WHERE type='table'; |
|
""" |
|
cursor.execute(query) |
|
tables = cursor.fetchall() |
|
|
|
schema = {} |
|
for table in tables: |
|
table_name = table[0] |
|
cursor.execute(f"PRAGMA table_info({table_name});") |
|
columns = cursor.fetchall() |
|
schema[table_name] = [{"name": col[1], "type": col[2], "not_null": col[3], "default": col[4], "pk": col[5]} for col in columns] |
|
|
|
return schema |
|
|
|
def find_ai_synonym(token_text, table_schema): |
|
"""Return the best-matching column from table_schema based on vector similarity.""" |
|
token_vec = nlp_vectors(token_text)[0].vector |
|
best_col = None |
|
best_score = 0.0 |
|
|
|
for col in table_schema: |
|
col_vec = nlp_vectors(col)[0].vector |
|
|
|
score = token_vec.dot(col_vec) / (np.linalg.norm(token_vec) * np.linalg.norm(col_vec)) |
|
if score > best_score: |
|
best_score = score |
|
best_col = col |
|
|
|
|
|
if best_score > 0.65: |
|
return best_col |
|
return None |
|
|
|
def identify_table(question, schema_tables): |
|
|
|
table, score = process.extractOne(question, schema_tables) |
|
|
|
if score > 80: |
|
return table |
|
return None |
|
|
|
def identify_columns(question, columns_for_table): |
|
|
|
|
|
matched_cols = [] |
|
tokens = question.lower().split() |
|
for token in tokens: |
|
col, score = process.extractOne(token, columns_for_table) |
|
if score > 80: |
|
matched_cols.append(col) |
|
return matched_cols |
|
|
|
def find_closest_column(token, table_schema): |
|
|
|
|
|
best_match, score = process.extractOne(token, table_schema) |
|
|
|
if score > 90: |
|
return best_match |
|
return None |
|
|
|
|
|
def extract_conditions(question, schema, table): |
|
table_schema = [col["name"].lower() for col in schema.get(table, [])] |
|
|
|
|
|
|
|
use_and = " and " in question.lower() |
|
use_or = " or " in question.lower() |
|
last_column = None |
|
|
|
|
|
condition_parts = re.split(r'\band\b|\bor\b', question, flags=re.IGNORECASE) |
|
|
|
print(condition_parts) |
|
|
|
conditions = [] |
|
|
|
for part in condition_parts: |
|
part = part.strip() |
|
|
|
|
|
doc = nlp(part.lower()) |
|
tokens = [token.text for token in doc] |
|
|
|
|
|
|
|
tokens = [t for t in tokens if t != table.lower()] |
|
|
|
part_conditions = [] |
|
current_part_column = None |
|
|
|
print(tokens) |
|
|
|
for i, token in enumerate(tokens): |
|
|
|
possible_col = find_ai_synonym(token, table_schema) |
|
if possible_col: |
|
current_part_column = possible_col |
|
last_column = possible_col |
|
|
|
|
|
for phrase, sql_operator in operator_mappings.items(): |
|
if phrase in part.lower(): |
|
|
|
value_index = part.lower().find(phrase) + len(phrase) |
|
value = part[value_index:].strip().split(" ")[0] |
|
value = value.replace("'", "").replace('"', "").strip() |
|
|
|
|
|
if sql_operator == "LIKE": |
|
if "starts with" in phrase: |
|
value = f"'{value}%'" |
|
elif "ends with" in phrase: |
|
value = f"'%{value}'" |
|
elif "contains" in phrase: |
|
value = f"'%{value}%'" |
|
|
|
|
|
column_to_use = current_part_column or last_column |
|
if column_to_use: |
|
|
|
part_conditions.append(f"{column_to_use} {sql_operator} {value}") |
|
|
|
|
|
|
|
if part_conditions: |
|
conditions.append(" AND ".join(part_conditions)) |
|
|
|
|
|
if use_and: |
|
return " AND ".join(conditions) |
|
elif use_or: |
|
return " OR ".join(conditions) |
|
else: |
|
|
|
return " AND ".join(conditions) |
|
|
|
|
|
def interpret_question(question, schema): |
|
|
|
intents = { |
|
"describe_table": "Provide information about the columns and structure of a table.", |
|
"list_table_data": "Fetch and display all data stored in a table.", |
|
"count_records": "Count the number of records in a table.", |
|
"fetch_column": "Fetch a specific column's data from a table." |
|
} |
|
|
|
|
|
labels = list(intents.keys()) |
|
result = classifier(question, labels) |
|
|
|
predicted_intent = result["labels"][0] |
|
table = identify_table(question, list(schema.keys())) |
|
|
|
|
|
condition_keywords = list(operator_mappings.keys()) |
|
if any(keyword in question.lower() for keyword in condition_keywords): |
|
predicted_intent = "list_table_data" |
|
|
|
return {"intent": predicted_intent, "table": table} |
|
|
|
|
|
def handle_intent(intent_data, schema, conn, question): |
|
intent = intent_data["intent"] |
|
table = intent_data["table"] |
|
|
|
if not table: |
|
return "I couldn't identify which table you're referring to." |
|
|
|
if intent == "describe_table": |
|
|
|
table_schema = schema[table] |
|
description = [f"Table '{table}' has the following columns:"] |
|
for col in table_schema: |
|
col_details = f"- {col['name']} ({col['type']})" |
|
if col['not_null']: |
|
col_details += " [NOT NULL]" |
|
if col['default'] is not None: |
|
col_details += f" [DEFAULT: {col['default']}]" |
|
if col['pk']: |
|
col_details += " [PRIMARY KEY]" |
|
description.append(col_details) |
|
return "\n".join(description) |
|
|
|
elif intent == "list_table_data": |
|
|
|
condition = extract_conditions(question, schema, table) |
|
cursor = conn.cursor() |
|
query = f"SELECT * FROM {table}" |
|
if condition: |
|
query += f" WHERE {condition};" |
|
else: |
|
query += ";" |
|
|
|
print(query) |
|
cursor.execute(query) |
|
return cursor.fetchall() |
|
|
|
elif intent == "count_records": |
|
|
|
cursor = conn.cursor() |
|
cursor.execute(f"SELECT COUNT(*) FROM {table};") |
|
return cursor.fetchone() |
|
|
|
elif intent == "fetch_column": |
|
return "Fetching specific column data is not yet implemented." |
|
|
|
else: |
|
return "I couldn't understand your question." |
|
|
|
|
|
def answer_question(question, conn, schema): |
|
intent_data = interpret_question(question, schema) |
|
print(intent_data) |
|
return handle_intent(intent_data, schema, conn, question) |
|
|
|
|
|
if __name__ == "__main__": |
|
db_path = "./ecommerce.db" |
|
conn = connect_to_db(db_path) |
|
schema = fetch_schema(conn) |
|
|
|
print("Schema:", schema) |
|
|
|
while True: |
|
question = input("\nAsk a question about the database: ") |
|
if question.lower() in ["exit", "quit"]: |
|
break |
|
|
|
answer = answer_question(question, conn, schema) |
|
print("Answer:", answer) |
|
|