Spaces:
Paused
Paused
| import logging | |
| import os | |
| import time | |
| import docker | |
| import pytest | |
| from docker import DockerClient | |
| from pytest_docker.plugin import get_docker_ip | |
| from fastapi.testclient import TestClient | |
| from sqlalchemy import text, create_engine | |
| log = logging.getLogger(__name__) | |
| def get_fast_api_client(): | |
| from main import app | |
| with TestClient(app) as c: | |
| return c | |
| class AbstractIntegrationTest: | |
| BASE_PATH = None | |
| def create_url(self, path="", query_params=None): | |
| if self.BASE_PATH is None: | |
| raise Exception("BASE_PATH is not set") | |
| parts = self.BASE_PATH.split("/") | |
| parts = [part.strip() for part in parts if part.strip() != ""] | |
| path_parts = path.split("/") | |
| path_parts = [part.strip() for part in path_parts if part.strip() != ""] | |
| query_parts = "" | |
| if query_params: | |
| query_parts = "&".join( | |
| [f"{key}={value}" for key, value in query_params.items()] | |
| ) | |
| query_parts = f"?{query_parts}" | |
| return "/".join(parts + path_parts) + query_parts | |
| def setup_class(cls): | |
| pass | |
| def setup_method(self): | |
| pass | |
| def teardown_class(cls): | |
| pass | |
| def teardown_method(self): | |
| pass | |
| class AbstractPostgresTest(AbstractIntegrationTest): | |
| DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" | |
| docker_client: DockerClient | |
| def _create_db_url(cls, env_vars_postgres: dict) -> str: | |
| host = get_docker_ip() | |
| user = env_vars_postgres["POSTGRES_USER"] | |
| pw = env_vars_postgres["POSTGRES_PASSWORD"] | |
| port = 8081 | |
| db = env_vars_postgres["POSTGRES_DB"] | |
| return f"postgresql://{user}:{pw}@{host}:{port}/{db}" | |
| def setup_class(cls): | |
| super().setup_class() | |
| try: | |
| env_vars_postgres = { | |
| "POSTGRES_USER": "user", | |
| "POSTGRES_PASSWORD": "example", | |
| "POSTGRES_DB": "openwebui", | |
| } | |
| cls.docker_client = docker.from_env() | |
| cls.docker_client.containers.run( | |
| "postgres:16.2", | |
| detach=True, | |
| environment=env_vars_postgres, | |
| name=cls.DOCKER_CONTAINER_NAME, | |
| ports={5432: ("0.0.0.0", 8081)}, | |
| command="postgres -c log_statement=all", | |
| ) | |
| time.sleep(0.5) | |
| database_url = cls._create_db_url(env_vars_postgres) | |
| os.environ["DATABASE_URL"] = database_url | |
| retries = 10 | |
| db = None | |
| while retries > 0: | |
| try: | |
| from open_webui.config import OPEN_WEBUI_DIR | |
| db = create_engine(database_url, pool_pre_ping=True) | |
| db = db.connect() | |
| log.info("postgres is ready!") | |
| break | |
| except Exception as e: | |
| log.warning(e) | |
| time.sleep(3) | |
| retries -= 1 | |
| if db: | |
| # import must be after setting env! | |
| cls.fast_api_client = get_fast_api_client() | |
| db.close() | |
| else: | |
| raise Exception("Could not connect to Postgres") | |
| except Exception as ex: | |
| log.error(ex) | |
| cls.teardown_class() | |
| pytest.fail(f"Could not setup test environment: {ex}") | |
| def _check_db_connection(self): | |
| from open_webui.internal.db import Session | |
| retries = 10 | |
| while retries > 0: | |
| try: | |
| Session.execute(text("SELECT 1")) | |
| Session.commit() | |
| break | |
| except Exception as e: | |
| Session.rollback() | |
| log.warning(e) | |
| time.sleep(3) | |
| retries -= 1 | |
| def setup_method(self): | |
| super().setup_method() | |
| self._check_db_connection() | |
| def teardown_class(cls) -> None: | |
| super().teardown_class() | |
| cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) | |
| def teardown_method(self): | |
| from open_webui.internal.db import Session | |
| # rollback everything not yet committed | |
| Session.commit() | |
| # truncate all tables | |
| tables = [ | |
| "auth", | |
| "chat", | |
| "chatidtag", | |
| "document", | |
| "memory", | |
| "model", | |
| "prompt", | |
| "tag", | |
| '"user"', | |
| ] | |
| for table in tables: | |
| Session.execute(text(f"TRUNCATE TABLE {table}")) | |
| Session.commit() | |