File size: 3,816 Bytes
4c1b5d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4eb4266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c1b5d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Tests for the FastAPI application
"""

import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch
import os

from app import app


@pytest.fixture
def client():
    return TestClient(app)


class TestRootEndpoint:
    
    def test_root(self, client):
        response = client.get("/")
        assert response.status_code == 200
        data = response.json()
        
        assert data["message"] == "Toxicity Prediction API"
        assert "endpoints" in data
        assert "usage" in data
        assert "/metadata" in data["endpoints"]
        assert "/healthz" in data["endpoints"]
        assert "/predict" in data["endpoints"]


class TestMetadataEndpoint:
    
    def test_metadata(self, client):
        response = client.get("/metadata")
        assert response.status_code == 200
        data = response.json()
        
        assert data["name"] == "AwesomeTox"
        assert data["version"] == "1.0.0"
        assert data["max_batch_size"] == 256
        assert data["tox_endpoints"] == ["mutagenicity", "hepatotoxicity"]


class TestHealthzEndpoint:
    
    def test_healthz(self, client):
        response = client.get("/healthz")
        assert response.status_code == 200
        assert response.json() == {"ok": True}


class TestPredictEndpoint:
    
    @patch.dict(os.environ, {"API_KEY": "test-key"})
    def test_predict_with_valid_auth(self, client):
        headers = {"Authorization": "Bearer test-key"}
        data = {"smiles": ["CCO"]}
        
        response = client.post("/predict", json=data, headers=headers)
        assert response.status_code == 200
        
        result = response.json()
        assert "predictions" in result
        assert "model_info" in result
        assert result["model_info"]["name"] == "random_clf"
        assert result["model_info"]["version"] == "1.0.0"
    
    @patch.dict(os.environ, {"API_KEY": "test-key"})
    def test_predict_without_auth(self, client):
        data = {"smiles": ["CCO"]}
        
        response = client.post("/predict", json=data)
        assert response.status_code == 401
        assert response.json()["detail"] == "Unauthorized"
    
    @patch.dict(os.environ, {"API_KEY": "test-key"})
    def test_predict_with_invalid_auth(self, client):
        headers = {"Authorization": "Bearer wrong-key"}
        data = {"smiles": ["CCO"]}
        
        response = client.post("/predict", json=data, headers=headers)
        assert response.status_code == 401
    
    @patch.dict(os.environ, {"API_KEY": "test-key"})
    def test_predict_empty_smiles_list(self, client):
        headers = {"Authorization": "Bearer test-key"}
        data = {"smiles": []}
        
        response = client.post("/predict", json=data, headers=headers)
        assert response.status_code == 422  # Validation error due to min_items=1
    
    @patch.dict(os.environ, {"API_KEY": "test-key"})
    def test_predict_too_many_smiles(self, client):
        headers = {"Authorization": "Bearer test-key"}
        data = {"smiles": ["CCO"] * 1001}  # Exceeds max_items=1000
        
        response = client.post("/predict", json=data, headers=headers)
        assert response.status_code == 422  # Validation error due to max_items=1000
    
    @patch.dict(os.environ, {"API_KEY": "test-key"})
    def test_predict_multiple_smiles(self, client):
        headers = {"Authorization": "Bearer test-key"}
        data = {"smiles": ["CCO", "CCN", "CCC"]}
        
        response = client.post("/predict", json=data, headers=headers)
        assert response.status_code == 200
        
        result = response.json()
        predictions = result["predictions"]
        
        for smiles in data["smiles"]:
            assert smiles in predictions
            assert len(predictions[smiles]) == 12