feat: add text to SQL retriever
Browse files
multi-agents-analysis/text_to_SQL.py
CHANGED
@@ -12,10 +12,11 @@ from sqlalchemy import (
|
|
12 |
import logging
|
13 |
import sys
|
14 |
from llama_index.core import SQLDatabase, VectorStoreIndex
|
15 |
-
from llama_index.core.query_engine import NLSQLTableQueryEngine
|
16 |
from llama_index.llms.ollama import Ollama
|
17 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
18 |
from llama_index.core import Settings
|
|
|
19 |
from llama_index.core.indices.struct_store.sql_query import (
|
20 |
SQLTableRetrieverQueryEngine,
|
21 |
)
|
@@ -70,19 +71,29 @@ for row in rows:
|
|
70 |
cursor = connection.execute(stmt)
|
71 |
|
72 |
# view current table
|
73 |
-
stmt = select(
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
).select_from(city_stats_table)
|
78 |
-
|
79 |
-
with engine.connect() as connection:
|
80 |
-
|
81 |
-
|
82 |
# Finish Define SQL Database
|
83 |
|
84 |
# --------------------------------
|
85 |
# Part 1: Text-to-SQL Query Engine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
# --------------------------------
|
87 |
# Part 2: Query-Time Retrieval of Tables for Text-to-SQL
|
88 |
|
@@ -91,31 +102,82 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
|
91 |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
92 |
|
93 |
# manually set context text
|
94 |
-
city_stats_text = (
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
table_node_mapping = SQLTableNodeMapping(sql_database)
|
101 |
-
table_schema_objs = [
|
102 |
-
(SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
|
103 |
-
] # add a SQLTableSchema for each table
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
VectorStoreIndex,
|
109 |
)
|
110 |
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
)
|
114 |
-
|
115 |
-
response = query_engine.query(query_str)
|
116 |
-
print("[dark_magenta on grey7]\n\n---------Part 2-----------------\n[/dark_magenta on grey7]")
|
117 |
-
print(f"[chartreuse3 on grey7]Question: {query_str}[/chartreuse3 on grey7]\n")
|
118 |
-
print(
|
119 |
-
f"[bold chartreuse1 on grey7]Response: {response}[/bold chartreuse1 on grey7]\n")
|
120 |
print(
|
121 |
-
f"[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import logging
|
13 |
import sys
|
14 |
from llama_index.core import SQLDatabase, VectorStoreIndex
|
15 |
+
from llama_index.core.query_engine import NLSQLTableQueryEngine, RetrieverQueryEngine
|
16 |
from llama_index.llms.ollama import Ollama
|
17 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
18 |
from llama_index.core import Settings
|
19 |
+
from llama_index.core.retrievers import NLSQLRetriever
|
20 |
from llama_index.core.indices.struct_store.sql_query import (
|
21 |
SQLTableRetrieverQueryEngine,
|
22 |
)
|
|
|
71 |
cursor = connection.execute(stmt)
|
72 |
|
73 |
# view current table
|
74 |
+
# stmt = select(
|
75 |
+
# city_stats_table.c.city_name,
|
76 |
+
# city_stats_table.c.population,
|
77 |
+
# city_stats_table.c.country,
|
78 |
+
# ).select_from(city_stats_table)
|
79 |
+
|
80 |
+
# with engine.connect() as connection:
|
81 |
+
# results = connection.execute(stmt).fetchall()
|
82 |
+
# print(f"results: {results}")
|
83 |
# Finish Define SQL Database
|
84 |
|
85 |
# --------------------------------
|
86 |
# Part 1: Text-to-SQL Query Engine
|
87 |
+
# query_engine = NLSQLTableQueryEngine(
|
88 |
+
# sql_database=sql_database, tables=["city_stats"], llm=llm
|
89 |
+
# )
|
90 |
+
# query_str = "Which city has the highest population?"
|
91 |
+
# response = query_engine.query(query_str)
|
92 |
+
|
93 |
+
# print("\n-----------Part 1---------------\n")
|
94 |
+
# print(f"Question: {query_str}\n")
|
95 |
+
# print(f"Response: {response}")
|
96 |
+
|
97 |
# --------------------------------
|
98 |
# Part 2: Query-Time Retrieval of Tables for Text-to-SQL
|
99 |
|
|
|
102 |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
103 |
|
104 |
# manually set context text
|
105 |
+
# city_stats_text = (
|
106 |
+
# "This table gives information regarding the population and country of a"
|
107 |
+
# " given city.\nThe user will query with codewords, where 'foo' corresponds"
|
108 |
+
# " to population and 'bar' corresponds to city."
|
109 |
+
# )
|
110 |
+
|
111 |
+
# table_node_mapping = SQLTableNodeMapping(sql_database)
|
112 |
+
# table_schema_objs = [
|
113 |
+
# (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
|
114 |
+
# ] # add a SQLTableSchema for each table
|
115 |
+
|
116 |
+
# obj_index = ObjectIndex.from_objects(
|
117 |
+
# table_schema_objs,
|
118 |
+
# table_node_mapping,
|
119 |
+
# VectorStoreIndex,
|
120 |
+
# )
|
121 |
+
|
122 |
+
# query_engine = SQLTableRetrieverQueryEngine(
|
123 |
+
# sql_database, obj_index.as_retriever(similarity_top_k=1, llm=llm)
|
124 |
+
# )
|
125 |
+
# query_str = "Which city has the highest population?"
|
126 |
+
# response = query_engine.query(query_str)
|
127 |
+
# print("[dark_magenta on grey7]\n\n---------Part 2-----------------\n[/dark_magenta on grey7]")
|
128 |
+
# print(f"[chartreuse3 on grey7]Question: {query_str}[/chartreuse3 on grey7]\n")
|
129 |
+
# print(
|
130 |
+
# f"[bold chartreuse1 on grey7]Response: {response}[/bold chartreuse1 on grey7]\n")
|
131 |
+
# print(
|
132 |
+
# f"[chartreuse3 on grey7]metadata: {response.metadata}[/chartreuse3 on grey7]\n")
|
133 |
+
|
134 |
+
|
135 |
+
# --------------------------------
|
136 |
+
# Part 3: Text-to-SQL Retriever
|
137 |
+
"""
|
138 |
+
So far our text-to-SQL capability is packaged in a
|
139 |
+
query engine and consists of both retrieval and synthesis.
|
140 |
+
|
141 |
+
You can use the SQL retriever on its own.
|
142 |
+
We show you some different parameters you can try,
|
143 |
+
and also show how to plug it into our RetrieverQueryEngine
|
144 |
+
to get roughly the same results.
|
145 |
+
"""
|
146 |
|
|
|
|
|
|
|
|
|
147 |
|
148 |
+
# default retrieval (return_raw=True)
|
149 |
+
nl_sql_retriever = NLSQLRetriever(
|
150 |
+
sql_database, tables=["city_stats"], return_raw=True
|
|
|
151 |
)
|
152 |
|
153 |
+
# results = nl_sql_retriever.retrieve(
|
154 |
+
# "Return the top 5 cities (along with their populations) with the highest population."
|
155 |
+
# )
|
156 |
+
|
157 |
+
# print(len(results))
|
158 |
+
# for n in results:
|
159 |
+
# print(
|
160 |
+
# f"[bold chartreuse1 on grey7]> n: {n.metadata}[/bold chartreuse1 on grey7]\n")
|
161 |
+
|
162 |
+
"""
|
163 |
+
# Plug into our RetrieverQueryEngine
|
164 |
+
|
165 |
+
We compose our SQL Retriever with our standard RetrieverQueryEngine
|
166 |
+
to synthesize a response. The result is roughly similar
|
167 |
+
to our packaged Text-to-SQL query engines.
|
168 |
+
"""
|
169 |
+
|
170 |
+
query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)
|
171 |
+
|
172 |
+
response = query_engine.query(
|
173 |
+
"Return the top 5 cities (along with their populations) with the highest population."
|
174 |
)
|
175 |
+
|
|
|
|
|
|
|
|
|
|
|
176 |
print(
|
177 |
+
f"[bold chartreuse1 on grey7]> Response: {str(response)}[/bold chartreuse1 on grey7]\n")
|
178 |
+
""""
|
179 |
+
> Response: Tokyo - 13,960,000
|
180 |
+
Seoul - 9,776,000
|
181 |
+
Toronto - 2,930,000
|
182 |
+
Chicago - 2,679,000
|
183 |
+
"""
|