Manoj Kumar commited on
Commit
e6f4fec
·
1 Parent(s): 867cb42

Mark POhase 1

Browse files
Files changed (8) hide show
  1. .DS_Store +0 -0
  2. Mark-1/db_creation.py +91 -0
  3. Mark-1/phase1.py +274 -0
  4. db.py +1 -1
  5. ecommerce.db +0 -0
  6. requirements.txt +4 -1
  7. wikiPreTrained.py +121 -0
  8. wikiSQL.py +199 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Mark-1/db_creation.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import random
3
+ from faker import Faker
4
+
5
+ # Initialize Faker for generating random data
6
+ fake = Faker()
7
+
8
+ # Define custom schema
9
+ custom_schema = {
10
+ "products": {
11
+ "columns": ["product_id INTEGER PRIMARY KEY", "name TEXT", "price REAL", "category_id INTEGER"],
12
+ "relations": ["category_id -> categories.id"],
13
+ },
14
+ "categories": {
15
+ "columns": ["id INTEGER PRIMARY KEY", "category_name TEXT"],
16
+ "relations": None,
17
+ },
18
+ "orders": {
19
+ "columns": ["order_id INTEGER PRIMARY KEY", "user_id INTEGER", "product_id INTEGER", "order_date TEXT"],
20
+ "relations": ["product_id -> products.product_id", "user_id -> users.user_id"],
21
+ },
22
+ "users": {
23
+ "columns": [
24
+ "user_id INTEGER PRIMARY KEY",
25
+ "first_name TEXT",
26
+ "last_name TEXT",
27
+ "email TEXT UNIQUE",
28
+ "phone_number TEXT",
29
+ "address TEXT"
30
+ ],
31
+ "relations": None,
32
+ }
33
+ }
34
+
35
+ # Connect to SQLite database
36
+ conn = sqlite3.connect("ecommerce.db")
37
+ cursor = conn.cursor()
38
+
39
+ # Function to create tables based on schema
40
+ def create_tables():
41
+ for table_name, table_data in custom_schema.items():
42
+ columns = ", ".join(table_data["columns"])
43
+ table_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns}"
44
+
45
+ if table_data["relations"]:
46
+ for relation in table_data["relations"]:
47
+ col_name, ref_table = relation.split(" -> ")
48
+ ref_col = ref_table.split(".")[1]
49
+ ref_table_name = ref_table.split(".")[0]
50
+ table_sql += f", FOREIGN KEY({col_name}) REFERENCES {ref_table_name}({ref_col})"
51
+
52
+ table_sql += ");"
53
+ cursor.execute(table_sql)
54
+
55
+ # Function to populate categories table
56
+ def insert_categories():
57
+ categories = [(i, fake.word().capitalize() + " " + fake.word().capitalize()) for i in range(1, 1001)]
58
+ cursor.executemany("INSERT INTO categories (id, category_name) VALUES (?, ?)", categories)
59
+ return categories
60
+
61
+ # Function to populate products table
62
+ def insert_products(categories):
63
+ products = [(i, fake.company() + " " + fake.word().capitalize(), round(random.uniform(10, 1000), 2), random.choice(categories)[0]) for i in range(1, 1001)]
64
+ cursor.executemany("INSERT INTO products (product_id, name, price, category_id) VALUES (?, ?, ?, ?)", products)
65
+ return products
66
+
67
+ # Function to populate users table
68
+ def insert_users():
69
+ users = [(i, fake.first_name(), fake.last_name(), fake.email(), fake.phone_number(), fake.address()) for i in range(1, 1001)]
70
+ cursor.executemany("INSERT OR IGNORE INTO users (user_id, first_name, last_name, email, phone_number, address) VALUES (?, ?, ?, ?, ?, ?)", users)
71
+ return users
72
+
73
+ # Function to populate orders table
74
+ def insert_orders(users, products):
75
+ orders = [(i, random.choice(users)[0], random.choice(products)[0], fake.date_this_year().strftime("%Y-%m-%d")) for i in range(1, 1001)]
76
+ cursor.executemany("INSERT INTO orders (order_id, user_id, product_id, order_date) VALUES (?, ?, ?, ?)", orders)
77
+
78
+ # Create tables
79
+ create_tables()
80
+
81
+ # Insert data into tables
82
+ categories = insert_categories()
83
+ products = insert_products(categories)
84
+ users = insert_users()
85
+ insert_orders(users, products)
86
+
87
+ # Commit and close connection
88
+ conn.commit()
89
+ conn.close()
90
+
91
+ print("1000 rows inserted into each table successfully!")
Mark-1/phase1.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import spacy
3
+ import re
4
+ from thefuzz import process
5
+ import numpy as np
6
+ from transformers import pipeline
7
+
8
+ # Load intent classification model
9
+ # Use Hugging Face's zero-shot pipeline for flexibility
10
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
11
+ nlp = spacy.load("en_core_web_sm")
12
+ nlp_vectors = spacy.load("en_core_web_md")
13
+
14
+
15
+ # Define operator mappings
16
+ operator_mappings = {
17
+ "greater than": ">",
18
+ "less than": "<",
19
+ "equal to": "=",
20
+ "not equal to": "!=",
21
+ "starts with": "LIKE",
22
+ "ends with": "LIKE",
23
+ "contains": "LIKE",
24
+ "above": ">",
25
+ "below": "<",
26
+ "more than": ">",
27
+ "less than": "<",
28
+ "<": "<",
29
+ ">": ">"
30
+ }
31
+
32
+ # Connect to SQLite database
33
+ def connect_to_db(db_path):
34
+ conn = sqlite3.connect(db_path)
35
+ return conn
36
+
37
+ # Fetch database schema
38
+ def fetch_schema(conn):
39
+ cursor = conn.cursor()
40
+ query = """
41
+ SELECT name
42
+ FROM sqlite_master
43
+ WHERE type='table';
44
+ """
45
+ cursor.execute(query)
46
+ tables = cursor.fetchall()
47
+
48
+ schema = {}
49
+ for table in tables:
50
+ table_name = table[0]
51
+ cursor.execute(f"PRAGMA table_info({table_name});")
52
+ columns = cursor.fetchall()
53
+ schema[table_name] = [{"name": col[1], "type": col[2], "not_null": col[3], "default": col[4], "pk": col[5]} for col in columns]
54
+
55
+ return schema
56
+
57
+ def find_ai_synonym(token_text, table_schema):
58
+ """Return the best-matching column from table_schema based on vector similarity."""
59
+ token_vec = nlp_vectors(token_text)[0].vector
60
+ best_col = None
61
+ best_score = 0.0
62
+
63
+ for col in table_schema:
64
+ col_vec = nlp_vectors(col)[0].vector
65
+ # Cosine similarity
66
+ score = token_vec.dot(col_vec) / (np.linalg.norm(token_vec) * np.linalg.norm(col_vec))
67
+ if score > best_score:
68
+ best_score = score
69
+ best_col = col
70
+
71
+ # Apply threshold
72
+ if best_score > 0.65:
73
+ return best_col
74
+ return None
75
+
76
+ def identify_table(question, schema_tables):
77
+ # schema_tables = ["products", "users", "orders", ...]
78
+ table, score = process.extractOne(question, schema_tables)
79
+
80
+ if score > 80: # a comfortable threshold
81
+ return table
82
+ return None
83
+
84
+ def identify_columns(question, columns_for_table):
85
+ # columns_for_table = ["id", "price", "stock", "name", ...]
86
+ # For each token in question, fuzzy match to columns
87
+ matched_cols = []
88
+ tokens = question.lower().split()
89
+ for token in tokens:
90
+ col, score = process.extractOne(token, columns_for_table)
91
+ if score > 80:
92
+ matched_cols.append(col)
93
+ return matched_cols
94
+
95
+ def find_closest_column(token, table_schema):
96
+ # table_schema is a list of column names, e.g. ["price", "stock", "name"]
97
+ # This returns (best_match, score)
98
+ best_match, score = process.extractOne(token, table_schema)
99
+ # You can tune this threshold as needed (e.g. 70, 80, etc.)
100
+ if score > 90:
101
+ return best_match
102
+ return None
103
+
104
+ # Condition extraction with NLP
105
+ def extract_conditions(question, schema, table):
106
+ table_schema = [col["name"].lower() for col in schema.get(table, [])]
107
+
108
+ # Detect whether the user used 'AND' / 'OR'
109
+ # (case-insensitive, hence .lower() checks)
110
+ use_and = " and " in question.lower()
111
+ use_or = " or " in question.lower()
112
+ last_column = None
113
+
114
+ # Split on 'and' or 'or' to handle multiple conditions
115
+ condition_parts = re.split(r'\band\b|\bor\b', question, flags=re.IGNORECASE)
116
+
117
+ print(condition_parts)
118
+
119
+ conditions = []
120
+
121
+ for part in condition_parts:
122
+ part = part.strip()
123
+
124
+ # Use spaCy to tokenize each part
125
+ doc = nlp(part.lower())
126
+ tokens = [token.text for token in doc]
127
+
128
+ # Skip the recognized_table token if it appears in tokens
129
+ # so it won't be matched as a column
130
+ tokens = [t for t in tokens if t != table.lower()]
131
+
132
+ part_conditions = []
133
+ current_part_column = None
134
+
135
+ print(tokens)
136
+
137
+ for i, token in enumerate(tokens):
138
+ # Try synonyms/fuzzy, etc. to find a column
139
+ possible_col = find_ai_synonym(token, table_schema)
140
+ if possible_col:
141
+ current_part_column = possible_col
142
+ last_column = possible_col # update last_column
143
+
144
+ # Check for any matching operator phrase in this part
145
+ for phrase, sql_operator in operator_mappings.items():
146
+ if phrase in part.lower():
147
+ # Extract the value after the phrase
148
+ value_index = part.lower().find(phrase) + len(phrase)
149
+ value = part[value_index:].strip().split(" ")[0]
150
+ value = value.replace("'", "").replace('"', "").strip()
151
+
152
+ # Special handling for LIKE operators
153
+ if sql_operator == "LIKE":
154
+ if "starts with" in phrase:
155
+ value = f"'{value}%'"
156
+ elif "ends with" in phrase:
157
+ value = f"'%{value}'"
158
+ elif "contains" in phrase:
159
+ value = f"'%{value}%'"
160
+
161
+ # If we did not find a new column, fallback to last_column
162
+ column_to_use = current_part_column or last_column
163
+ if column_to_use:
164
+ # Add this condition to the list for this part
165
+ part_conditions.append(f"{column_to_use} {sql_operator} {value}")
166
+
167
+ # If multiple conditions are found in this part, join them with AND
168
+ # (e.g., "price > 100 AND stock < 50" within the same part)
169
+ if part_conditions:
170
+ conditions.append(" AND ".join(part_conditions))
171
+
172
+ # Finally, combine each part with AND or OR, depending on the user query
173
+ if use_and:
174
+ return " AND ".join(conditions)
175
+ elif use_or:
176
+ return " OR ".join(conditions)
177
+ else:
178
+ # If there's only one part or no explicit 'and'/'or', default to AND
179
+ return " AND ".join(conditions)
180
+
181
+ # Interpret user question using intent recognition
182
+ def interpret_question(question, schema):
183
+ # Define potential intents
184
+ intents = {
185
+ "describe_table": "Provide information about the columns and structure of a table.",
186
+ "list_table_data": "Fetch and display all data stored in a table.",
187
+ "count_records": "Count the number of records in a table.",
188
+ "fetch_column": "Fetch a specific column's data from a table."
189
+ }
190
+
191
+ # Use classifier to predict intent
192
+ labels = list(intents.keys())
193
+ result = classifier(question, labels)
194
+
195
+ predicted_intent = result["labels"][0]
196
+ table = identify_table(question, list(schema.keys()))
197
+
198
+ # Rule-based fallback for conditional queries
199
+ condition_keywords = list(operator_mappings.keys())
200
+ if any(keyword in question.lower() for keyword in condition_keywords):
201
+ predicted_intent = "list_table_data"
202
+
203
+ return {"intent": predicted_intent, "table": table}
204
+
205
+ # Handle different intents
206
+ def handle_intent(intent_data, schema, conn, question):
207
+ intent = intent_data["intent"]
208
+ table = intent_data["table"]
209
+
210
+ if not table:
211
+ return "I couldn't identify which table you're referring to."
212
+
213
+ if intent == "describe_table":
214
+ # Describe table structure
215
+ table_schema = schema[table]
216
+ description = [f"Table '{table}' has the following columns:"]
217
+ for col in table_schema:
218
+ col_details = f"- {col['name']} ({col['type']})"
219
+ if col['not_null']:
220
+ col_details += " [NOT NULL]"
221
+ if col['default'] is not None:
222
+ col_details += f" [DEFAULT: {col['default']}]"
223
+ if col['pk']:
224
+ col_details += " [PRIMARY KEY]"
225
+ description.append(col_details)
226
+ return "\n".join(description)
227
+
228
+ elif intent == "list_table_data":
229
+ # Check for conditions
230
+ condition = extract_conditions(question, schema, table)
231
+ cursor = conn.cursor()
232
+ query = f"SELECT * FROM {table}"
233
+ if condition:
234
+ query += f" WHERE {condition};"
235
+ else:
236
+ query += ";"
237
+
238
+ print(query)
239
+ cursor.execute(query)
240
+ return cursor.fetchall()
241
+
242
+ elif intent == "count_records":
243
+ # Count records in the table
244
+ cursor = conn.cursor()
245
+ cursor.execute(f"SELECT COUNT(*) FROM {table};")
246
+ return cursor.fetchone()
247
+
248
+ elif intent == "fetch_column":
249
+ return "Fetching specific column data is not yet implemented."
250
+
251
+ else:
252
+ return "I couldn't understand your question."
253
+
254
+ # Main function
255
+ def answer_question(question, conn, schema):
256
+ intent_data = interpret_question(question, schema)
257
+ print(intent_data)
258
+ return handle_intent(intent_data, schema, conn, question)
259
+
260
+ # Example Usage
261
+ if __name__ == "__main__":
262
+ db_path = "./ecommerce.db" # Replace with your SQLite database path
263
+ conn = connect_to_db(db_path)
264
+ schema = fetch_schema(conn)
265
+
266
+ print("Schema:", schema)
267
+
268
+ while True:
269
+ question = input("\nAsk a question about the database: ")
270
+ if question.lower() in ["exit", "quit"]:
271
+ break
272
+
273
+ answer = answer_question(question, conn, schema)
274
+ print("Answer:", answer)
db.py CHANGED
@@ -34,7 +34,7 @@ def generate_context(schema):
34
  schema_context = generate_context(schema)
35
 
36
  # Step 2: Load the T5-base-text-to-sql model
37
- model_name = "mrm8488/t5-base-finetuned-wikiSQL" # A model fine-tuned for SQL generation
38
  tokenizer = AutoTokenizer.from_pretrained(model_name)
39
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
40
 
 
34
  schema_context = generate_context(schema)
35
 
36
  # Step 2: Load the T5-base-text-to-sql model
37
+ model_name = "suriya7/t5-base-text-to-sql" # A model fine-tuned for SQL generation
38
  tokenizer = AutoTokenizer.from_pretrained(model_name)
39
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
40
 
ecommerce.db ADDED
Binary file (258 kB). View file
 
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  transformers
2
  torch
3
  accelerate>=0.26.0
4
- tiktoken
 
 
 
 
1
  transformers
2
  torch
3
  accelerate>=0.26.0
4
+ tiktoken
5
+ datasets
6
+ sentencepiece
7
+ tqdm
wikiPreTrained.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
2
+ import torch
3
+ import re
4
+
5
+ # Load the trained model and tokenizer
6
+ model = T5ForConditionalGeneration.from_pretrained("./t5_sql_finetuned")
7
+ tokenizer = T5Tokenizer.from_pretrained("./t5_sql_finetuned")
8
+
9
+ # Define a simple function to check if the question is schema-related or SQL-related
10
+ def is_schema_question(question: str):
11
+ schema_keywords = ["columns", "tables", "structure", "schema", "relations", "fields"]
12
+ return any(keyword in question.lower() for keyword in schema_keywords)
13
+
14
+ # Helper function to extract table name from the question
15
+ def extract_table_name(question: str):
16
+ # Regex pattern to find table names, assuming table names are capitalized or match a known pattern
17
+ table_name_match = re.search(r'for (\w+)|in (\w+)|from (\w+)', question)
18
+
19
+ if table_name_match:
20
+ # Return the matched table name (first capturing group)
21
+ return table_name_match.group(1) or table_name_match.group(2) or table_name_match.group(3)
22
+
23
+ # If no table name is detected, return None
24
+ return None
25
+
26
+
27
+ # Define a function to handle SQL generation
28
+ def generate_sql(question: str, schema: dict, model, tokenizer, device):
29
+ # Preprocess the question for SQL generation (e.g., reformat)
30
+ # Example question: "What is the price of the product with ID 123?"
31
+
32
+ # Here we use the model to generate SQL query
33
+ inputs = tokenizer(question, return_tensors="pt")
34
+ input_ids = inputs.input_ids.to(device)
35
+
36
+ with torch.no_grad():
37
+ generated_ids = model.generate(input_ids, max_length=128)
38
+
39
+ # Decode the SQL query generated by the model
40
+ sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
41
+
42
+ return sql_query
43
+
44
+ # Define a function to handle schema-related questions
45
+ def handle_schema_question(question: str, schema: dict):
46
+ # Here you handle questions about the schema (tables, columns, relations)
47
+ # Example schema-related question: "What columns does the products table have?"
48
+
49
+ question = question.lower()
50
+
51
+ # Check if the question asks about columns
52
+ if "columns" in question or "fields" in question:
53
+ table_name = extract_table_name(question)
54
+ if table_name:
55
+ if table_name in schema:
56
+ return schema[table_name]["columns"]
57
+ else:
58
+ return f"Table '{table_name}' not found in the schema."
59
+
60
+ # Check if the question asks about relations
61
+ elif "relations" in question or "relationships" in question:
62
+ table_name = extract_table_name(question)
63
+ if table_name:
64
+ if table_name in schema:
65
+ return schema[table_name]["relations"]
66
+ else:
67
+ return f"Table '{table_name}' not found in the schema."
68
+
69
+ # Additional cases can be handled here (e.g., "Which tables are in the schema?")
70
+ elif "tables" in question:
71
+ return list(schema.keys())
72
+
73
+ # If the question is too vague or doesn't match the expected patterns
74
+ return "Sorry, I couldn't understand your schema question. Could you rephrase?"
75
+
76
+
77
+ # Example schema for your custom use case
78
+ custom_schema = {
79
+ "products": {
80
+ "columns": ["product_id", "name", "price", "category_id"],
81
+ "relations": "category_id -> categories.id",
82
+ },
83
+ "categories": {
84
+ "columns": ["id", "category_name"],
85
+ "relations": None,
86
+ },
87
+ "orders": {
88
+ "columns": ["order_id", "user_id", "product_id", "order_date"],
89
+ "relations": ["product_id -> products.product_id", "user_id -> users.user_id"],
90
+ },
91
+ "users": {
92
+ "columns": ["user_id", "first_name", "last_name", "email", "phone_number", "address"],
93
+ "relations": None,
94
+ }
95
+ }
96
+
97
+ def answer_question(question: str, schema: dict, model, tokenizer, device):
98
+ # First, check if the question is about the schema or SQL
99
+ if is_schema_question(question):
100
+ # Handle schema-related questions
101
+ response = handle_schema_question(question, schema)
102
+ return f"Schema Information: {response}"
103
+ else:
104
+ # Generate an SQL query for data-related questions
105
+ sql_query = generate_sql(question, schema, model, tokenizer, device)
106
+ return f"Generated SQL Query: {sql_query}"
107
+
108
+ # Example input questions
109
+ question_1 = "What columns does the products table have?"
110
+ question_2 = "What is the price of the product with product_id 123?"
111
+
112
+ # Assuming you have loaded your model and tokenizer as `model` and `tokenizer`
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+
115
+ # Handle schema question
116
+ response_1 = answer_question(question_1, custom_schema, model, tokenizer, device)
117
+ print(response_1) # This should give you the columns of the products table
118
+
119
+ # Handle SQL query question
120
+ response_2 = answer_question(question_2, custom_schema, model, tokenizer, device)
121
+ print(response_2) # This should generate an SQL query for fetching the price
wikiSQL.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import torch
4
+ from datasets import Dataset
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForSeq2SeqLM,
8
+ Seq2SeqTrainer,
9
+ Seq2SeqTrainingArguments,
10
+ )
11
+ from torch.utils.data import DataLoader
12
+ from sklearn.model_selection import train_test_split
13
+ from tqdm import tqdm
14
+
15
+
16
+ def load_table_schemas(tables_file):
17
+ """
18
+ Load table schemas from the tables.jsonl file.
19
+
20
+ Args:
21
+ tables_file: Path to the tables.jsonl file.
22
+
23
+ Returns:
24
+ A dictionary mapping table IDs to their column names.
25
+ """
26
+ table_schemas = {}
27
+ with open(tables_file, 'r') as f:
28
+ for line in f:
29
+ table_data = json.loads(line)
30
+ table_id = table_data["id"]
31
+ table_columns = table_data["header"]
32
+ table_schemas[table_id] = table_columns
33
+ return table_schemas
34
+
35
+
36
+ # Step 1: Load and Preprocess WikiSQL Data
37
+ def load_wikisql(data_dir):
38
+ """
39
+ Load WikiSQL data and prepare it for training.
40
+ Args:
41
+ data_dir: Path to the WikiSQL dataset directory.
42
+ Returns:
43
+ List of examples with input and target text.
44
+ """
45
+ def parse_file(file_path):
46
+ with open(file_path, 'r') as f:
47
+ return [json.loads(line) for line in f]
48
+
49
+ tables_data = parse_file(os.path.join(data_dir, "train.tables.jsonl"))
50
+ train_data = parse_file(os.path.join(data_dir, "train.jsonl"))
51
+ dev_data = parse_file(os.path.join(data_dir, "dev.jsonl"))
52
+
53
+ print("====>", train_data[0])
54
+ tables_file = "./data/train.tables.jsonl"
55
+ table_schemas = load_table_schemas(tables_file)
56
+
57
+ dev_tables = './data/dev.tables.jsonl'
58
+ dev_tables_schema = load_table_schemas(dev_tables)
59
+
60
+ def format_data(data, type):
61
+ formatted = []
62
+ for item in data:
63
+ table_id = item["table_id"]
64
+ table_columns = table_schemas[table_id] if type == 'train' else dev_tables_schema[table_id]
65
+ question = item["question"]
66
+ sql = item["sql"]
67
+ sql_query = sql_to_text(sql, table_columns)
68
+ print("SQL Query", sql_query)
69
+ formatted.append({"input": f"Question: {question}", "target": sql_query})
70
+ return formatted
71
+
72
+ return format_data(train_data, "train"), format_data(dev_data, "dev")
73
+
74
+
75
+ def sql_to_text(sql, table_columns):
76
+ """
77
+ Convert SQL dictionary from WikiSQL to text representation.
78
+
79
+ Args:
80
+ sql: SQL dictionary from WikiSQL (e.g., {"sel": 5, "conds": [[3, 0, "value"]], "agg": 0}).
81
+ table_columns: List of column names corresponding to the table.
82
+
83
+ Returns:
84
+ SQL query as a string.
85
+ """
86
+ # Aggregation functions mapping
87
+ agg_functions = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
88
+ operators = ["=", ">", "<"]
89
+
90
+ # Get selected column
91
+ sel_column = table_columns[sql["sel"]]
92
+ agg_func = agg_functions[sql["agg"]]
93
+ select_clause = f"SELECT {agg_func}({sel_column})" if agg_func else f"SELECT {sel_column}"
94
+
95
+ # Get conditions
96
+ if sql["conds"]:
97
+ conditions = []
98
+ for cond in sql["conds"]:
99
+ col_idx, operator, value = cond
100
+ col_name = table_columns[col_idx]
101
+ conditions.append(f"{col_name} {operators[operator]} '{value}'")
102
+ where_clause = " WHERE " + " AND ".join(conditions)
103
+ else:
104
+ where_clause = ""
105
+
106
+ # Combine clauses into a full query
107
+ return select_clause + where_clause
108
+
109
+ # Step 2: Tokenize the Data
110
+ def tokenize_data(data, tokenizer, max_length=128):
111
+ """
112
+ Tokenize the input and target text.
113
+ Args:
114
+ data: List of examples with "input" and "target".
115
+ tokenizer: Pretrained tokenizer.
116
+ max_length: Maximum sequence length for the model.
117
+ Returns:
118
+ Tokenized dataset.
119
+ """
120
+ inputs = [item["input"] for item in data]
121
+ targets = [item["target"] for item in data]
122
+
123
+ tokenized = tokenizer(
124
+ inputs,
125
+ max_length=max_length,
126
+ padding="max_length",
127
+ truncation=True,
128
+ return_tensors="pt",
129
+ )
130
+ labels = tokenizer(
131
+ targets,
132
+ max_length=max_length,
133
+ padding="max_length",
134
+ truncation=True,
135
+ return_tensors="pt",
136
+ )
137
+
138
+ tokenized["labels"] = labels["input_ids"]
139
+ return tokenized
140
+
141
+
142
+ # Step 3: Load Model and Tokenizer
143
+ model_name = "t5-small" # Use "t5-small", "t5-base", or "t5-large"
144
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
145
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
146
+
147
+ # Step 4: Prepare Training and Validation Data
148
+ data_dir = "data" # Path to the WikiSQL dataset
149
+ train_data, dev_data = load_wikisql(data_dir)
150
+
151
+ # Tokenize Data
152
+ train_dataset = tokenize_data(train_data, tokenizer)
153
+ dev_dataset = tokenize_data(dev_data, tokenizer)
154
+
155
+ # # Convert to Hugging Face Dataset format
156
+ train_dataset = Dataset.from_dict(train_dataset)
157
+ dev_dataset = Dataset.from_dict(dev_dataset)
158
+
159
+ # # # Step 5: Define Training Arguments
160
+ # training_args = Seq2SeqTrainingArguments(
161
+ # output_dir="./t5_sql_finetuned",
162
+ # evaluation_strategy="steps",
163
+ # save_steps=1000,
164
+ # eval_steps=100,
165
+ # logging_steps=100,
166
+ # per_device_train_batch_size=16,
167
+ # per_device_eval_batch_size=16,
168
+ # num_train_epochs=3,
169
+ # save_total_limit=2,
170
+ # learning_rate=5e-5,
171
+ # predict_with_generate=True,
172
+ # fp16=torch.cuda.is_available(), # Enable mixed precision for faster training
173
+ # logging_dir="./logs",
174
+ # )
175
+
176
+ # # # Step 6: Define Trainer
177
+ # trainer = Seq2SeqTrainer(
178
+ # model=model,
179
+ # args=training_args,
180
+ # train_dataset=train_dataset,
181
+ # eval_dataset=dev_dataset,
182
+ # tokenizer=tokenizer,
183
+ # )
184
+
185
+ # # # Step 7: Train the Model
186
+ # trainer.train()
187
+
188
+ # # # Step 8: Save the Model
189
+ # trainer.save_model("./t5_sql_finetuned")
190
+ # tokenizer.save_pretrained("./t5_sql_finetuned")
191
+
192
+ # # Step 9: Test the Model
193
+ test_question = "Find all orders with product_id greater than 5."
194
+ input_text = f"Question: {test_question}"
195
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
196
+
197
+ outputs = model.generate(**inputs, max_length=128)
198
+ generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
199
+ print("Generated SQL:", generated_sql)