from transformers import AutoTokenizer, AutoModelForCausalLM # Load model and tokenizer model_name = "EleutherAI/gpt-neo-2.7B" # Replace with a suitable model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) # Example schema schema = { "products": { "columns": ["product_id", "name", "price", "category_id"], "relations": "category_id -> categories.id", }, "categories": { "columns": ["id", "category_name"], "relations": None, }, "orders": { "columns": ["order_id", "customer_name", "product_id", "order_date"], "relations": "product_id -> products.product_id", }, } def generate_context(schema): """ Generate context dynamically from the schema. """ context = "### Database Schema ###\n\n" for table, details in schema.items(): context += f"Table: {table}\nColumns: {', '.join(details['columns'])}\n" if details.get("relations"): context += f"Relations: {details['relations']}\n" context += "\n" context += "### Instructions ###\n" context += ( "Your task is to understand and explain the structure of the database schema, including tables, fields, and relationships between them. You should provide descriptions of fields, their types, and the relationships between tables. Additionally, if a query involves fetching data, you need to identify the appropriate table(s) and return the relevant data based on the query. " "Describing a Table: When asked about a table, you need to tell about the columns of that table, their data types, and the relationships with other tables (if any)." "Example: If a user asks about the product table, describe fields like product_id, name, price, category_id, and the relationship between product and category." "Fetching Data: If a user asks to fetch specific data (e.g., products, orders), identify which table(s) store the data and provide the relevant records or values based on the conditions specified." "Example: If a user asks for products in a specific price range, you should filter the product table based on the price column." "Handling Relationships: Identify when multiple tables need to be joined or related based on foreign keys or other relationships. You should know how to join tables and fetch the right data from them." "Example: If a user asks for products belonging to a specific category, you need to understand that category_id in the product table links to the category table, and use that relationship to fetch the products in the requested category." "Answering Database-Related Queries: For any query regarding data in the database, determine the appropriate table(s) to query, join them if needed, and return the relevant data. Queries may involve conditions like filtering, ordering, or aggregation." "Example: 'Show me all customers who have placed orders over $500.' This involves understanding the customer and order tables and filtering the orders where the total value is over $500.\n" ) return context # Generate dynamic context context = generate_context(schema) def answer_question(context, question): """ Generate an SQL query or database-related response using the model. """ prompt = f"{context}\n\nUser Question: {question}\nSQL Query or Answer:" inputs = tokenizer(prompt, return_tensors="pt", truncation=True) outputs = model.generate(inputs.input_ids, max_length=1024, num_beams=5, early_stopping=True) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) return answer # Interactive loop print("Database Assistant is ready. Ask your questions!") # Example interactive questions questions = [ "Tell me about the products table, what kind of data it is storing?" ] for user_question in questions: try: print(f"Question: {user_question}") response = answer_question(context, user_question) print("\nGenerated Response:\n", response, "\n") except: print("Errorrrrr")