Spaces:
Sleeping
Sleeping
| """ | |
| Tests for the FastAPI application | |
| """ | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from unittest.mock import patch | |
| import os | |
| from app import app | |
| def client(): | |
| return TestClient(app) | |
| class TestRootEndpoint: | |
| def test_root(self, client): | |
| response = client.get("/") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["message"] == "Toxicity Prediction API" | |
| assert "endpoints" in data | |
| assert "usage" in data | |
| assert "/metadata" in data["endpoints"] | |
| assert "/healthz" in data["endpoints"] | |
| assert "/predict" in data["endpoints"] | |
| class TestMetadataEndpoint: | |
| def test_metadata(self, client): | |
| response = client.get("/metadata") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["name"] == "AwesomeTox" | |
| assert data["version"] == "1.0.0" | |
| assert data["max_batch_size"] == 256 | |
| assert data["tox_endpoints"] == ["mutagenicity", "hepatotoxicity"] | |
| class TestHealthzEndpoint: | |
| def test_healthz(self, client): | |
| response = client.get("/healthz") | |
| assert response.status_code == 200 | |
| assert response.json() == {"ok": True} | |
| class TestPredictEndpoint: | |
| def test_predict_with_valid_auth(self, client): | |
| headers = {"Authorization": "Bearer test-key"} | |
| data = {"smiles": ["CCO"]} | |
| response = client.post("/predict", json=data, headers=headers) | |
| assert response.status_code == 200 | |
| result = response.json() | |
| assert "predictions" in result | |
| assert "model_info" in result | |
| assert result["model_info"]["name"] == "random_clf" | |
| assert result["model_info"]["version"] == "1.0.0" | |
| def test_predict_without_auth(self, client): | |
| data = {"smiles": ["CCO"]} | |
| response = client.post("/predict", json=data) | |
| assert response.status_code == 401 | |
| assert response.json()["detail"] == "Unauthorized" | |
| def test_predict_with_invalid_auth(self, client): | |
| headers = {"Authorization": "Bearer wrong-key"} | |
| data = {"smiles": ["CCO"]} | |
| response = client.post("/predict", json=data, headers=headers) | |
| assert response.status_code == 401 | |
| def test_predict_empty_smiles_list(self, client): | |
| headers = {"Authorization": "Bearer test-key"} | |
| data = {"smiles": []} | |
| response = client.post("/predict", json=data, headers=headers) | |
| assert response.status_code == 422 # Validation error due to min_items=1 | |
| def test_predict_too_many_smiles(self, client): | |
| headers = {"Authorization": "Bearer test-key"} | |
| data = {"smiles": ["CCO"] * 1001} # Exceeds max_items=1000 | |
| response = client.post("/predict", json=data, headers=headers) | |
| assert response.status_code == 422 # Validation error due to max_items=1000 | |
| def test_predict_multiple_smiles(self, client): | |
| headers = {"Authorization": "Bearer test-key"} | |
| data = {"smiles": ["CCO", "CCN", "CCC"]} | |
| response = client.post("/predict", json=data, headers=headers) | |
| assert response.status_code == 200 | |
| result = response.json() | |
| predictions = result["predictions"] | |
| for smiles in data["smiles"]: | |
| assert smiles in predictions | |
| assert len(predictions[smiles]) == 12 |