|
6 | 6 |
|
7 | 7 | from api.utils.settings import BASE_DIR, settings |
8 | 8 |
|
9 | | -DB_HOST = settings.DB_HOST |
10 | | -DB_PORT = settings.DB_PORT |
11 | | -DB_USER = settings.DB_USER |
12 | | -DB_PASSWORD = settings.DB_PASSWORD |
13 | | -DB_NAME = settings.DB_NAME |
14 | | -DB_TYPE = settings.DB_TYPE |
15 | | - |
16 | 9 |
|
17 | 10 | def get_db_engine(test_mode: bool = False): # type: ignore |
18 | | - if DB_TYPE == "sqlite" or test_mode: |
| 11 | + """Create a SQLAlchemy engine instance. |
| 12 | +
|
| 13 | + Args: |
| 14 | + test_mode (bool): If True, use a SQLite database for testing purposes. |
| 15 | +
|
| 16 | + Returns: |
| 17 | + Engine: SQLAlchemy engine instance. |
| 18 | + """ |
| 19 | + if settings.DB_TYPE == "sqlite" or test_mode: |
| 20 | + # If the database type is SQLite or we are in test mode, use a SQLite database |
19 | 21 | BASE_PATH = f"sqlite:///{BASE_DIR}" |
20 | 22 | DATABASE_URL = BASE_PATH + ("/test.db" if test_mode else "/db.sqlite3") |
21 | | - |
22 | | - return create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) |
23 | | - elif DB_TYPE == "mysql": |
24 | | - DATABASE_URL = ( |
25 | | - f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" |
26 | | - ) |
27 | | - elif DB_TYPE == "postgresql": |
28 | | - DATABASE_URL = ( |
29 | | - f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" |
30 | | - ) |
31 | 23 | else: |
32 | | - raise ValueError(f"Unsupported DB_TYPE: {DB_TYPE}") |
| 24 | + # For other database types, use the provided DATABASE_URL |
| 25 | + DATABASE_URL = settings.DB_URL |
33 | 26 |
|
34 | | - return create_engine(DATABASE_URL) |
| 27 | + return create_engine( |
| 28 | + DATABASE_URL, |
| 29 | + connect_args={"check_same_thread": False} |
| 30 | + if settings.DB_TYPE == "sqlite" |
| 31 | + else {}, |
| 32 | + ) |
35 | 33 |
|
36 | 34 |
|
| 35 | +# Create an engine instance |
37 | 36 | engine = get_db_engine() |
38 | 37 |
|
| 38 | +# Create a configured "Session" class |
39 | 39 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
40 | 40 |
|
| 41 | +# Create a configured "scoped_session" factory |
41 | 42 | db_session = scoped_session(SessionLocal) |
42 | 43 |
|
| 44 | +# Create a base class for declarative class definitions |
43 | 45 | Base = declarative_base() |
44 | 46 |
|
45 | 47 |
|
46 | 48 | def create_database(): |
| 49 | + """Create all tables in the database.""" |
47 | 50 | return Base.metadata.create_all(bind=engine) |
48 | 51 |
|
49 | 52 |
|
50 | 53 | def get_db(): |
| 54 | + """Yield a database session. |
| 55 | +
|
| 56 | + This function provides a context manager for database sessions, ensuring that |
| 57 | + each session is properly closed after use. |
| 58 | +
|
| 59 | + Yields: |
| 60 | + Session: SQLAlchemy session instance. |
| 61 | + """ |
51 | 62 | db = db_session() |
52 | 63 | try: |
53 | 64 | yield db |
|
0 commit comments