Draichi commited on
Commit
10aa4d2
·
unverified ·
1 Parent(s): da086af

feat: add text to SQL retriever

Browse files
Files changed (1) hide show
  1. multi-agents-analysis/text_to_SQL.py +94 -32
multi-agents-analysis/text_to_SQL.py CHANGED
@@ -12,10 +12,11 @@ from sqlalchemy import (
12
  import logging
13
  import sys
14
  from llama_index.core import SQLDatabase, VectorStoreIndex
15
- from llama_index.core.query_engine import NLSQLTableQueryEngine
16
  from llama_index.llms.ollama import Ollama
17
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
18
  from llama_index.core import Settings
 
19
  from llama_index.core.indices.struct_store.sql_query import (
20
  SQLTableRetrieverQueryEngine,
21
  )
@@ -70,19 +71,29 @@ for row in rows:
70
  cursor = connection.execute(stmt)
71
 
72
  # view current table
73
- stmt = select(
74
- city_stats_table.c.city_name,
75
- city_stats_table.c.population,
76
- city_stats_table.c.country,
77
- ).select_from(city_stats_table)
78
-
79
- with engine.connect() as connection:
80
- results = connection.execute(stmt).fetchall()
81
- print(f"results: {results}")
82
  # Finish Define SQL Database
83
 
84
  # --------------------------------
85
  # Part 1: Text-to-SQL Query Engine
 
 
 
 
 
 
 
 
 
 
86
  # --------------------------------
87
  # Part 2: Query-Time Retrieval of Tables for Text-to-SQL
88
 
@@ -91,31 +102,82 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
91
  logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
92
 
93
  # manually set context text
94
- city_stats_text = (
95
- "This table gives information regarding the population and country of a"
96
- " given city.\nThe user will query with codewords, where 'foo' corresponds"
97
- " to population and 'bar' corresponds to city."
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- table_node_mapping = SQLTableNodeMapping(sql_database)
101
- table_schema_objs = [
102
- (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
103
- ] # add a SQLTableSchema for each table
104
 
105
- obj_index = ObjectIndex.from_objects(
106
- table_schema_objs,
107
- table_node_mapping,
108
- VectorStoreIndex,
109
  )
110
 
111
- query_engine = SQLTableRetrieverQueryEngine(
112
- sql_database, obj_index.as_retriever(similarity_top_k=1, llm=llm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  )
114
- query_str = "Which city has the highest population?"
115
- response = query_engine.query(query_str)
116
- print("[dark_magenta on grey7]\n\n---------Part 2-----------------\n[/dark_magenta on grey7]")
117
- print(f"[chartreuse3 on grey7]Question: {query_str}[/chartreuse3 on grey7]\n")
118
- print(
119
- f"[bold chartreuse1 on grey7]Response: {response}[/bold chartreuse1 on grey7]\n")
120
  print(
121
- f"[chartreuse3 on grey7]metadata: {response.metadata}[/chartreuse3 on grey7]\n")
 
 
 
 
 
 
 
12
  import logging
13
  import sys
14
  from llama_index.core import SQLDatabase, VectorStoreIndex
15
+ from llama_index.core.query_engine import NLSQLTableQueryEngine, RetrieverQueryEngine
16
  from llama_index.llms.ollama import Ollama
17
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
18
  from llama_index.core import Settings
19
+ from llama_index.core.retrievers import NLSQLRetriever
20
  from llama_index.core.indices.struct_store.sql_query import (
21
  SQLTableRetrieverQueryEngine,
22
  )
 
71
  cursor = connection.execute(stmt)
72
 
73
  # view current table
74
+ # stmt = select(
75
+ # city_stats_table.c.city_name,
76
+ # city_stats_table.c.population,
77
+ # city_stats_table.c.country,
78
+ # ).select_from(city_stats_table)
79
+
80
+ # with engine.connect() as connection:
81
+ # results = connection.execute(stmt).fetchall()
82
+ # print(f"results: {results}")
83
  # Finish Define SQL Database
84
 
85
  # --------------------------------
86
  # Part 1: Text-to-SQL Query Engine
87
+ # query_engine = NLSQLTableQueryEngine(
88
+ # sql_database=sql_database, tables=["city_stats"], llm=llm
89
+ # )
90
+ # query_str = "Which city has the highest population?"
91
+ # response = query_engine.query(query_str)
92
+
93
+ # print("\n-----------Part 1---------------\n")
94
+ # print(f"Question: {query_str}\n")
95
+ # print(f"Response: {response}")
96
+
97
  # --------------------------------
98
  # Part 2: Query-Time Retrieval of Tables for Text-to-SQL
99
 
 
102
  logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
103
 
104
  # manually set context text
105
+ # city_stats_text = (
106
+ # "This table gives information regarding the population and country of a"
107
+ # " given city.\nThe user will query with codewords, where 'foo' corresponds"
108
+ # " to population and 'bar' corresponds to city."
109
+ # )
110
+
111
+ # table_node_mapping = SQLTableNodeMapping(sql_database)
112
+ # table_schema_objs = [
113
+ # (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
114
+ # ] # add a SQLTableSchema for each table
115
+
116
+ # obj_index = ObjectIndex.from_objects(
117
+ # table_schema_objs,
118
+ # table_node_mapping,
119
+ # VectorStoreIndex,
120
+ # )
121
+
122
+ # query_engine = SQLTableRetrieverQueryEngine(
123
+ # sql_database, obj_index.as_retriever(similarity_top_k=1, llm=llm)
124
+ # )
125
+ # query_str = "Which city has the highest population?"
126
+ # response = query_engine.query(query_str)
127
+ # print("[dark_magenta on grey7]\n\n---------Part 2-----------------\n[/dark_magenta on grey7]")
128
+ # print(f"[chartreuse3 on grey7]Question: {query_str}[/chartreuse3 on grey7]\n")
129
+ # print(
130
+ # f"[bold chartreuse1 on grey7]Response: {response}[/bold chartreuse1 on grey7]\n")
131
+ # print(
132
+ # f"[chartreuse3 on grey7]metadata: {response.metadata}[/chartreuse3 on grey7]\n")
133
+
134
+
135
+ # --------------------------------
136
+ # Part 3: Text-to-SQL Retriever
137
+ """
138
+ So far our text-to-SQL capability is packaged in a
139
+ query engine and consists of both retrieval and synthesis.
140
+
141
+ You can use the SQL retriever on its own.
142
+ We show you some different parameters you can try,
143
+ and also show how to plug it into our RetrieverQueryEngine
144
+ to get roughly the same results.
145
+ """
146
 
 
 
 
 
147
 
148
+ # default retrieval (return_raw=True)
149
+ nl_sql_retriever = NLSQLRetriever(
150
+ sql_database, tables=["city_stats"], return_raw=True
 
151
  )
152
 
153
+ # results = nl_sql_retriever.retrieve(
154
+ # "Return the top 5 cities (along with their populations) with the highest population."
155
+ # )
156
+
157
+ # print(len(results))
158
+ # for n in results:
159
+ # print(
160
+ # f"[bold chartreuse1 on grey7]> n: {n.metadata}[/bold chartreuse1 on grey7]\n")
161
+
162
+ """
163
+ # Plug into our RetrieverQueryEngine
164
+
165
+ We compose our SQL Retriever with our standard RetrieverQueryEngine
166
+ to synthesize a response. The result is roughly similar
167
+ to our packaged Text-to-SQL query engines.
168
+ """
169
+
170
+ query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)
171
+
172
+ response = query_engine.query(
173
+ "Return the top 5 cities (along with their populations) with the highest population."
174
  )
175
+
 
 
 
 
 
176
  print(
177
+ f"[bold chartreuse1 on grey7]> Response: {str(response)}[/bold chartreuse1 on grey7]\n")
178
+ """"
179
+ > Response: Tokyo - 13,960,000
180
+ Seoul - 9,776,000
181
+ Toronto - 2,930,000
182
+ Chicago - 2,679,000
183
+ """