File size: 6,975 Bytes
f860f0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import sqlite3
import spacy
import re
from thefuzz import process
import numpy as np
from transformers import pipeline
# Load intent classification model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
nlp = spacy.load("en_core_web_sm")
nlp_vectors = spacy.load("en_core_web_md")
# Define operator mappings
operator_mappings = {
"greater than": ">",
"less than": "<",
"equal to": "=",
"not equal to": "!=",
"starts with": "LIKE",
"ends with": "LIKE",
"contains": "LIKE",
"above": ">",
"below": "<",
"more than": ">",
"less than": "<",
"<": "<",
">": ">"
}
# Connect to SQLite database
def connect_to_db(db_path):
try:
conn = sqlite3.connect(db_path)
return conn
except sqlite3.Error as e:
print(f"Error connecting to database: {e}")
return None
# Fetch database schema
def fetch_schema(conn):
cursor = conn.cursor()
query = """
SELECT name
FROM sqlite_master
WHERE type='table';
"""
cursor.execute(query)
tables = cursor.fetchall()
schema = {}
for table in tables:
table_name = table[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
schema[table_name] = [{"name": col[1], "type": col[2], "not_null": col[3], "default": col[4], "pk": col[5]} for col in columns]
return schema
# Match token to schema columns using vector similarity and fuzzy matching
def find_best_match(token_text, table_schema):
"""Return the best-matching column from table_schema."""
token_vec = nlp_vectors(token_text).vector
best_col = None
best_score = 0.0
for col in table_schema:
col_vec = nlp_vectors(col).vector
score = token_vec.dot(col_vec) / (np.linalg.norm(token_vec) * np.linalg.norm(col_vec))
if score > best_score:
best_score = score
best_col = col
if best_score > 0.65:
return best_col
# Fallback to fuzzy matching if vector similarity fails
best_fuzzy_match, fuzzy_score = process.extractOne(token_text, table_schema)
if fuzzy_score > 80:
return best_fuzzy_match
return None
# Extract conditions from user query
def extract_conditions(question, schema, table):
table_schema = [col["name"].lower() for col in schema.get(table, [])]
# Detect whether the user used 'AND' / 'OR'
use_and = " and " in question.lower()
use_or = " or " in question.lower()
condition_parts = re.split(r"\band\b|\bor\b", question, flags=re.IGNORECASE)
conditions = []
for part in condition_parts:
part = part.strip()
tokens = [token.text.lower() for token in nlp(part)]
current_col = None
for token in tokens:
possible_col = find_best_match(token, table_schema)
if possible_col:
current_col = possible_col
break
if current_col:
for phrase, sql_operator in operator_mappings.items():
if phrase in part:
value_start = part.lower().find(phrase) + len(phrase)
value = part[value_start:].strip().split()[0]
if sql_operator == "LIKE":
if "starts with" in phrase:
value = f"'{value}%'"
elif "ends with" in phrase:
value = f"'%{value}'"
elif "contains" in phrase:
value = f"'%{value}%'"
conditions.append(f"{current_col} {sql_operator} {value}")
break
if use_and:
return " AND ".join(conditions)
elif use_or:
return " OR ".join(conditions)
else:
return " AND ".join(conditions) if conditions else None
# Main interpretation and execution
def interpret_question(question, schema):
intents = {
"describe_table": "Provide information about the columns and structure of a table.",
"list_table_data": "Fetch and display all data stored in a table.",
"count_records": "Count the number of records in a table.",
"fetch_column": "Fetch a specific column's data from a table.",
"fetch_all_data": "Fetch all records from a table without filters.",
"filter_data_with_conditions": "Fetch records based on specific conditions."
}
labels = list(intents.keys())
result = classifier(question, labels, multi_label=True)
scores = result["scores"]
predicted_label_index = np.argmax(scores)
predicted_intent = labels[predicted_label_index]
# Extract table name using schema and fuzzy matching
table, score = process.extractOne(question, schema.keys())
if score > 80:
return {"intent": predicted_intent, "table": table}
return {"intent": predicted_intent, "table": None}
def handle_intent(intent_data, schema, conn, question):
intent = intent_data["intent"]
table = intent_data["table"]
if not table:
return "I couldn't identify which table you're referring to."
if intent == "describe_table":
return schema.get(table, "No such table found.")
elif intent in ["list_table_data", "fetch_all_data"]:
conditions = extract_conditions(question, schema, table) if intent == "list_table_data" else None
query = f"SELECT * FROM {table}"
if conditions:
query += f" WHERE {conditions}"
cursor = conn.cursor()
cursor.execute(query)
return cursor.fetchall()
elif intent == "count_records":
query = f"SELECT COUNT(*) FROM {table}"
cursor = conn.cursor()
cursor.execute(query)
return cursor.fetchone()
elif intent == "fetch_column":
column = extract_conditions(question, schema, table)
if column:
query = f"SELECT {column} FROM {table}"
cursor = conn.cursor()
cursor.execute(query)
return cursor.fetchall()
else:
return "I couldn't identify which column you're referring to."
elif intent == "filter_data_with_conditions":
conditions = extract_conditions(question, schema, table)
query = f"SELECT * FROM {table} WHERE {conditions}"
cursor = conn.cursor()
cursor.execute(query)
return cursor.fetchall()
return "Unsupported intent."
# Entry point
def answer_question(question, conn, schema):
intent_data = interpret_question(question, schema)
return handle_intent(intent_data, schema, conn, question)
if __name__ == "__main__":
db_path = "./ecommerce.db"
conn = connect_to_db(db_path)
if conn:
schema = fetch_schema(conn)
while True:
question = input("\nAsk a question about the database: ")
if question.lower() in ["exit", "quit"]:
break
print(answer_question(question, conn, schema))
|