Tschoui's picture
πŸ›
4eb4266
"""
Tests for the FastAPI application
"""
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch
import os
from app import app
@pytest.fixture
def client():
return TestClient(app)
class TestRootEndpoint:
def test_root(self, client):
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["message"] == "Toxicity Prediction API"
assert "endpoints" in data
assert "usage" in data
assert "/metadata" in data["endpoints"]
assert "/healthz" in data["endpoints"]
assert "/predict" in data["endpoints"]
class TestMetadataEndpoint:
def test_metadata(self, client):
response = client.get("/metadata")
assert response.status_code == 200
data = response.json()
assert data["name"] == "AwesomeTox"
assert data["version"] == "1.0.0"
assert data["max_batch_size"] == 256
assert data["tox_endpoints"] == ["mutagenicity", "hepatotoxicity"]
class TestHealthzEndpoint:
def test_healthz(self, client):
response = client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"ok": True}
class TestPredictEndpoint:
@patch.dict(os.environ, {"API_KEY": "test-key"})
def test_predict_with_valid_auth(self, client):
headers = {"Authorization": "Bearer test-key"}
data = {"smiles": ["CCO"]}
response = client.post("/predict", json=data, headers=headers)
assert response.status_code == 200
result = response.json()
assert "predictions" in result
assert "model_info" in result
assert result["model_info"]["name"] == "random_clf"
assert result["model_info"]["version"] == "1.0.0"
@patch.dict(os.environ, {"API_KEY": "test-key"})
def test_predict_without_auth(self, client):
data = {"smiles": ["CCO"]}
response = client.post("/predict", json=data)
assert response.status_code == 401
assert response.json()["detail"] == "Unauthorized"
@patch.dict(os.environ, {"API_KEY": "test-key"})
def test_predict_with_invalid_auth(self, client):
headers = {"Authorization": "Bearer wrong-key"}
data = {"smiles": ["CCO"]}
response = client.post("/predict", json=data, headers=headers)
assert response.status_code == 401
@patch.dict(os.environ, {"API_KEY": "test-key"})
def test_predict_empty_smiles_list(self, client):
headers = {"Authorization": "Bearer test-key"}
data = {"smiles": []}
response = client.post("/predict", json=data, headers=headers)
assert response.status_code == 422 # Validation error due to min_items=1
@patch.dict(os.environ, {"API_KEY": "test-key"})
def test_predict_too_many_smiles(self, client):
headers = {"Authorization": "Bearer test-key"}
data = {"smiles": ["CCO"] * 1001} # Exceeds max_items=1000
response = client.post("/predict", json=data, headers=headers)
assert response.status_code == 422 # Validation error due to max_items=1000
@patch.dict(os.environ, {"API_KEY": "test-key"})
def test_predict_multiple_smiles(self, client):
headers = {"Authorization": "Bearer test-key"}
data = {"smiles": ["CCO", "CCN", "CCC"]}
response = client.post("/predict", json=data, headers=headers)
assert response.status_code == 200
result = response.json()
predictions = result["predictions"]
for smiles in data["smiles"]:
assert smiles in predictions
assert len(predictions[smiles]) == 12