Spaces:
Running
Running
| # Standard Library Imports | |
| from typing import Callable, Optional, TypeAlias, Union | |
| # Third-Party Library Imports | |
| from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine | |
| from sqlalchemy.orm import DeclarativeBase | |
| # Local Application Imports | |
| from src.common import Config, logger | |
| # Define the SQLAlchemy Base | |
| class Base(DeclarativeBase): | |
| pass | |
| class DummyAsyncSession: | |
| is_dummy = True # Flag to indicate this is a dummy session. | |
| async def __enter__(self): | |
| return self | |
| async def __exit__(self, exc_type, exc_value, traceback): | |
| pass | |
| async def add(self, _instance, _warn=True): | |
| # No-op: simply ignore adding the instance. | |
| pass | |
| async def commit(self): | |
| # Raise an exception to simulate failure when attempting a write. | |
| raise RuntimeError("DummyAsyncSession does not support commit operations.") | |
| async def refresh(self, _instance): | |
| # Raise an exception to simulate failure when attempting to refresh. | |
| raise RuntimeError("DummyAsyncSession does not support refresh operations.") | |
| async def rollback(self): | |
| # No-op: there's nothing to roll back. | |
| pass | |
| async def close(self): | |
| # No-op: nothing to close. | |
| pass | |
| AsyncDBSessionMaker: TypeAlias = Union[async_sessionmaker[AsyncSession], Callable[[], DummyAsyncSession]] | |
| engine: Optional[AsyncEngine] = None | |
| def init_db(config: Config) -> AsyncDBSessionMaker: | |
| """ | |
| Initialize the database engine and return a session factory based on the provided configuration. | |
| In production, a valid DATABASE_URL is required. In development, if a DATABASE_URL is provided, | |
| it is used; otherwise, a dummy session factory is returned to allow graceful failure. | |
| Args: | |
| config (Config): The application configuration. | |
| Returns: | |
| AsyncDBSessionMaker: A sessionmaker bound to the engine, or a dummy session factory. | |
| """ | |
| # ruff doesn't like setting global variables, but this is practical here | |
| global engine # noqa | |
| if config.app_env == "prod": | |
| # In production, a valid DATABASE_URL is required. | |
| if not config.database_url: | |
| raise ValueError("DATABASE_URL must be set in production!") | |
| async_db_url = config.database_url | |
| engine = create_async_engine(async_db_url) | |
| return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession) | |
| # In development, if a DATABASE_URL is provided, use it. | |
| if config.database_url: | |
| async_db_url = config.database_url | |
| engine = create_async_engine(async_db_url) | |
| return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession) | |
| # No DATABASE_URL is provided; use a DummyAsyncSession that does nothing. | |
| engine = None | |
| logger.warning("No DATABASE_URL provided - database operations will use DummyAsyncSession") | |
| def async_dummy_session_factory() -> DummyAsyncSession: | |
| return DummyAsyncSession() | |
| return async_dummy_session_factory | |