import sqlite3 import uvicorn from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM app = FastAPI() # Load fine-tuned text-to-SQL model MODEL_NAME = "budecosystem/sql-millennials-13b" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) #AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) def generate_sql(query): print(query) inputs = tokenizer(query, return_tensors="pt") outputs = model.generate(**inputs) print(outputs) sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True) print("======>", sql_query) return sql_query def execute_sql(sql_query): conn = sqlite3.connect("./ecommerce.db") cursor = conn.cursor() try: cursor.execute(sql_query) result = cursor.fetchall() conn.commit() except Exception as e: result = str(e) conn.close() return result class QueryRequest(BaseModel): text: str @app.post("/generate_sql/") def get_sql(query: QueryRequest): sql_query = generate_sql(query.text) result = execute_sql(sql_query) return {"sql": sql_query, "result": result} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)