|
import sqlite3 |
|
import spacy |
|
import re |
|
from thefuzz import process |
|
import numpy as np |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
nlp = spacy.load("en_core_web_sm") |
|
nlp_vectors = spacy.load("en_core_web_md") |
|
|
|
|
|
operator_mappings = { |
|
"greater than": ">", |
|
"less than": "<", |
|
"equal to": "=", |
|
"not equal to": "!=", |
|
"starts with": "LIKE", |
|
"ends with": "LIKE", |
|
"contains": "LIKE", |
|
"above": ">", |
|
"below": "<", |
|
"more than": ">", |
|
"less than": "<", |
|
"<": "<", |
|
">": ">" |
|
} |
|
|
|
|
|
def connect_to_db(db_path): |
|
try: |
|
conn = sqlite3.connect(db_path) |
|
return conn |
|
except sqlite3.Error as e: |
|
print(f"Error connecting to database: {e}") |
|
return None |
|
|
|
|
|
def fetch_schema(conn): |
|
cursor = conn.cursor() |
|
query = """ |
|
SELECT name |
|
FROM sqlite_master |
|
WHERE type='table'; |
|
""" |
|
cursor.execute(query) |
|
tables = cursor.fetchall() |
|
|
|
schema = {} |
|
for table in tables: |
|
table_name = table[0] |
|
cursor.execute(f"PRAGMA table_info({table_name});") |
|
columns = cursor.fetchall() |
|
schema[table_name] = [{"name": col[1], "type": col[2], "not_null": col[3], "default": col[4], "pk": col[5]} for col in columns] |
|
|
|
return schema |
|
|
|
|
|
def find_best_match(token_text, table_schema): |
|
"""Return the best-matching column from table_schema.""" |
|
token_vec = nlp_vectors(token_text).vector |
|
best_col = None |
|
best_score = 0.0 |
|
|
|
for col in table_schema: |
|
col_vec = nlp_vectors(col).vector |
|
score = token_vec.dot(col_vec) / (np.linalg.norm(token_vec) * np.linalg.norm(col_vec)) |
|
if score > best_score: |
|
best_score = score |
|
best_col = col |
|
|
|
if best_score > 0.65: |
|
return best_col |
|
|
|
|
|
best_fuzzy_match, fuzzy_score = process.extractOne(token_text, table_schema) |
|
if fuzzy_score > 80: |
|
return best_fuzzy_match |
|
|
|
return None |
|
|
|
|
|
def extract_conditions(question, schema, table): |
|
table_schema = [col["name"].lower() for col in schema.get(table, [])] |
|
|
|
|
|
use_and = " and " in question.lower() |
|
use_or = " or " in question.lower() |
|
|
|
condition_parts = re.split(r"\band\b|\bor\b", question, flags=re.IGNORECASE) |
|
conditions = [] |
|
|
|
for part in condition_parts: |
|
part = part.strip() |
|
tokens = [token.text.lower() for token in nlp(part)] |
|
current_col = None |
|
|
|
for token in tokens: |
|
possible_col = find_best_match(token, table_schema) |
|
if possible_col: |
|
current_col = possible_col |
|
break |
|
|
|
if current_col: |
|
for phrase, sql_operator in operator_mappings.items(): |
|
if phrase in part: |
|
value_start = part.lower().find(phrase) + len(phrase) |
|
value = part[value_start:].strip().split()[0] |
|
|
|
if sql_operator == "LIKE": |
|
if "starts with" in phrase: |
|
value = f"'{value}%'" |
|
elif "ends with" in phrase: |
|
value = f"'%{value}'" |
|
elif "contains" in phrase: |
|
value = f"'%{value}%'" |
|
|
|
conditions.append(f"{current_col} {sql_operator} {value}") |
|
break |
|
|
|
if use_and: |
|
return " AND ".join(conditions) |
|
elif use_or: |
|
return " OR ".join(conditions) |
|
else: |
|
return " AND ".join(conditions) if conditions else None |
|
|
|
|
|
|
|
def interpret_question(question, schema): |
|
intents = { |
|
"describe_table": "Provide information about the columns and structure of a table.", |
|
"list_table_data": "Fetch and display all data stored in a table.", |
|
"count_records": "Count the number of records in a table.", |
|
"fetch_column": "Fetch a specific column's data from a table.", |
|
"fetch_all_data": "Fetch all records from a table without filters.", |
|
"filter_data_with_conditions": "Fetch records based on specific conditions." |
|
} |
|
|
|
labels = list(intents.keys()) |
|
result = classifier(question, labels, multi_label=True) |
|
scores = result["scores"] |
|
predicted_label_index = np.argmax(scores) |
|
predicted_intent = labels[predicted_label_index] |
|
|
|
|
|
table, score = process.extractOne(question, schema.keys()) |
|
if score > 80: |
|
return {"intent": predicted_intent, "table": table} |
|
|
|
return {"intent": predicted_intent, "table": None} |
|
|
|
def handle_intent(intent_data, schema, conn, question): |
|
intent = intent_data["intent"] |
|
table = intent_data["table"] |
|
|
|
if not table: |
|
return "I couldn't identify which table you're referring to." |
|
|
|
if intent == "describe_table": |
|
return schema.get(table, "No such table found.") |
|
elif intent in ["list_table_data", "fetch_all_data"]: |
|
conditions = extract_conditions(question, schema, table) if intent == "list_table_data" else None |
|
query = f"SELECT * FROM {table}" |
|
if conditions: |
|
query += f" WHERE {conditions}" |
|
cursor = conn.cursor() |
|
cursor.execute(query) |
|
return cursor.fetchall() |
|
elif intent == "count_records": |
|
query = f"SELECT COUNT(*) FROM {table}" |
|
cursor = conn.cursor() |
|
cursor.execute(query) |
|
return cursor.fetchone() |
|
elif intent == "fetch_column": |
|
column = extract_conditions(question, schema, table) |
|
if column: |
|
query = f"SELECT {column} FROM {table}" |
|
cursor = conn.cursor() |
|
cursor.execute(query) |
|
return cursor.fetchall() |
|
else: |
|
return "I couldn't identify which column you're referring to." |
|
elif intent == "filter_data_with_conditions": |
|
conditions = extract_conditions(question, schema, table) |
|
query = f"SELECT * FROM {table} WHERE {conditions}" |
|
cursor = conn.cursor() |
|
cursor.execute(query) |
|
return cursor.fetchall() |
|
|
|
return "Unsupported intent." |
|
|
|
|
|
|
|
def answer_question(question, conn, schema): |
|
intent_data = interpret_question(question, schema) |
|
return handle_intent(intent_data, schema, conn, question) |
|
|
|
if __name__ == "__main__": |
|
db_path = "./ecommerce.db" |
|
conn = connect_to_db(db_path) |
|
if conn: |
|
schema = fetch_schema(conn) |
|
while True: |
|
question = input("\nAsk a question about the database: ") |
|
if question.lower() in ["exit", "quit"]: |
|
break |
|
print(answer_question(question, conn, schema)) |
|
|