Spaces:
Sleeping
Sleeping
added app
Browse files
app.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import requests
|
3 |
+
|
4 |
+
def generate_prompt(question, schema):
|
5 |
+
instruction = f"""Your task is to generate valid duckdb SQL to answer the following question{"" if (schema == "") else ", given a duckdb database schema."}"""
|
6 |
+
input_text = f"""
|
7 |
+
{schema}
|
8 |
+
|
9 |
+
Generate a SQL query that answers the question `{question}`.
|
10 |
+
"""
|
11 |
+
return f"""### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n### Response:\n"""
|
12 |
+
|
13 |
+
|
14 |
+
def generate_sql(question, schema):
|
15 |
+
prompt = generate_prompt(question, schema)
|
16 |
+
s = requests.Session()
|
17 |
+
api_base = "https://text-motherduck-cl7b-chat-fp16-4vycuix6qcp2.octoai.run/v1"
|
18 |
+
url = f"{api_base}/completions"
|
19 |
+
body = {
|
20 |
+
"model": "motherduck-sql-fp16",
|
21 |
+
"prompt": prompt,
|
22 |
+
"temperature": 0.1,
|
23 |
+
"max_tokens": 200,
|
24 |
+
"stop":'<s>'
|
25 |
+
}
|
26 |
+
|
27 |
+
with s.post(url, json=body) as resp:
|
28 |
+
return resp.json()["choices"][0]["text"]
|
29 |
+
|
30 |
+
st.title("DuckDB-NSQL-7B Demo")
|
31 |
+
|
32 |
+
expander = st.expander("Customize Schema (Optional)")
|
33 |
+
expander.text("Execute this query in your DuckDB database to get your current schema:")
|
34 |
+
expander.code("SELECT array_to_string(list(sql), '\\n') from duckdb_tables()", language="sql")
|
35 |
+
|
36 |
+
# Input field for text prompt
|
37 |
+
default_schema = 'CREATE TABLE hn.hacker_news(title VARCHAR, url VARCHAR, "text" VARCHAR, dead BOOLEAN, "by" VARCHAR, score BIGINT, "time" BIGINT, "timestamp" TIMESTAMP, "type" VARCHAR, id BIGINT, parent BIGINT, descendants BIGINT, ranking BIGINT, deleted BOOLEAN);\nCREATE TABLE nyc.rideshare(hvfhs_license_num VARCHAR, dispatching_base_num VARCHAR, originating_base_num VARCHAR, request_datetime TIMESTAMP, on_scene_datetime TIMESTAMP, pickup_datetime TIMESTAMP, dropoff_datetime TIMESTAMP, PULocationID BIGINT, DOLocationID BIGINT, trip_miles DOUBLE, trip_time BIGINT, base_passenger_fare DOUBLE, tolls DOUBLE, bcf DOUBLE, sales_tax DOUBLE, congestion_surcharge DOUBLE, airport_fee DOUBLE, tips DOUBLE, driver_pay DOUBLE, shared_request_flag VARCHAR, shared_match_flag VARCHAR, access_a_ride_flag VARCHAR, wav_request_flag VARCHAR, wav_match_flag VARCHAR);\nCREATE TABLE nyc.taxi(VendorID BIGINT, tpep_pickup_datetime TIMESTAMP, tpep_dropoff_datetime TIMESTAMP, passenger_count DOUBLE, trip_distance DOUBLE, RatecodeID DOUBLE, store_and_fwd_flag VARCHAR, PULocationID BIGINT, DOLocationID BIGINT, payment_type BIGINT, fare_amount DOUBLE, extra DOUBLE, mta_tax DOUBLE, tip_amount DOUBLE, tolls_amount DOUBLE, improvement_surcharge DOUBLE, total_amount DOUBLE, congestion_surcharge DOUBLE, airport_fee DOUBLE);\nCREATE TABLE nyc.service_requests(unique_key BIGINT, created_date TIMESTAMP, closed_date TIMESTAMP, agency VARCHAR, agency_name VARCHAR, complaint_type VARCHAR, descriptor VARCHAR, location_type VARCHAR, incident_zip VARCHAR, incident_address VARCHAR, street_name VARCHAR, cross_street_1 VARCHAR, cross_street_2 VARCHAR, intersection_street_1 VARCHAR, intersection_street_2 VARCHAR, address_type VARCHAR, city VARCHAR, landmark VARCHAR, facility_type VARCHAR, status VARCHAR, due_date TIMESTAMP, resolution_description VARCHAR, resolution_action_updated_date TIMESTAMP, community_board VARCHAR, bbl VARCHAR, borough VARCHAR, x_coordinate_state_plane VARCHAR, y_coordinate_state_plane VARCHAR, open_data_channel_type VARCHAR, park_facility_name VARCHAR, park_borough VARCHAR, vehicle_type VARCHAR, taxi_company_borough VARCHAR, taxi_pick_up_location VARCHAR, bridge_highway_name VARCHAR, bridge_highway_direction VARCHAR, road_ramp VARCHAR, bridge_highway_segment VARCHAR, latitude DOUBLE, longitude DOUBLE);\nCREATE TABLE who.ambient_air_quality(who_region VARCHAR, iso3 VARCHAR, country_name VARCHAR, city VARCHAR, "year" BIGINT, "version" VARCHAR, pm10_concentration BIGINT, pm25_concentration BIGINT, no2_concentration BIGINT, pm10_tempcov BIGINT, pm25_tempcov BIGINT, no2_tempcov BIGINT, type_of_stations VARCHAR, reference VARCHAR, web_link VARCHAR, population VARCHAR, population_source VARCHAR, latitude FLOAT, longitude FLOAT, who_ms BIGINT)'
|
38 |
+
schema = expander.text_input("Current schema:", value=default_schema)
|
39 |
+
|
40 |
+
# Input field for text prompt
|
41 |
+
text_prompt = st.text_input("What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv")
|
42 |
+
|
43 |
+
if text_prompt:
|
44 |
+
sql_query = generate_sql(text_prompt, schema)
|
45 |
+
st.code(sql_query, language="sql")
|
46 |
+
|