Sonja Topf commited on
Commit
f484830
·
1 Parent(s): 101ce13

initial commit

Browse files
Files changed (9) hide show
  1. Dockerfile +16 -0
  2. README.md +76 -6
  3. app.py +78 -0
  4. assets/best_gin_model.pt +3 -0
  5. predict.py +68 -0
  6. requirements.txt +9 -0
  7. src/model.py +53 -0
  8. src/preprocess.py +101 -0
  9. src/seed.py +19 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11.4
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,12 +1,82 @@
1
  ---
2
- title: Tox21 Gin Classifier
3
- emoji: 💻
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
- short_description: GIN baseline for Tox21 dataset
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Tox21 GIN Classifier
3
+ emoji: 🤖
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
+ short_description: Graph Isomorphism Network
10
  ---
11
 
12
+ # Tox21 Graph Isomorphism Network Classifier
13
+
14
+ This repository hosts a Hugging Face Space that provides an examplary API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/tschouis/tox21_leaderboard).
15
+
16
+ In this example, we trained a GIN classifier on the Tox21 targets and saved the trained model in the `assets/` folder.
17
+
18
+ **Important:** For leaderboard submission, your Space does not need to include training code. It only needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a nested prediction dictionary as output, with SMILES as keys and dictionaries containing targetname-prediction pairs as values. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
19
+
20
+ # Repository Structure
21
+ - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
22
+ - `app.py` - FastAPI application wrapper (can be used as-is).
23
+
24
+ - `src/` - Core model & preprocessing logic:
25
+ - `preprocess.py` - SMILES preprocessing pipeline
26
+ - `model.py` - GIN classifier
27
+ - `seed.py` - used to ensure reproducibility
28
+
29
+ # Quickstart with Spaces
30
+
31
+ You can easily adapt this project in your own Hugging Face account:
32
+
33
+ - Open this Space on Hugging Face.
34
+
35
+ - Click "Duplicate this Space" (top-right corner).
36
+
37
+ - Modify `src/` for your preprocessing pipeline and model class
38
+
39
+ - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
40
+
41
+ That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
42
+
43
+ # Installation
44
+ To run the GIN classifier, clone the repository and install dependencies:
45
+
46
+ ```bash
47
+ git clone https://huggingface.co/spaces/tschouis/tox21_gin_classifier
48
+ cd tox21_gin_classifier
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+
53
+ # Inference
54
+
55
+ For inference, you only need `predict.py`.
56
+
57
+ Example usage inside Python:
58
+
59
+ ```python
60
+ from predict import predict
61
+
62
+ smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
63
+ results = predict(smiles_list)
64
+
65
+ print(results)
66
+ ```
67
+
68
+ The output will be a nested dictionary in the format:
69
+
70
+ ```python
71
+ {
72
+ "CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
73
+ "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
74
+ "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
75
+ }
76
+ ```
77
+
78
+ # Notes
79
+
80
+ - Only adapting `predict.py` for your model inference is required for leaderboard submission.
81
+
82
+ - Preprocessing (here inside `src/preprocess.py`) must be applied at inference time, not just predicting.
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict as predict_func
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Request(BaseModel):
20
+ smiles: List[str] = Field(min_items=1, max_items=1000)
21
+
22
+
23
+ class Response(BaseModel):
24
+ predictions: dict
25
+ model_info: Dict[str, str] = {}
26
+
27
+
28
+ app = FastAPI(title="toxicity-api")
29
+
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {
34
+ "message": "Toxicity Prediction API",
35
+ "endpoints": {
36
+ "/metadata": "GET - API metadata and capabilities",
37
+ "/healthz": "GET - Health check",
38
+ "/predict": "POST - Predict toxicity for SMILES",
39
+ },
40
+ "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']} and Authorization header",
41
+ }
42
+
43
+
44
+ @app.get("/metadata")
45
+ def metadata():
46
+ return {
47
+ "name": "AwesomeTox",
48
+ "version": "1.0.0",
49
+ "max_batch_size": 256,
50
+ "tox_endpoints": [
51
+ "NR-AR",
52
+ "NR-AR-LBD",
53
+ "NR-AhR",
54
+ "NR-Aromatase",
55
+ "NR-ER",
56
+ "NR-ER-LBD",
57
+ "NR-PPAR-gamma",
58
+ "SR-ARE",
59
+ "SR-ATAD5",
60
+ "SR-HSE",
61
+ "SR-MMP",
62
+ "SR-p53",
63
+ ],
64
+ }
65
+
66
+
67
+ @app.get("/healthz")
68
+ def healthz():
69
+ return {"ok": True}
70
+
71
+
72
+ @app.post("/predict", response_model=Response)
73
+ def predict(request: Request):
74
+ predictions = predict_func(request.smiles)
75
+ return {
76
+ "predictions": predictions,
77
+ "model_info": {"name": "random_clf", "version": "1.0.0"},
78
+ }
assets/best_gin_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e49790b36de8674646c3c3eb6b35818f9d7319a61dbc0a483c13c2f78bcb210
3
+ size 634178
predict.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch_geometric.data import Batch
2
+ from torch_geometric.utils import from_rdmol
3
+ import torch
4
+
5
+ from src.model import GIN
6
+ from src.preprocess import create_clean_mol_objects
7
+ from src.seed import set_seed
8
+
9
+ def predict_from_smiles(smiles_list):
10
+ """
11
+ Predict toxicity targets for a list of SMILES strings.
12
+
13
+ Args:
14
+ smiles_list (list[str]): SMILES strings
15
+
16
+ Returns:
17
+ dict: {smiles: {target_name: prediction_prob}}
18
+ """
19
+ set_seed(42)
20
+ # tox21 targets
21
+ TARGET_NAMES = [
22
+ "NR-AR",
23
+ "NR-AR-LBD",
24
+ "NR-AhR",
25
+ "NR-Aromatase",
26
+ "NR-ER",
27
+ "NR-ER-LBD",
28
+ "NR-PPAR-gamma",
29
+ "SR-ARE",
30
+ "SR-ATAD5",
31
+ "SR-HSE",
32
+ "SR-MMP",
33
+ "SR-p53",
34
+ ]
35
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ print(f"Received {len(smiles_list)} SMILES strings")
37
+
38
+ # setup model
39
+ model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
40
+ model_path = "./assets/best_gin_model.pt"
41
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
42
+ print(f"Loaded model from {model_path}")
43
+ model.to(DEVICE)
44
+ model.eval()
45
+ predictions = {}
46
+
47
+ for smiles in smiles_list:
48
+ try:
49
+ # Convert SMILES to graph
50
+ mol, _ = create_clean_mol_objects([smiles])
51
+ data = from_rdmol(mol[0]).to(DEVICE)
52
+ batch = Batch.from_data_list([data])
53
+
54
+ # Forward pass
55
+ with torch.no_grad():
56
+ logits = model(batch.x, batch.edge_index, batch.batch)
57
+ probs = torch.sigmoid(logits).cpu().numpy().flatten()
58
+
59
+ # Map predictions to targets
60
+ pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)}
61
+ predictions[smiles] = pred_dict
62
+
63
+ except Exception as e:
64
+ # If SMILES fails, return zeros
65
+ pred_dict = {t: 0.0 for t in TARGET_NAMES}
66
+ predictions[smiles] = pred_dict
67
+
68
+ return predictions
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch=2.3.0
4
+ torch-geometric=2.6.1
5
+ numpy=1.26.2
6
+ pandas=2.2.2
7
+ rdkit-pypi=2024.3.6
8
+ pydantic
9
+ typing-extensions
src/model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool
4
+ import torch.nn.functional as F
5
+
6
+ import numpy as np
7
+
8
+
9
+ class GIN(torch.nn.Module):
10
+ def __init__(self, num_features, num_classes, dropout, hidden_dim=64, num_layers=5, add_or_mean="add"):
11
+ super().__init__()
12
+ self.num_layers = num_layers
13
+ self.hidden_dim = hidden_dim
14
+ self.add_or_mean = add_or_mean
15
+ self.dropout = dropout
16
+
17
+ self.conv_layers = nn.ModuleList()
18
+
19
+ # input features → hidden_dim
20
+ mlp = nn.Sequential(
21
+ nn.Linear(num_features, hidden_dim),
22
+ nn.ReLU(),
23
+ nn.Linear(hidden_dim, hidden_dim),
24
+ nn.BatchNorm1d(hidden_dim)
25
+ )
26
+ self.conv_layers.append(GINConv(mlp, train_eps=True))
27
+
28
+ # hidden GIN layers
29
+ for _ in range(num_layers - 1):
30
+ mlp = nn.Sequential(
31
+ nn.Linear(hidden_dim, hidden_dim),
32
+ nn.ReLU(),
33
+ nn.Linear(hidden_dim, hidden_dim),
34
+ nn.BatchNorm1d(hidden_dim)
35
+ )
36
+ self.conv_layers.append(GINConv(mlp, train_eps=True))
37
+
38
+ # Final classifier (after pooling)
39
+ self.fc = nn.Linear(hidden_dim, num_classes)
40
+
41
+ def forward(self, x, edge_index, batch):
42
+ for conv in self.conv_layers:
43
+ x = conv(x, edge_index)
44
+ x = F.relu(x)
45
+ x = F.dropout(x, p=self.dropout, training=self.training)
46
+ # Pool to get graph-level representation
47
+ if self.add_or_mean == "mean":
48
+ x = global_mean_pool(x, batch)
49
+ elif self.add_or_mean == "add":
50
+ x = global_add_pool(x, batch)
51
+
52
+ x = F.dropout(x, p=0.5, training=self.training)
53
+ return self.fc(x)
src/preprocess.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import pandas as pd
4
+
5
+ from rdkit import Chem
6
+ from rdkit.Chem.MolStandardize import rdMolStandardize
7
+ from rdkit import Chem
8
+ from torch_geometric.data import InMemoryDataset
9
+ from torch_geometric.utils import from_rdmol
10
+
11
+ def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
12
+ """Create cleaned RDKit Mol objects from SMILES.
13
+ Returns (list of mols, mask of valid mols).
14
+ """
15
+ clean_mol_mask = []
16
+ mols = []
17
+
18
+ # Standardizer components
19
+ cleaner = rdMolStandardize.CleanupParameters()
20
+ tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
21
+
22
+ for smi in smiles:
23
+ try:
24
+ mol = Chem.MolFromSmiles(smi)
25
+ if mol is None:
26
+ clean_mol_mask.append(False)
27
+ continue
28
+
29
+ # Cleanup and canonicalize
30
+ mol = rdMolStandardize.Cleanup(mol, cleaner)
31
+ mol = tautomer_enumerator.Canonicalize(mol)
32
+
33
+ # Recompute canonical SMILES & reload
34
+ can_smi = Chem.MolToSmiles(mol)
35
+ mol = Chem.MolFromSmiles(can_smi)
36
+
37
+ if mol is not None:
38
+ mols.append(mol)
39
+ clean_mol_mask.append(True)
40
+ else:
41
+ clean_mol_mask.append(False)
42
+
43
+ except Exception as e:
44
+ print(f"Failed to standardize {smi}: {e}")
45
+ clean_mol_mask.append(False)
46
+
47
+ return mols, np.array(clean_mol_mask, dtype=bool)
48
+
49
+
50
+ class Tox21Dataset(InMemoryDataset):
51
+ def __init__(self, dataframe):
52
+ super().__init__()
53
+ data_list = []
54
+
55
+ # Clean molecules & filter dataframe
56
+ mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist())
57
+ dataframe = dataframe[clean_mask].reset_index(drop=True)
58
+
59
+ # Now mols and dataframe are aligned, so we can zip
60
+ for mol, (_, row) in zip(mols, dataframe.iterrows()):
61
+ try:
62
+ data = from_rdmol(mol)
63
+
64
+ # Extract labels as a pandas Series
65
+ drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"]
66
+ labels = row.drop(drop_cols)
67
+
68
+ # Mask for valid labels
69
+ mask = ~labels.isna()
70
+
71
+ # Explicit numeric conversion, replaces NaN with 0.0 safely
72
+ labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values
73
+
74
+ # Convert to tensors
75
+ y = torch.tensor(labels, dtype=torch.float).unsqueeze(0)
76
+ m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0)
77
+
78
+ data.y = y
79
+ data.mask = m
80
+
81
+ data_list.append(data)
82
+
83
+ except Exception as e:
84
+ print(f"Skipping molecule {row['smiles']} due to error: {e}")
85
+
86
+ # Collate into dataset
87
+ self.data, self.slices = self.collate(data_list)
88
+
89
+
90
+ def get_graph_dataset(filepath:str):
91
+ """returns an InMemoryDataset that can be used in dataloaders
92
+
93
+ Args:
94
+ filepath (str): the filepath of the data csv
95
+
96
+ Returns:
97
+ Tox21Dataset: dataset for dataloaders
98
+ """
99
+ df = pd.read_csv(filepath)
100
+ dataset = Tox21Dataset(df)
101
+ return dataset
src/seed.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+
6
+ def set_seed(seed: int = 42):
7
+ random.seed(seed)
8
+ np.random.seed(seed)
9
+ torch.manual_seed(seed)
10
+ torch.cuda.manual_seed(seed) # current GPU
11
+ torch.cuda.manual_seed_all(seed) # all GPUs
12
+
13
+ # Ensure deterministic behavior
14
+ torch.backends.cudnn.deterministic = True
15
+ torch.backends.cudnn.benchmark = False
16
+
17
+ # For PyTorch >= 1.8
18
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
19
+ torch.use_deterministic_algorithms(True, warn_only=True)