Draichi commited on
Commit
c8e83ce
·
unverified ·
1 Parent(s): c1981ee

feat(text-to-SQL.py): add part 2

Browse files
multi-agents-analysis/text_to_SQL.py CHANGED
@@ -9,14 +9,28 @@ from sqlalchemy import (
9
  insert,
10
  text
11
  )
12
- from llama_index.core import SQLDatabase
 
 
13
  from llama_index.core.query_engine import NLSQLTableQueryEngine
14
  from llama_index.llms.ollama import Ollama
15
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
16
  from llama_index.core import Settings
 
 
 
 
 
 
 
 
 
 
17
 
18
  llm = Ollama(model="phi3", request_timeout=360.0)
19
 
 
 
20
  Settings.embed_model = HuggingFaceEmbedding(
21
  model_name="BAAI/bge-small-en-v1.5"
22
  )
@@ -69,12 +83,39 @@ with engine.connect() as connection:
69
 
70
  # --------------------------------
71
  # Part 1: Text-to-SQL Query Engine
72
- query_engine = NLSQLTableQueryEngine(
73
- sql_database=sql_database, tables=["city_stats"], llm=llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
  query_str = "Which city has the highest population?"
76
  response = query_engine.query(query_str)
77
-
78
- print("\n--------------------------\n")
79
- print(f"Question: {query_str}\n")
80
- print(f"Response: {response}")
 
 
 
9
  insert,
10
  text
11
  )
12
+ import logging
13
+ import sys
14
+ from llama_index.core import SQLDatabase, VectorStoreIndex
15
  from llama_index.core.query_engine import NLSQLTableQueryEngine
16
  from llama_index.llms.ollama import Ollama
17
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
18
  from llama_index.core import Settings
19
+ from llama_index.core.indices.struct_store.sql_query import (
20
+ SQLTableRetrieverQueryEngine,
21
+ )
22
+ from llama_index.core.objects import (
23
+ SQLTableNodeMapping,
24
+ ObjectIndex,
25
+ SQLTableSchema,
26
+ )
27
+ from rich import print
28
+
29
 
30
  llm = Ollama(model="phi3", request_timeout=360.0)
31
 
32
+ Settings.llm = llm
33
+
34
  Settings.embed_model = HuggingFaceEmbedding(
35
  model_name="BAAI/bge-small-en-v1.5"
36
  )
 
83
 
84
  # --------------------------------
85
  # Part 1: Text-to-SQL Query Engine
86
+ # --------------------------------
87
+ # Part 2: Query-Time Retrieval of Tables for Text-to-SQL
88
+
89
+ # set Logging to DEBUG for more detailed outputs
90
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
91
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
92
+
93
+ # manually set context text
94
+ city_stats_text = (
95
+ "This table gives information regarding the population and country of a"
96
+ " given city.\nThe user will query with codewords, where 'foo' corresponds"
97
+ " to population and 'bar' corresponds to city."
98
+ )
99
+
100
+ table_node_mapping = SQLTableNodeMapping(sql_database)
101
+ table_schema_objs = [
102
+ (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
103
+ ] # add a SQLTableSchema for each table
104
+
105
+ obj_index = ObjectIndex.from_objects(
106
+ table_schema_objs,
107
+ table_node_mapping,
108
+ VectorStoreIndex,
109
+ )
110
+
111
+ query_engine = SQLTableRetrieverQueryEngine(
112
+ sql_database, obj_index.as_retriever(similarity_top_k=1, llm=llm)
113
  )
114
  query_str = "Which city has the highest population?"
115
  response = query_engine.query(query_str)
116
+ print("[dark_magenta on grey7]\n\n---------Part 2-----------------\n[/dark_magenta on grey7]")
117
+ print(f"[chartreuse3 on grey7]Question: {query_str}[/chartreuse3 on grey7]\n")
118
+ print(
119
+ f"[bold chartreuse1 on grey7]Response: {response}[/bold chartreuse1 on grey7]\n")
120
+ print(
121
+ f"[chartreuse3 on grey7]metadata: {response.metadata}[/chartreuse3 on grey7]\n")