diff --git a/.env.copy b/.env.copy index e83fac3..af58625 100644 --- a/.env.copy +++ b/.env.copy @@ -6,3 +6,7 @@ DB_USER=postgres DB_PASSWORD=postgres OPENAI_API_KEY=your_actual_openai_api_key_here + +QDRANT_URL=http://localhost:6333 +QDRANT_COLLECTION=northwind_rag +GROQ_API_KEY= diff --git a/app.py b/app.py index 9a1d285..5a61b8e 100644 --- a/app.py +++ b/app.py @@ -235,7 +235,8 @@ {"label": "Schema-Based Querying", "value": "schema"}, {"label": "RAG (Retrieval-Augmented Generation)", "value": "rag"}, {"label": "Visualize", "value": "visualize"}, - {"label": "Multi-Table Join", "value": "multitablejoin"} + {"label": "Multi-Table Join", "value": "multitablejoin"}, + {"label": "Simple RAG", "value": "simple_rag"} ], value="schema", size="sm", @@ -360,6 +361,8 @@ def toggle_modal(n1, n2, n3, is_open): return not is_open return is_open + + # Callback for chat functionality @app.callback( [Output("chat-messages", "children"), diff --git a/database/query_engine.py b/database/query_engine.py index 005418a..da50484 100644 --- a/database/query_engine.py +++ b/database/query_engine.py @@ -542,6 +542,49 @@ def validate_query(self, sql_query: str, context: Dict[str, Any]) -> Tuple[bool, logger.error(f"Security validation error: {e}") return False, "Blocked by guardrails due to internal validation error", sql_query +class SimpleRAGQueryEngine(QueryEngine): + """Simple RAG-based query generation using Qdrant and Groq""" + + def __init__(self, qdrant_config: Dict, groq_api_key: str): + from simple_rag.rag_logic import SQLAgent + self.sql_agent = SQLAgent(qdrant_config, groq_api_key) + logger.info("Initialized SimpleRAGQueryEngine") + + def get_name(self) -> str: + return "Simple RAG Querying" + + def generate_query(self, user_query: str, context: Dict[str, Any]) -> Tuple[bool, str]: + """Generate SQL query using RAG logic""" + try: + result = self.sql_agent.process_query(user_query) + if result['success']: + return True, result['sql_query'] + else: + return False, result['error'] + except Exception as e: + error_msg = f"Failed to generate query: {str(e)}" + logger.error(error_msg) + return False, error_msg + + def execute_query(self, sql_query: str) -> Tuple[bool, Any]: + """Execute the generated SQL query""" + try: + # _execute_sql_query returns (results, columns) tuple + results, columns = self.sql_agent._execute_sql_query(sql_query) + + # Convert to DataFrame for consistency with other engines + import pandas as pd + if results and columns: + df = pd.DataFrame(results, columns=columns) + return True, df + else: + return True, pd.DataFrame() # Empty DataFrame if no results + + except Exception as e: + error_msg = f"Failed to execute query: {str(e)}" + logger.error(error_msg) + return False, error_msg + class QueryEngineFactory: """Factory for creating query engines""" @@ -560,6 +603,8 @@ def create_query_engine(engine_type: str, config: Dict[str, Any]) -> QueryEngine if not api_key: raise ValueError("Groq API key required for visualization") return VisualizationQueryEngine(api_key) + elif engine_type == "simple_rag": + return SimpleRAGQueryEngine(config.get('qdrant_config'), config.get('groq_api_key')) else: raise ValueError(f"Unknown query engine type: {engine_type}") diff --git a/pyproject.toml b/pyproject.toml index dd7741a..2f57d1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,4 +28,6 @@ dependencies = [ "seaborn>=0.13.0", "matplotlib>=3.8.0", "langchain-groq>=0.3.7", + "sentence-transformers>=5.1.0", + "qdrant-client>=1.12.1", ] diff --git a/simple_rag/Readme.md b/simple_rag/Readme.md new file mode 100644 index 0000000..46eb9ec --- /dev/null +++ b/simple_rag/Readme.md @@ -0,0 +1,47 @@ +# Demo Platform + +## Setup Instructions + +1. Clone the repository: + ``` + git clone https://github.com/pthom/northwind_psql.git + ``` + +2. Navigate into the cloned directory: + ``` + cd northwind_psql + ``` + +3. Run Docker Compose to start the services: + ``` + docker compose up + ``` + +4. Set up Qdrant using Docker: + ``` + docker run -p 6333:6333 -p 6334:6334 \ + -v "$(pwd)/qdrant_storage:/qdrant/storage:z" \ + qdrant/qdrant + ``` + +5. Navigate to the parent directory: + ``` + cd .. + ``` + + +6. Set up the `.env` file with the following content: + ``` + QDRANT_URL=http://localhost:6333 + QDRANT_COLLECTION=northwind_rag + GROQ_API_KEY= + ``` + +7. Run the embedding insertion script (first time only): + ``` + python embedd_insert.py + ``` + +## Sample Query + +- **Query:** List the top 5 customers by total order value. diff --git a/simple_rag/embedd_insert.py b/simple_rag/embedd_insert.py new file mode 100644 index 0000000..146e9f5 --- /dev/null +++ b/simple_rag/embedd_insert.py @@ -0,0 +1,65 @@ +import os +import numpy as np +from sqlalchemy import create_engine, inspect, text +from sentence_transformers import SentenceTransformer +from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams, PointStruct + +PG_CONN = os.getenv("PG_CONN", "postgresql://postgres:postgres@localhost:55432/northwind") +QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333") +COLLECTION = "northwind_rag" + +# 1. Connect to Postgres +engine = create_engine(PG_CONN) +insp = inspect(engine) + +# 2. Build table cards +def build_table_cards(engine, sample_rows=3): + cards = [] + with engine.connect() as conn: + for table in insp.get_table_names(): + cols = insp.get_columns(table) + col_lines = [f"- {c['name']} ({c['type']})" for c in cols] + fks = insp.get_foreign_keys(table) + fk_lines = [f"- {fk['constrained_columns']} -> {fk['referred_table']}.{fk['referred_columns']}" for fk in fks] or ["- none"] + try: + rows = conn.execute(text(f"SELECT * FROM {table} LIMIT {sample_rows}")).fetchall() + except: + rows = [] + card = f"""Table: {table} +Columns: +{chr(10).join(col_lines)} +Foreign Keys: +{chr(10).join(fk_lines)} +Row samples ({len(rows)}): +{chr(10).join([str(r) for r in rows])} +""" + cards.append({"table": table, "text": card}) + return cards + +cards = build_table_cards(engine) + +# 3. Create embeddings +model = SentenceTransformer("all-MiniLM-L6-v2", device='cpu') +embs = model.encode([c["text"] for c in cards], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32) + +# 4. Push into Qdrant +client = QdrantClient(url=QDRANT_URL) + +client.recreate_collection( + collection_name=COLLECTION, + vectors_config=VectorParams(size=embs.shape[1], distance=Distance.COSINE) +) + +points = [ + PointStruct( + id=i, + vector=emb.tolist(), + payload={"table": c["table"], "text": c["text"], "type": "table_card"} + ) + for i, (c, emb) in enumerate(zip(cards, embs)) +] + +client.upsert(collection_name=COLLECTION, points=points) + +print(f"Uploaded {len(points)} table cards into Qdrant collection '{COLLECTION}'") diff --git a/simple_rag/rag_logic.py b/simple_rag/rag_logic.py new file mode 100644 index 0000000..070732a --- /dev/null +++ b/simple_rag/rag_logic.py @@ -0,0 +1,290 @@ +import psycopg2 +from sentence_transformers import SentenceTransformer +import pandas as pd +from typing import List, Dict, Tuple +import qdrant_client +from dotenv import load_dotenv +from langchain_groq import ChatGroq # Fixed import - use langchain_groq + +class SQLAgent: + def __init__(self, qdrant_config: Dict, groq_api_key: str): + """ + Initialize SQL Agent with Qdrant and GPT OSS 20B on Groq + """ + # Load environment variables + load_dotenv() + + self.qdrant_config = qdrant_config + self.groq_api_key = groq_api_key # Updated variable name + + # Configure Qdrant client + self.qdrant_client = qdrant_client.QdrantClient( + url="http://localhost:6333", + ) + self.collection_name = "northwind_rag" + + # Configure ChatGroq from langchain_groq + self.model = ChatGroq( + groq_api_key=groq_api_key, + model_name="openai/gpt-oss-120b" # Changed to GPT model + ) + + # Initialize sentence transformer for embeddings + self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') + + # Store schema information + self.schema_info = [] + self.schema_embeddings = None + + # Initialize database connection and load schema + self._load_database_schema() + + def _get_db_connection(self): + """Create database connection""" + try: + conn = psycopg2.connect( + host="localhost", + database="northwind", + user="postgres", + password="postgres", + port=55432 + ) + return conn + except Exception as e: + raise Exception(f"Database connection failed: {str(e)}") + + def _load_database_schema(self): + """Load and embed database schema information""" + try: + conn = self._get_db_connection() + cursor = conn.cursor() + + # Get all tables and their columns + schema_query = """ + SELECT + t.table_name, + c.column_name, + c.data_type, + c.is_nullable, + tc.constraint_type + FROM information_schema.tables t + LEFT JOIN information_schema.columns c ON t.table_name = c.table_name + LEFT JOIN information_schema.key_column_usage kcu ON c.table_name = kcu.table_name + AND c.column_name = kcu.column_name + LEFT JOIN information_schema.table_constraints tc ON kcu.constraint_name = tc.constraint_name + WHERE t.table_schema = 'public' AND t.table_type = 'BASE TABLE' + ORDER BY t.table_name, c.ordinal_position; + """ + + cursor.execute(schema_query) + results = cursor.fetchall() + + # Organize schema information + tables_info = {} + for row in results: + table_name, column_name, data_type, is_nullable, constraint_type = row + + if table_name not in tables_info: + tables_info[table_name] = { + 'columns': [], + 'description': f"Table: {table_name}" + } + + if column_name: + column_info = f"{column_name} ({data_type})" + if constraint_type: + column_info += f" - {constraint_type}" + tables_info[table_name]['columns'].append(column_info) + + # Create schema descriptions for embedding + self.schema_info = [] + for table_name, info in tables_info.items(): + schema_text = f"Table: {table_name}\nColumns: {', '.join(info['columns'])}" + self.schema_info.append({ + 'table': table_name, + 'description': schema_text, + 'columns': info['columns'] + }) + + # Create embeddings for schema + schema_texts = [item['description'] for item in self.schema_info] + self.schema_embeddings = self.embedder.encode(schema_texts) + + cursor.close() + conn.close() + + except Exception as e: + raise Exception(f"Failed to load schema: {str(e)}") + + def _find_relevant_tables(self, query: str, top_k: int = 3) -> List[Dict]: + """Find most relevant tables based on query using Qdrant""" + query_embedding = self.embedder.encode([query])[0] + + # Search in Qdrant + search_results = self.qdrant_client.search( + collection_name=self.collection_name, + query_vector=query_embedding, + limit=top_k + ) + + relevant_tables = [] + for result in search_results: + table_info = result.payload + relevant_tables.append({ + 'table_info': { + 'description': table_info.get('text', 'No description available'), + 'table': table_info.get('table', 'Unknown table') + }, + 'similarity': result.score + }) + + return relevant_tables + + def _generate_sql_with_gpt(self, user_query: str, relevant_schema: str) -> str: + """Generate SQL query using ChatGroq""" + prompt = f""" + You are an expert SQL developer. Given the user question and database schema, generate a precise PostgreSQL query. + + Database Schema: + {relevant_schema} + + User Question: {user_query} + + Rules: + 1. Generate only valid PostgreSQL syntax + 2. Use appropriate JOINs when needed + 3. Include proper WHERE clauses for filtering + 4. Use LIMIT when appropriate + 5. Return only the SQL query without explanations + + SQL Query: + """ + + try: + response = self.model.invoke(prompt) # Use invoke method instead of generate + sql_query = response.content.strip() + + # Clean up the response to extract just the SQL + if sql_query.startswith("```sql"): + sql_query = sql_query[6:] + if sql_query.endswith("```"): + sql_query = sql_query[:-3] + + return sql_query.strip() + + except Exception as e: + raise Exception(f"Failed to generate SQL: {str(e)}") + + def _execute_sql_query(self, sql_query: str) -> Tuple[List, List]: + """Execute SQL query and return results""" + try: + conn = self._get_db_connection() + cursor = conn.cursor() + + cursor.execute(sql_query) + results = cursor.fetchall() + column_names = [desc[0] for desc in cursor.description] if cursor.description else [] + + cursor.close() + conn.close() + + return results, column_names + + except Exception as e: + # Return empty results on error to maintain tuple structure + print(f"SQL execution error: {str(e)}") + return [], [] + + def _format_results(self, results: List, columns: List) -> str: + """Format query results for display""" + if not results: + return "No results found." + + # Convert to pandas DataFrame for better formatting + df = pd.DataFrame(results, columns=columns) + + # Limit to first 10 rows for display + if len(df) > 10: + display_df = df.head(10) + result_text = f"Showing first 10 of {len(df)} results:\n\n" + else: + display_df = df + result_text = f"Found {len(df)} results:\n\n" + + result_text += display_df.to_string(index=False) + + return result_text + + def process_query(self, user_query: str) -> Dict: + """Main method to process user query using RAG""" + try: + # Step 1: Find relevant tables using RAG + relevant_tables = self._find_relevant_tables(user_query) + if not relevant_tables: + return { + 'success': False, + 'sql_query': None, + 'results': None, + 'relevant_tables': None, + 'error': "No relevant tables found for the query." + } + + # Step 2: Create schema context for LLM + schema_context = "\n\n".join([ + table['table_info']['description'] + for table in relevant_tables if table.get('table_info') + ]) + if not schema_context: + return { + 'success': False, + 'sql_query': None, + 'results': None, + 'relevant_tables': None, + 'error': "Failed to create schema context from relevant tables." + } + + # Step 3: Generate SQL using GPT OSS 20B + sql_query = self._generate_sql_with_gpt(user_query, schema_context) + if not sql_query: + return { + 'success': False, + 'sql_query': None, + 'results': None, + 'relevant_tables': None, + 'error': "Failed to generate SQL query." + } + + # Step 4: Execute SQL query + results, columns = self._execute_sql_query(sql_query) + if results is None or columns is None: + return { + 'success': False, + 'sql_query': sql_query, + 'results': None, + 'relevant_tables': [t['table_info']['table'] for t in relevant_tables], + 'error': "Failed to execute SQL query." + } + + # Step 5: Format results + formatted_results = self._format_results(results, columns) + + return { + 'success': True, + 'sql_query': sql_query, + 'results': formatted_results, + 'relevant_tables': [t['table_info']['table'] for t in relevant_tables], + 'error': None + } + + except Exception as e: + return { + 'success': False, + 'sql_query': None, + 'results': None, + 'relevant_tables': None, + 'error': str(e) + } + +def create_sql_agent(db_config: Dict, groq_api_key: str) -> SQLAgent: + """Factory function to create SQL Agent""" + return SQLAgent(db_config, groq_api_key) # Updated parameter name