""" Tests for the FastAPI application """ import pytest from fastapi.testclient import TestClient from unittest.mock import patch import os from app import app @pytest.fixture def client(): return TestClient(app) class TestRootEndpoint: def test_root(self, client): response = client.get("/") assert response.status_code == 200 data = response.json() assert data["message"] == "Toxicity Prediction API" assert "endpoints" in data assert "usage" in data assert "/metadata" in data["endpoints"] assert "/healthz" in data["endpoints"] assert "/predict" in data["endpoints"] class TestMetadataEndpoint: def test_metadata(self, client): response = client.get("/metadata") assert response.status_code == 200 data = response.json() assert data["name"] == "AwesomeTox" assert data["version"] == "1.0.0" assert data["max_batch_size"] == 256 assert data["tox_endpoints"] == ["mutagenicity", "hepatotoxicity"] class TestHealthzEndpoint: def test_healthz(self, client): response = client.get("/healthz") assert response.status_code == 200 assert response.json() == {"ok": True} class TestPredictEndpoint: @patch.dict(os.environ, {"API_KEY": "test-key"}) def test_predict_with_valid_auth(self, client): headers = {"Authorization": "Bearer test-key"} data = {"smiles": ["CCO"]} response = client.post("/predict", json=data, headers=headers) assert response.status_code == 200 result = response.json() assert "predictions" in result assert "model_info" in result assert result["model_info"]["name"] == "random_clf" assert result["model_info"]["version"] == "1.0.0" @patch.dict(os.environ, {"API_KEY": "test-key"}) def test_predict_without_auth(self, client): data = {"smiles": ["CCO"]} response = client.post("/predict", json=data) assert response.status_code == 401 assert response.json()["detail"] == "Unauthorized" @patch.dict(os.environ, {"API_KEY": "test-key"}) def test_predict_with_invalid_auth(self, client): headers = {"Authorization": "Bearer wrong-key"} data = {"smiles": ["CCO"]} response = client.post("/predict", json=data, headers=headers) assert response.status_code == 401 @patch.dict(os.environ, {"API_KEY": "test-key"}) def test_predict_empty_smiles_list(self, client): headers = {"Authorization": "Bearer test-key"} data = {"smiles": []} response = client.post("/predict", json=data, headers=headers) assert response.status_code == 422 # Validation error due to min_items=1 @patch.dict(os.environ, {"API_KEY": "test-key"}) def test_predict_too_many_smiles(self, client): headers = {"Authorization": "Bearer test-key"} data = {"smiles": ["CCO"] * 1001} # Exceeds max_items=1000 response = client.post("/predict", json=data, headers=headers) assert response.status_code == 422 # Validation error due to max_items=1000 @patch.dict(os.environ, {"API_KEY": "test-key"}) def test_predict_multiple_smiles(self, client): headers = {"Authorization": "Bearer test-key"} data = {"smiles": ["CCO", "CCN", "CCC"]} response = client.post("/predict", json=data, headers=headers) assert response.status_code == 200 result = response.json() predictions = result["predictions"] for smiles in data["smiles"]: assert smiles in predictions assert len(predictions[smiles]) == 12