feat(text-to-SQL.py): add part 2
Browse files
multi-agents-analysis/text_to_SQL.py
CHANGED
@@ -9,14 +9,28 @@ from sqlalchemy import (
|
|
9 |
insert,
|
10 |
text
|
11 |
)
|
12 |
-
|
|
|
|
|
13 |
from llama_index.core.query_engine import NLSQLTableQueryEngine
|
14 |
from llama_index.llms.ollama import Ollama
|
15 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
16 |
from llama_index.core import Settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
llm = Ollama(model="phi3", request_timeout=360.0)
|
19 |
|
|
|
|
|
20 |
Settings.embed_model = HuggingFaceEmbedding(
|
21 |
model_name="BAAI/bge-small-en-v1.5"
|
22 |
)
|
@@ -69,12 +83,39 @@ with engine.connect() as connection:
|
|
69 |
|
70 |
# --------------------------------
|
71 |
# Part 1: Text-to-SQL Query Engine
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
)
|
75 |
query_str = "Which city has the highest population?"
|
76 |
response = query_engine.query(query_str)
|
77 |
-
|
78 |
-
print("\n
|
79 |
-
print(
|
80 |
-
|
|
|
|
|
|
9 |
insert,
|
10 |
text
|
11 |
)
|
12 |
+
import logging
|
13 |
+
import sys
|
14 |
+
from llama_index.core import SQLDatabase, VectorStoreIndex
|
15 |
from llama_index.core.query_engine import NLSQLTableQueryEngine
|
16 |
from llama_index.llms.ollama import Ollama
|
17 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
18 |
from llama_index.core import Settings
|
19 |
+
from llama_index.core.indices.struct_store.sql_query import (
|
20 |
+
SQLTableRetrieverQueryEngine,
|
21 |
+
)
|
22 |
+
from llama_index.core.objects import (
|
23 |
+
SQLTableNodeMapping,
|
24 |
+
ObjectIndex,
|
25 |
+
SQLTableSchema,
|
26 |
+
)
|
27 |
+
from rich import print
|
28 |
+
|
29 |
|
30 |
llm = Ollama(model="phi3", request_timeout=360.0)
|
31 |
|
32 |
+
Settings.llm = llm
|
33 |
+
|
34 |
Settings.embed_model = HuggingFaceEmbedding(
|
35 |
model_name="BAAI/bge-small-en-v1.5"
|
36 |
)
|
|
|
83 |
|
84 |
# --------------------------------
|
85 |
# Part 1: Text-to-SQL Query Engine
|
86 |
+
# --------------------------------
|
87 |
+
# Part 2: Query-Time Retrieval of Tables for Text-to-SQL
|
88 |
+
|
89 |
+
# set Logging to DEBUG for more detailed outputs
|
90 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
91 |
+
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
92 |
+
|
93 |
+
# manually set context text
|
94 |
+
city_stats_text = (
|
95 |
+
"This table gives information regarding the population and country of a"
|
96 |
+
" given city.\nThe user will query with codewords, where 'foo' corresponds"
|
97 |
+
" to population and 'bar' corresponds to city."
|
98 |
+
)
|
99 |
+
|
100 |
+
table_node_mapping = SQLTableNodeMapping(sql_database)
|
101 |
+
table_schema_objs = [
|
102 |
+
(SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
|
103 |
+
] # add a SQLTableSchema for each table
|
104 |
+
|
105 |
+
obj_index = ObjectIndex.from_objects(
|
106 |
+
table_schema_objs,
|
107 |
+
table_node_mapping,
|
108 |
+
VectorStoreIndex,
|
109 |
+
)
|
110 |
+
|
111 |
+
query_engine = SQLTableRetrieverQueryEngine(
|
112 |
+
sql_database, obj_index.as_retriever(similarity_top_k=1, llm=llm)
|
113 |
)
|
114 |
query_str = "Which city has the highest population?"
|
115 |
response = query_engine.query(query_str)
|
116 |
+
print("[dark_magenta on grey7]\n\n---------Part 2-----------------\n[/dark_magenta on grey7]")
|
117 |
+
print(f"[chartreuse3 on grey7]Question: {query_str}[/chartreuse3 on grey7]\n")
|
118 |
+
print(
|
119 |
+
f"[bold chartreuse1 on grey7]Response: {response}[/bold chartreuse1 on grey7]\n")
|
120 |
+
print(
|
121 |
+
f"[chartreuse3 on grey7]metadata: {response.metadata}[/chartreuse3 on grey7]\n")
|