Manoj Kumar
commited on
Commit
·
e6f4fec
1
Parent(s):
867cb42
Mark POhase 1
Browse files- .DS_Store +0 -0
- Mark-1/db_creation.py +91 -0
- Mark-1/phase1.py +274 -0
- db.py +1 -1
- ecommerce.db +0 -0
- requirements.txt +4 -1
- wikiPreTrained.py +121 -0
- wikiSQL.py +199 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Mark-1/db_creation.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import random
|
3 |
+
from faker import Faker
|
4 |
+
|
5 |
+
# Initialize Faker for generating random data
|
6 |
+
fake = Faker()
|
7 |
+
|
8 |
+
# Define custom schema
|
9 |
+
custom_schema = {
|
10 |
+
"products": {
|
11 |
+
"columns": ["product_id INTEGER PRIMARY KEY", "name TEXT", "price REAL", "category_id INTEGER"],
|
12 |
+
"relations": ["category_id -> categories.id"],
|
13 |
+
},
|
14 |
+
"categories": {
|
15 |
+
"columns": ["id INTEGER PRIMARY KEY", "category_name TEXT"],
|
16 |
+
"relations": None,
|
17 |
+
},
|
18 |
+
"orders": {
|
19 |
+
"columns": ["order_id INTEGER PRIMARY KEY", "user_id INTEGER", "product_id INTEGER", "order_date TEXT"],
|
20 |
+
"relations": ["product_id -> products.product_id", "user_id -> users.user_id"],
|
21 |
+
},
|
22 |
+
"users": {
|
23 |
+
"columns": [
|
24 |
+
"user_id INTEGER PRIMARY KEY",
|
25 |
+
"first_name TEXT",
|
26 |
+
"last_name TEXT",
|
27 |
+
"email TEXT UNIQUE",
|
28 |
+
"phone_number TEXT",
|
29 |
+
"address TEXT"
|
30 |
+
],
|
31 |
+
"relations": None,
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
# Connect to SQLite database
|
36 |
+
conn = sqlite3.connect("ecommerce.db")
|
37 |
+
cursor = conn.cursor()
|
38 |
+
|
39 |
+
# Function to create tables based on schema
|
40 |
+
def create_tables():
|
41 |
+
for table_name, table_data in custom_schema.items():
|
42 |
+
columns = ", ".join(table_data["columns"])
|
43 |
+
table_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns}"
|
44 |
+
|
45 |
+
if table_data["relations"]:
|
46 |
+
for relation in table_data["relations"]:
|
47 |
+
col_name, ref_table = relation.split(" -> ")
|
48 |
+
ref_col = ref_table.split(".")[1]
|
49 |
+
ref_table_name = ref_table.split(".")[0]
|
50 |
+
table_sql += f", FOREIGN KEY({col_name}) REFERENCES {ref_table_name}({ref_col})"
|
51 |
+
|
52 |
+
table_sql += ");"
|
53 |
+
cursor.execute(table_sql)
|
54 |
+
|
55 |
+
# Function to populate categories table
|
56 |
+
def insert_categories():
|
57 |
+
categories = [(i, fake.word().capitalize() + " " + fake.word().capitalize()) for i in range(1, 1001)]
|
58 |
+
cursor.executemany("INSERT INTO categories (id, category_name) VALUES (?, ?)", categories)
|
59 |
+
return categories
|
60 |
+
|
61 |
+
# Function to populate products table
|
62 |
+
def insert_products(categories):
|
63 |
+
products = [(i, fake.company() + " " + fake.word().capitalize(), round(random.uniform(10, 1000), 2), random.choice(categories)[0]) for i in range(1, 1001)]
|
64 |
+
cursor.executemany("INSERT INTO products (product_id, name, price, category_id) VALUES (?, ?, ?, ?)", products)
|
65 |
+
return products
|
66 |
+
|
67 |
+
# Function to populate users table
|
68 |
+
def insert_users():
|
69 |
+
users = [(i, fake.first_name(), fake.last_name(), fake.email(), fake.phone_number(), fake.address()) for i in range(1, 1001)]
|
70 |
+
cursor.executemany("INSERT OR IGNORE INTO users (user_id, first_name, last_name, email, phone_number, address) VALUES (?, ?, ?, ?, ?, ?)", users)
|
71 |
+
return users
|
72 |
+
|
73 |
+
# Function to populate orders table
|
74 |
+
def insert_orders(users, products):
|
75 |
+
orders = [(i, random.choice(users)[0], random.choice(products)[0], fake.date_this_year().strftime("%Y-%m-%d")) for i in range(1, 1001)]
|
76 |
+
cursor.executemany("INSERT INTO orders (order_id, user_id, product_id, order_date) VALUES (?, ?, ?, ?)", orders)
|
77 |
+
|
78 |
+
# Create tables
|
79 |
+
create_tables()
|
80 |
+
|
81 |
+
# Insert data into tables
|
82 |
+
categories = insert_categories()
|
83 |
+
products = insert_products(categories)
|
84 |
+
users = insert_users()
|
85 |
+
insert_orders(users, products)
|
86 |
+
|
87 |
+
# Commit and close connection
|
88 |
+
conn.commit()
|
89 |
+
conn.close()
|
90 |
+
|
91 |
+
print("1000 rows inserted into each table successfully!")
|
Mark-1/phase1.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import spacy
|
3 |
+
import re
|
4 |
+
from thefuzz import process
|
5 |
+
import numpy as np
|
6 |
+
from transformers import pipeline
|
7 |
+
|
8 |
+
# Load intent classification model
|
9 |
+
# Use Hugging Face's zero-shot pipeline for flexibility
|
10 |
+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
11 |
+
nlp = spacy.load("en_core_web_sm")
|
12 |
+
nlp_vectors = spacy.load("en_core_web_md")
|
13 |
+
|
14 |
+
|
15 |
+
# Define operator mappings
|
16 |
+
operator_mappings = {
|
17 |
+
"greater than": ">",
|
18 |
+
"less than": "<",
|
19 |
+
"equal to": "=",
|
20 |
+
"not equal to": "!=",
|
21 |
+
"starts with": "LIKE",
|
22 |
+
"ends with": "LIKE",
|
23 |
+
"contains": "LIKE",
|
24 |
+
"above": ">",
|
25 |
+
"below": "<",
|
26 |
+
"more than": ">",
|
27 |
+
"less than": "<",
|
28 |
+
"<": "<",
|
29 |
+
">": ">"
|
30 |
+
}
|
31 |
+
|
32 |
+
# Connect to SQLite database
|
33 |
+
def connect_to_db(db_path):
|
34 |
+
conn = sqlite3.connect(db_path)
|
35 |
+
return conn
|
36 |
+
|
37 |
+
# Fetch database schema
|
38 |
+
def fetch_schema(conn):
|
39 |
+
cursor = conn.cursor()
|
40 |
+
query = """
|
41 |
+
SELECT name
|
42 |
+
FROM sqlite_master
|
43 |
+
WHERE type='table';
|
44 |
+
"""
|
45 |
+
cursor.execute(query)
|
46 |
+
tables = cursor.fetchall()
|
47 |
+
|
48 |
+
schema = {}
|
49 |
+
for table in tables:
|
50 |
+
table_name = table[0]
|
51 |
+
cursor.execute(f"PRAGMA table_info({table_name});")
|
52 |
+
columns = cursor.fetchall()
|
53 |
+
schema[table_name] = [{"name": col[1], "type": col[2], "not_null": col[3], "default": col[4], "pk": col[5]} for col in columns]
|
54 |
+
|
55 |
+
return schema
|
56 |
+
|
57 |
+
def find_ai_synonym(token_text, table_schema):
|
58 |
+
"""Return the best-matching column from table_schema based on vector similarity."""
|
59 |
+
token_vec = nlp_vectors(token_text)[0].vector
|
60 |
+
best_col = None
|
61 |
+
best_score = 0.0
|
62 |
+
|
63 |
+
for col in table_schema:
|
64 |
+
col_vec = nlp_vectors(col)[0].vector
|
65 |
+
# Cosine similarity
|
66 |
+
score = token_vec.dot(col_vec) / (np.linalg.norm(token_vec) * np.linalg.norm(col_vec))
|
67 |
+
if score > best_score:
|
68 |
+
best_score = score
|
69 |
+
best_col = col
|
70 |
+
|
71 |
+
# Apply threshold
|
72 |
+
if best_score > 0.65:
|
73 |
+
return best_col
|
74 |
+
return None
|
75 |
+
|
76 |
+
def identify_table(question, schema_tables):
|
77 |
+
# schema_tables = ["products", "users", "orders", ...]
|
78 |
+
table, score = process.extractOne(question, schema_tables)
|
79 |
+
|
80 |
+
if score > 80: # a comfortable threshold
|
81 |
+
return table
|
82 |
+
return None
|
83 |
+
|
84 |
+
def identify_columns(question, columns_for_table):
|
85 |
+
# columns_for_table = ["id", "price", "stock", "name", ...]
|
86 |
+
# For each token in question, fuzzy match to columns
|
87 |
+
matched_cols = []
|
88 |
+
tokens = question.lower().split()
|
89 |
+
for token in tokens:
|
90 |
+
col, score = process.extractOne(token, columns_for_table)
|
91 |
+
if score > 80:
|
92 |
+
matched_cols.append(col)
|
93 |
+
return matched_cols
|
94 |
+
|
95 |
+
def find_closest_column(token, table_schema):
|
96 |
+
# table_schema is a list of column names, e.g. ["price", "stock", "name"]
|
97 |
+
# This returns (best_match, score)
|
98 |
+
best_match, score = process.extractOne(token, table_schema)
|
99 |
+
# You can tune this threshold as needed (e.g. 70, 80, etc.)
|
100 |
+
if score > 90:
|
101 |
+
return best_match
|
102 |
+
return None
|
103 |
+
|
104 |
+
# Condition extraction with NLP
|
105 |
+
def extract_conditions(question, schema, table):
|
106 |
+
table_schema = [col["name"].lower() for col in schema.get(table, [])]
|
107 |
+
|
108 |
+
# Detect whether the user used 'AND' / 'OR'
|
109 |
+
# (case-insensitive, hence .lower() checks)
|
110 |
+
use_and = " and " in question.lower()
|
111 |
+
use_or = " or " in question.lower()
|
112 |
+
last_column = None
|
113 |
+
|
114 |
+
# Split on 'and' or 'or' to handle multiple conditions
|
115 |
+
condition_parts = re.split(r'\band\b|\bor\b', question, flags=re.IGNORECASE)
|
116 |
+
|
117 |
+
print(condition_parts)
|
118 |
+
|
119 |
+
conditions = []
|
120 |
+
|
121 |
+
for part in condition_parts:
|
122 |
+
part = part.strip()
|
123 |
+
|
124 |
+
# Use spaCy to tokenize each part
|
125 |
+
doc = nlp(part.lower())
|
126 |
+
tokens = [token.text for token in doc]
|
127 |
+
|
128 |
+
# Skip the recognized_table token if it appears in tokens
|
129 |
+
# so it won't be matched as a column
|
130 |
+
tokens = [t for t in tokens if t != table.lower()]
|
131 |
+
|
132 |
+
part_conditions = []
|
133 |
+
current_part_column = None
|
134 |
+
|
135 |
+
print(tokens)
|
136 |
+
|
137 |
+
for i, token in enumerate(tokens):
|
138 |
+
# Try synonyms/fuzzy, etc. to find a column
|
139 |
+
possible_col = find_ai_synonym(token, table_schema)
|
140 |
+
if possible_col:
|
141 |
+
current_part_column = possible_col
|
142 |
+
last_column = possible_col # update last_column
|
143 |
+
|
144 |
+
# Check for any matching operator phrase in this part
|
145 |
+
for phrase, sql_operator in operator_mappings.items():
|
146 |
+
if phrase in part.lower():
|
147 |
+
# Extract the value after the phrase
|
148 |
+
value_index = part.lower().find(phrase) + len(phrase)
|
149 |
+
value = part[value_index:].strip().split(" ")[0]
|
150 |
+
value = value.replace("'", "").replace('"', "").strip()
|
151 |
+
|
152 |
+
# Special handling for LIKE operators
|
153 |
+
if sql_operator == "LIKE":
|
154 |
+
if "starts with" in phrase:
|
155 |
+
value = f"'{value}%'"
|
156 |
+
elif "ends with" in phrase:
|
157 |
+
value = f"'%{value}'"
|
158 |
+
elif "contains" in phrase:
|
159 |
+
value = f"'%{value}%'"
|
160 |
+
|
161 |
+
# If we did not find a new column, fallback to last_column
|
162 |
+
column_to_use = current_part_column or last_column
|
163 |
+
if column_to_use:
|
164 |
+
# Add this condition to the list for this part
|
165 |
+
part_conditions.append(f"{column_to_use} {sql_operator} {value}")
|
166 |
+
|
167 |
+
# If multiple conditions are found in this part, join them with AND
|
168 |
+
# (e.g., "price > 100 AND stock < 50" within the same part)
|
169 |
+
if part_conditions:
|
170 |
+
conditions.append(" AND ".join(part_conditions))
|
171 |
+
|
172 |
+
# Finally, combine each part with AND or OR, depending on the user query
|
173 |
+
if use_and:
|
174 |
+
return " AND ".join(conditions)
|
175 |
+
elif use_or:
|
176 |
+
return " OR ".join(conditions)
|
177 |
+
else:
|
178 |
+
# If there's only one part or no explicit 'and'/'or', default to AND
|
179 |
+
return " AND ".join(conditions)
|
180 |
+
|
181 |
+
# Interpret user question using intent recognition
|
182 |
+
def interpret_question(question, schema):
|
183 |
+
# Define potential intents
|
184 |
+
intents = {
|
185 |
+
"describe_table": "Provide information about the columns and structure of a table.",
|
186 |
+
"list_table_data": "Fetch and display all data stored in a table.",
|
187 |
+
"count_records": "Count the number of records in a table.",
|
188 |
+
"fetch_column": "Fetch a specific column's data from a table."
|
189 |
+
}
|
190 |
+
|
191 |
+
# Use classifier to predict intent
|
192 |
+
labels = list(intents.keys())
|
193 |
+
result = classifier(question, labels)
|
194 |
+
|
195 |
+
predicted_intent = result["labels"][0]
|
196 |
+
table = identify_table(question, list(schema.keys()))
|
197 |
+
|
198 |
+
# Rule-based fallback for conditional queries
|
199 |
+
condition_keywords = list(operator_mappings.keys())
|
200 |
+
if any(keyword in question.lower() for keyword in condition_keywords):
|
201 |
+
predicted_intent = "list_table_data"
|
202 |
+
|
203 |
+
return {"intent": predicted_intent, "table": table}
|
204 |
+
|
205 |
+
# Handle different intents
|
206 |
+
def handle_intent(intent_data, schema, conn, question):
|
207 |
+
intent = intent_data["intent"]
|
208 |
+
table = intent_data["table"]
|
209 |
+
|
210 |
+
if not table:
|
211 |
+
return "I couldn't identify which table you're referring to."
|
212 |
+
|
213 |
+
if intent == "describe_table":
|
214 |
+
# Describe table structure
|
215 |
+
table_schema = schema[table]
|
216 |
+
description = [f"Table '{table}' has the following columns:"]
|
217 |
+
for col in table_schema:
|
218 |
+
col_details = f"- {col['name']} ({col['type']})"
|
219 |
+
if col['not_null']:
|
220 |
+
col_details += " [NOT NULL]"
|
221 |
+
if col['default'] is not None:
|
222 |
+
col_details += f" [DEFAULT: {col['default']}]"
|
223 |
+
if col['pk']:
|
224 |
+
col_details += " [PRIMARY KEY]"
|
225 |
+
description.append(col_details)
|
226 |
+
return "\n".join(description)
|
227 |
+
|
228 |
+
elif intent == "list_table_data":
|
229 |
+
# Check for conditions
|
230 |
+
condition = extract_conditions(question, schema, table)
|
231 |
+
cursor = conn.cursor()
|
232 |
+
query = f"SELECT * FROM {table}"
|
233 |
+
if condition:
|
234 |
+
query += f" WHERE {condition};"
|
235 |
+
else:
|
236 |
+
query += ";"
|
237 |
+
|
238 |
+
print(query)
|
239 |
+
cursor.execute(query)
|
240 |
+
return cursor.fetchall()
|
241 |
+
|
242 |
+
elif intent == "count_records":
|
243 |
+
# Count records in the table
|
244 |
+
cursor = conn.cursor()
|
245 |
+
cursor.execute(f"SELECT COUNT(*) FROM {table};")
|
246 |
+
return cursor.fetchone()
|
247 |
+
|
248 |
+
elif intent == "fetch_column":
|
249 |
+
return "Fetching specific column data is not yet implemented."
|
250 |
+
|
251 |
+
else:
|
252 |
+
return "I couldn't understand your question."
|
253 |
+
|
254 |
+
# Main function
|
255 |
+
def answer_question(question, conn, schema):
|
256 |
+
intent_data = interpret_question(question, schema)
|
257 |
+
print(intent_data)
|
258 |
+
return handle_intent(intent_data, schema, conn, question)
|
259 |
+
|
260 |
+
# Example Usage
|
261 |
+
if __name__ == "__main__":
|
262 |
+
db_path = "./ecommerce.db" # Replace with your SQLite database path
|
263 |
+
conn = connect_to_db(db_path)
|
264 |
+
schema = fetch_schema(conn)
|
265 |
+
|
266 |
+
print("Schema:", schema)
|
267 |
+
|
268 |
+
while True:
|
269 |
+
question = input("\nAsk a question about the database: ")
|
270 |
+
if question.lower() in ["exit", "quit"]:
|
271 |
+
break
|
272 |
+
|
273 |
+
answer = answer_question(question, conn, schema)
|
274 |
+
print("Answer:", answer)
|
db.py
CHANGED
@@ -34,7 +34,7 @@ def generate_context(schema):
|
|
34 |
schema_context = generate_context(schema)
|
35 |
|
36 |
# Step 2: Load the T5-base-text-to-sql model
|
37 |
-
model_name = "
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
39 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
40 |
|
|
|
34 |
schema_context = generate_context(schema)
|
35 |
|
36 |
# Step 2: Load the T5-base-text-to-sql model
|
37 |
+
model_name = "suriya7/t5-base-text-to-sql" # A model fine-tuned for SQL generation
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
39 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
40 |
|
ecommerce.db
ADDED
Binary file (258 kB). View file
|
|
requirements.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
transformers
|
2 |
torch
|
3 |
accelerate>=0.26.0
|
4 |
-
tiktoken
|
|
|
|
|
|
|
|
1 |
transformers
|
2 |
torch
|
3 |
accelerate>=0.26.0
|
4 |
+
tiktoken
|
5 |
+
datasets
|
6 |
+
sentencepiece
|
7 |
+
tqdm
|
wikiPreTrained.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
2 |
+
import torch
|
3 |
+
import re
|
4 |
+
|
5 |
+
# Load the trained model and tokenizer
|
6 |
+
model = T5ForConditionalGeneration.from_pretrained("./t5_sql_finetuned")
|
7 |
+
tokenizer = T5Tokenizer.from_pretrained("./t5_sql_finetuned")
|
8 |
+
|
9 |
+
# Define a simple function to check if the question is schema-related or SQL-related
|
10 |
+
def is_schema_question(question: str):
|
11 |
+
schema_keywords = ["columns", "tables", "structure", "schema", "relations", "fields"]
|
12 |
+
return any(keyword in question.lower() for keyword in schema_keywords)
|
13 |
+
|
14 |
+
# Helper function to extract table name from the question
|
15 |
+
def extract_table_name(question: str):
|
16 |
+
# Regex pattern to find table names, assuming table names are capitalized or match a known pattern
|
17 |
+
table_name_match = re.search(r'for (\w+)|in (\w+)|from (\w+)', question)
|
18 |
+
|
19 |
+
if table_name_match:
|
20 |
+
# Return the matched table name (first capturing group)
|
21 |
+
return table_name_match.group(1) or table_name_match.group(2) or table_name_match.group(3)
|
22 |
+
|
23 |
+
# If no table name is detected, return None
|
24 |
+
return None
|
25 |
+
|
26 |
+
|
27 |
+
# Define a function to handle SQL generation
|
28 |
+
def generate_sql(question: str, schema: dict, model, tokenizer, device):
|
29 |
+
# Preprocess the question for SQL generation (e.g., reformat)
|
30 |
+
# Example question: "What is the price of the product with ID 123?"
|
31 |
+
|
32 |
+
# Here we use the model to generate SQL query
|
33 |
+
inputs = tokenizer(question, return_tensors="pt")
|
34 |
+
input_ids = inputs.input_ids.to(device)
|
35 |
+
|
36 |
+
with torch.no_grad():
|
37 |
+
generated_ids = model.generate(input_ids, max_length=128)
|
38 |
+
|
39 |
+
# Decode the SQL query generated by the model
|
40 |
+
sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
41 |
+
|
42 |
+
return sql_query
|
43 |
+
|
44 |
+
# Define a function to handle schema-related questions
|
45 |
+
def handle_schema_question(question: str, schema: dict):
|
46 |
+
# Here you handle questions about the schema (tables, columns, relations)
|
47 |
+
# Example schema-related question: "What columns does the products table have?"
|
48 |
+
|
49 |
+
question = question.lower()
|
50 |
+
|
51 |
+
# Check if the question asks about columns
|
52 |
+
if "columns" in question or "fields" in question:
|
53 |
+
table_name = extract_table_name(question)
|
54 |
+
if table_name:
|
55 |
+
if table_name in schema:
|
56 |
+
return schema[table_name]["columns"]
|
57 |
+
else:
|
58 |
+
return f"Table '{table_name}' not found in the schema."
|
59 |
+
|
60 |
+
# Check if the question asks about relations
|
61 |
+
elif "relations" in question or "relationships" in question:
|
62 |
+
table_name = extract_table_name(question)
|
63 |
+
if table_name:
|
64 |
+
if table_name in schema:
|
65 |
+
return schema[table_name]["relations"]
|
66 |
+
else:
|
67 |
+
return f"Table '{table_name}' not found in the schema."
|
68 |
+
|
69 |
+
# Additional cases can be handled here (e.g., "Which tables are in the schema?")
|
70 |
+
elif "tables" in question:
|
71 |
+
return list(schema.keys())
|
72 |
+
|
73 |
+
# If the question is too vague or doesn't match the expected patterns
|
74 |
+
return "Sorry, I couldn't understand your schema question. Could you rephrase?"
|
75 |
+
|
76 |
+
|
77 |
+
# Example schema for your custom use case
|
78 |
+
custom_schema = {
|
79 |
+
"products": {
|
80 |
+
"columns": ["product_id", "name", "price", "category_id"],
|
81 |
+
"relations": "category_id -> categories.id",
|
82 |
+
},
|
83 |
+
"categories": {
|
84 |
+
"columns": ["id", "category_name"],
|
85 |
+
"relations": None,
|
86 |
+
},
|
87 |
+
"orders": {
|
88 |
+
"columns": ["order_id", "user_id", "product_id", "order_date"],
|
89 |
+
"relations": ["product_id -> products.product_id", "user_id -> users.user_id"],
|
90 |
+
},
|
91 |
+
"users": {
|
92 |
+
"columns": ["user_id", "first_name", "last_name", "email", "phone_number", "address"],
|
93 |
+
"relations": None,
|
94 |
+
}
|
95 |
+
}
|
96 |
+
|
97 |
+
def answer_question(question: str, schema: dict, model, tokenizer, device):
|
98 |
+
# First, check if the question is about the schema or SQL
|
99 |
+
if is_schema_question(question):
|
100 |
+
# Handle schema-related questions
|
101 |
+
response = handle_schema_question(question, schema)
|
102 |
+
return f"Schema Information: {response}"
|
103 |
+
else:
|
104 |
+
# Generate an SQL query for data-related questions
|
105 |
+
sql_query = generate_sql(question, schema, model, tokenizer, device)
|
106 |
+
return f"Generated SQL Query: {sql_query}"
|
107 |
+
|
108 |
+
# Example input questions
|
109 |
+
question_1 = "What columns does the products table have?"
|
110 |
+
question_2 = "What is the price of the product with product_id 123?"
|
111 |
+
|
112 |
+
# Assuming you have loaded your model and tokenizer as `model` and `tokenizer`
|
113 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
114 |
+
|
115 |
+
# Handle schema question
|
116 |
+
response_1 = answer_question(question_1, custom_schema, model, tokenizer, device)
|
117 |
+
print(response_1) # This should give you the columns of the products table
|
118 |
+
|
119 |
+
# Handle SQL query question
|
120 |
+
response_2 = answer_question(question_2, custom_schema, model, tokenizer, device)
|
121 |
+
print(response_2) # This should generate an SQL query for fetching the price
|
wikiSQL.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from datasets import Dataset
|
5 |
+
from transformers import (
|
6 |
+
AutoTokenizer,
|
7 |
+
AutoModelForSeq2SeqLM,
|
8 |
+
Seq2SeqTrainer,
|
9 |
+
Seq2SeqTrainingArguments,
|
10 |
+
)
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from sklearn.model_selection import train_test_split
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def load_table_schemas(tables_file):
|
17 |
+
"""
|
18 |
+
Load table schemas from the tables.jsonl file.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
tables_file: Path to the tables.jsonl file.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
A dictionary mapping table IDs to their column names.
|
25 |
+
"""
|
26 |
+
table_schemas = {}
|
27 |
+
with open(tables_file, 'r') as f:
|
28 |
+
for line in f:
|
29 |
+
table_data = json.loads(line)
|
30 |
+
table_id = table_data["id"]
|
31 |
+
table_columns = table_data["header"]
|
32 |
+
table_schemas[table_id] = table_columns
|
33 |
+
return table_schemas
|
34 |
+
|
35 |
+
|
36 |
+
# Step 1: Load and Preprocess WikiSQL Data
|
37 |
+
def load_wikisql(data_dir):
|
38 |
+
"""
|
39 |
+
Load WikiSQL data and prepare it for training.
|
40 |
+
Args:
|
41 |
+
data_dir: Path to the WikiSQL dataset directory.
|
42 |
+
Returns:
|
43 |
+
List of examples with input and target text.
|
44 |
+
"""
|
45 |
+
def parse_file(file_path):
|
46 |
+
with open(file_path, 'r') as f:
|
47 |
+
return [json.loads(line) for line in f]
|
48 |
+
|
49 |
+
tables_data = parse_file(os.path.join(data_dir, "train.tables.jsonl"))
|
50 |
+
train_data = parse_file(os.path.join(data_dir, "train.jsonl"))
|
51 |
+
dev_data = parse_file(os.path.join(data_dir, "dev.jsonl"))
|
52 |
+
|
53 |
+
print("====>", train_data[0])
|
54 |
+
tables_file = "./data/train.tables.jsonl"
|
55 |
+
table_schemas = load_table_schemas(tables_file)
|
56 |
+
|
57 |
+
dev_tables = './data/dev.tables.jsonl'
|
58 |
+
dev_tables_schema = load_table_schemas(dev_tables)
|
59 |
+
|
60 |
+
def format_data(data, type):
|
61 |
+
formatted = []
|
62 |
+
for item in data:
|
63 |
+
table_id = item["table_id"]
|
64 |
+
table_columns = table_schemas[table_id] if type == 'train' else dev_tables_schema[table_id]
|
65 |
+
question = item["question"]
|
66 |
+
sql = item["sql"]
|
67 |
+
sql_query = sql_to_text(sql, table_columns)
|
68 |
+
print("SQL Query", sql_query)
|
69 |
+
formatted.append({"input": f"Question: {question}", "target": sql_query})
|
70 |
+
return formatted
|
71 |
+
|
72 |
+
return format_data(train_data, "train"), format_data(dev_data, "dev")
|
73 |
+
|
74 |
+
|
75 |
+
def sql_to_text(sql, table_columns):
|
76 |
+
"""
|
77 |
+
Convert SQL dictionary from WikiSQL to text representation.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
sql: SQL dictionary from WikiSQL (e.g., {"sel": 5, "conds": [[3, 0, "value"]], "agg": 0}).
|
81 |
+
table_columns: List of column names corresponding to the table.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
SQL query as a string.
|
85 |
+
"""
|
86 |
+
# Aggregation functions mapping
|
87 |
+
agg_functions = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
|
88 |
+
operators = ["=", ">", "<"]
|
89 |
+
|
90 |
+
# Get selected column
|
91 |
+
sel_column = table_columns[sql["sel"]]
|
92 |
+
agg_func = agg_functions[sql["agg"]]
|
93 |
+
select_clause = f"SELECT {agg_func}({sel_column})" if agg_func else f"SELECT {sel_column}"
|
94 |
+
|
95 |
+
# Get conditions
|
96 |
+
if sql["conds"]:
|
97 |
+
conditions = []
|
98 |
+
for cond in sql["conds"]:
|
99 |
+
col_idx, operator, value = cond
|
100 |
+
col_name = table_columns[col_idx]
|
101 |
+
conditions.append(f"{col_name} {operators[operator]} '{value}'")
|
102 |
+
where_clause = " WHERE " + " AND ".join(conditions)
|
103 |
+
else:
|
104 |
+
where_clause = ""
|
105 |
+
|
106 |
+
# Combine clauses into a full query
|
107 |
+
return select_clause + where_clause
|
108 |
+
|
109 |
+
# Step 2: Tokenize the Data
|
110 |
+
def tokenize_data(data, tokenizer, max_length=128):
|
111 |
+
"""
|
112 |
+
Tokenize the input and target text.
|
113 |
+
Args:
|
114 |
+
data: List of examples with "input" and "target".
|
115 |
+
tokenizer: Pretrained tokenizer.
|
116 |
+
max_length: Maximum sequence length for the model.
|
117 |
+
Returns:
|
118 |
+
Tokenized dataset.
|
119 |
+
"""
|
120 |
+
inputs = [item["input"] for item in data]
|
121 |
+
targets = [item["target"] for item in data]
|
122 |
+
|
123 |
+
tokenized = tokenizer(
|
124 |
+
inputs,
|
125 |
+
max_length=max_length,
|
126 |
+
padding="max_length",
|
127 |
+
truncation=True,
|
128 |
+
return_tensors="pt",
|
129 |
+
)
|
130 |
+
labels = tokenizer(
|
131 |
+
targets,
|
132 |
+
max_length=max_length,
|
133 |
+
padding="max_length",
|
134 |
+
truncation=True,
|
135 |
+
return_tensors="pt",
|
136 |
+
)
|
137 |
+
|
138 |
+
tokenized["labels"] = labels["input_ids"]
|
139 |
+
return tokenized
|
140 |
+
|
141 |
+
|
142 |
+
# Step 3: Load Model and Tokenizer
|
143 |
+
model_name = "t5-small" # Use "t5-small", "t5-base", or "t5-large"
|
144 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
145 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
146 |
+
|
147 |
+
# Step 4: Prepare Training and Validation Data
|
148 |
+
data_dir = "data" # Path to the WikiSQL dataset
|
149 |
+
train_data, dev_data = load_wikisql(data_dir)
|
150 |
+
|
151 |
+
# Tokenize Data
|
152 |
+
train_dataset = tokenize_data(train_data, tokenizer)
|
153 |
+
dev_dataset = tokenize_data(dev_data, tokenizer)
|
154 |
+
|
155 |
+
# # Convert to Hugging Face Dataset format
|
156 |
+
train_dataset = Dataset.from_dict(train_dataset)
|
157 |
+
dev_dataset = Dataset.from_dict(dev_dataset)
|
158 |
+
|
159 |
+
# # # Step 5: Define Training Arguments
|
160 |
+
# training_args = Seq2SeqTrainingArguments(
|
161 |
+
# output_dir="./t5_sql_finetuned",
|
162 |
+
# evaluation_strategy="steps",
|
163 |
+
# save_steps=1000,
|
164 |
+
# eval_steps=100,
|
165 |
+
# logging_steps=100,
|
166 |
+
# per_device_train_batch_size=16,
|
167 |
+
# per_device_eval_batch_size=16,
|
168 |
+
# num_train_epochs=3,
|
169 |
+
# save_total_limit=2,
|
170 |
+
# learning_rate=5e-5,
|
171 |
+
# predict_with_generate=True,
|
172 |
+
# fp16=torch.cuda.is_available(), # Enable mixed precision for faster training
|
173 |
+
# logging_dir="./logs",
|
174 |
+
# )
|
175 |
+
|
176 |
+
# # # Step 6: Define Trainer
|
177 |
+
# trainer = Seq2SeqTrainer(
|
178 |
+
# model=model,
|
179 |
+
# args=training_args,
|
180 |
+
# train_dataset=train_dataset,
|
181 |
+
# eval_dataset=dev_dataset,
|
182 |
+
# tokenizer=tokenizer,
|
183 |
+
# )
|
184 |
+
|
185 |
+
# # # Step 7: Train the Model
|
186 |
+
# trainer.train()
|
187 |
+
|
188 |
+
# # # Step 8: Save the Model
|
189 |
+
# trainer.save_model("./t5_sql_finetuned")
|
190 |
+
# tokenizer.save_pretrained("./t5_sql_finetuned")
|
191 |
+
|
192 |
+
# # Step 9: Test the Model
|
193 |
+
test_question = "Find all orders with product_id greater than 5."
|
194 |
+
input_text = f"Question: {test_question}"
|
195 |
+
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
|
196 |
+
|
197 |
+
outputs = model.generate(**inputs, max_length=128)
|
198 |
+
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
199 |
+
print("Generated SQL:", generated_sql)
|