Spaces:
Running
Running
Sonja Topf
commited on
Commit
·
f484830
1
Parent(s):
101ce13
initial commit
Browse files- Dockerfile +16 -0
- README.md +76 -6
- app.py +78 -0
- assets/best_gin_model.pt +3 -0
- predict.py +68 -0
- requirements.txt +9 -0
- src/model.py +53 -0
- src/preprocess.py +101 -0
- src/seed.py +19 -0
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 2 |
+
# you will also find guides on how best to write your Dockerfile
|
| 3 |
+
|
| 4 |
+
FROM python:3.11.4
|
| 5 |
+
|
| 6 |
+
RUN useradd -m -u 1000 user
|
| 7 |
+
USER user
|
| 8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 13 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 14 |
+
|
| 15 |
+
COPY --chown=user . /app
|
| 16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,12 +1,82 @@
|
|
| 1 |
---
|
| 2 |
-
title: Tox21
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: apache-2.0
|
| 9 |
-
short_description:
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Tox21 GIN Classifier
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: apache-2.0
|
| 9 |
+
short_description: Graph Isomorphism Network
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Tox21 Graph Isomorphism Network Classifier
|
| 13 |
+
|
| 14 |
+
This repository hosts a Hugging Face Space that provides an examplary API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/tschouis/tox21_leaderboard).
|
| 15 |
+
|
| 16 |
+
In this example, we trained a GIN classifier on the Tox21 targets and saved the trained model in the `assets/` folder.
|
| 17 |
+
|
| 18 |
+
**Important:** For leaderboard submission, your Space does not need to include training code. It only needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a nested prediction dictionary as output, with SMILES as keys and dictionaries containing targetname-prediction pairs as values. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
|
| 19 |
+
|
| 20 |
+
# Repository Structure
|
| 21 |
+
- `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
|
| 22 |
+
- `app.py` - FastAPI application wrapper (can be used as-is).
|
| 23 |
+
|
| 24 |
+
- `src/` - Core model & preprocessing logic:
|
| 25 |
+
- `preprocess.py` - SMILES preprocessing pipeline
|
| 26 |
+
- `model.py` - GIN classifier
|
| 27 |
+
- `seed.py` - used to ensure reproducibility
|
| 28 |
+
|
| 29 |
+
# Quickstart with Spaces
|
| 30 |
+
|
| 31 |
+
You can easily adapt this project in your own Hugging Face account:
|
| 32 |
+
|
| 33 |
+
- Open this Space on Hugging Face.
|
| 34 |
+
|
| 35 |
+
- Click "Duplicate this Space" (top-right corner).
|
| 36 |
+
|
| 37 |
+
- Modify `src/` for your preprocessing pipeline and model class
|
| 38 |
+
|
| 39 |
+
- Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
|
| 40 |
+
|
| 41 |
+
That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
|
| 42 |
+
|
| 43 |
+
# Installation
|
| 44 |
+
To run the GIN classifier, clone the repository and install dependencies:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
git clone https://huggingface.co/spaces/tschouis/tox21_gin_classifier
|
| 48 |
+
cd tox21_gin_classifier
|
| 49 |
+
pip install -r requirements.txt
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Inference
|
| 54 |
+
|
| 55 |
+
For inference, you only need `predict.py`.
|
| 56 |
+
|
| 57 |
+
Example usage inside Python:
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
from predict import predict
|
| 61 |
+
|
| 62 |
+
smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
|
| 63 |
+
results = predict(smiles_list)
|
| 64 |
+
|
| 65 |
+
print(results)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
The output will be a nested dictionary in the format:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
{
|
| 72 |
+
"CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
|
| 73 |
+
"c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
|
| 74 |
+
"CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
|
| 75 |
+
}
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
# Notes
|
| 79 |
+
|
| 80 |
+
- Only adapting `predict.py` for your model inference is required for leaderboard submission.
|
| 81 |
+
|
| 82 |
+
- Preprocessing (here inside `src/preprocess.py`) must be applied at inference time, not just predicting.
|
app.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is the main entry point for the FastAPI application.
|
| 3 |
+
The app handles the request to predict toxicity for a list of SMILES strings.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ---------------------------------------------------------------------------------------
|
| 7 |
+
# Dependencies and global variable definition
|
| 8 |
+
import os
|
| 9 |
+
from typing import List, Dict, Optional
|
| 10 |
+
from fastapi import FastAPI, Header, HTTPException
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
from predict import predict as predict_func
|
| 14 |
+
|
| 15 |
+
API_KEY = os.getenv("API_KEY") # set via Space Secrets
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------------------
|
| 19 |
+
class Request(BaseModel):
|
| 20 |
+
smiles: List[str] = Field(min_items=1, max_items=1000)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Response(BaseModel):
|
| 24 |
+
predictions: dict
|
| 25 |
+
model_info: Dict[str, str] = {}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
app = FastAPI(title="toxicity-api")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@app.get("/")
|
| 32 |
+
def root():
|
| 33 |
+
return {
|
| 34 |
+
"message": "Toxicity Prediction API",
|
| 35 |
+
"endpoints": {
|
| 36 |
+
"/metadata": "GET - API metadata and capabilities",
|
| 37 |
+
"/healthz": "GET - Health check",
|
| 38 |
+
"/predict": "POST - Predict toxicity for SMILES",
|
| 39 |
+
},
|
| 40 |
+
"usage": "Send POST to /predict with {'smiles': ['your_smiles_here']} and Authorization header",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.get("/metadata")
|
| 45 |
+
def metadata():
|
| 46 |
+
return {
|
| 47 |
+
"name": "AwesomeTox",
|
| 48 |
+
"version": "1.0.0",
|
| 49 |
+
"max_batch_size": 256,
|
| 50 |
+
"tox_endpoints": [
|
| 51 |
+
"NR-AR",
|
| 52 |
+
"NR-AR-LBD",
|
| 53 |
+
"NR-AhR",
|
| 54 |
+
"NR-Aromatase",
|
| 55 |
+
"NR-ER",
|
| 56 |
+
"NR-ER-LBD",
|
| 57 |
+
"NR-PPAR-gamma",
|
| 58 |
+
"SR-ARE",
|
| 59 |
+
"SR-ATAD5",
|
| 60 |
+
"SR-HSE",
|
| 61 |
+
"SR-MMP",
|
| 62 |
+
"SR-p53",
|
| 63 |
+
],
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.get("/healthz")
|
| 68 |
+
def healthz():
|
| 69 |
+
return {"ok": True}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.post("/predict", response_model=Response)
|
| 73 |
+
def predict(request: Request):
|
| 74 |
+
predictions = predict_func(request.smiles)
|
| 75 |
+
return {
|
| 76 |
+
"predictions": predictions,
|
| 77 |
+
"model_info": {"name": "random_clf", "version": "1.0.0"},
|
| 78 |
+
}
|
assets/best_gin_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e49790b36de8674646c3c3eb6b35818f9d7319a61dbc0a483c13c2f78bcb210
|
| 3 |
+
size 634178
|
predict.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch_geometric.data import Batch
|
| 2 |
+
from torch_geometric.utils import from_rdmol
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from src.model import GIN
|
| 6 |
+
from src.preprocess import create_clean_mol_objects
|
| 7 |
+
from src.seed import set_seed
|
| 8 |
+
|
| 9 |
+
def predict_from_smiles(smiles_list):
|
| 10 |
+
"""
|
| 11 |
+
Predict toxicity targets for a list of SMILES strings.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
smiles_list (list[str]): SMILES strings
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
dict: {smiles: {target_name: prediction_prob}}
|
| 18 |
+
"""
|
| 19 |
+
set_seed(42)
|
| 20 |
+
# tox21 targets
|
| 21 |
+
TARGET_NAMES = [
|
| 22 |
+
"NR-AR",
|
| 23 |
+
"NR-AR-LBD",
|
| 24 |
+
"NR-AhR",
|
| 25 |
+
"NR-Aromatase",
|
| 26 |
+
"NR-ER",
|
| 27 |
+
"NR-ER-LBD",
|
| 28 |
+
"NR-PPAR-gamma",
|
| 29 |
+
"SR-ARE",
|
| 30 |
+
"SR-ATAD5",
|
| 31 |
+
"SR-HSE",
|
| 32 |
+
"SR-MMP",
|
| 33 |
+
"SR-p53",
|
| 34 |
+
]
|
| 35 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
print(f"Received {len(smiles_list)} SMILES strings")
|
| 37 |
+
|
| 38 |
+
# setup model
|
| 39 |
+
model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
|
| 40 |
+
model_path = "./assets/best_gin_model.pt"
|
| 41 |
+
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 42 |
+
print(f"Loaded model from {model_path}")
|
| 43 |
+
model.to(DEVICE)
|
| 44 |
+
model.eval()
|
| 45 |
+
predictions = {}
|
| 46 |
+
|
| 47 |
+
for smiles in smiles_list:
|
| 48 |
+
try:
|
| 49 |
+
# Convert SMILES to graph
|
| 50 |
+
mol, _ = create_clean_mol_objects([smiles])
|
| 51 |
+
data = from_rdmol(mol[0]).to(DEVICE)
|
| 52 |
+
batch = Batch.from_data_list([data])
|
| 53 |
+
|
| 54 |
+
# Forward pass
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
logits = model(batch.x, batch.edge_index, batch.batch)
|
| 57 |
+
probs = torch.sigmoid(logits).cpu().numpy().flatten()
|
| 58 |
+
|
| 59 |
+
# Map predictions to targets
|
| 60 |
+
pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)}
|
| 61 |
+
predictions[smiles] = pred_dict
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
# If SMILES fails, return zeros
|
| 65 |
+
pred_dict = {t: 0.0 for t in TARGET_NAMES}
|
| 66 |
+
predictions[smiles] = pred_dict
|
| 67 |
+
|
| 68 |
+
return predictions
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
torch=2.3.0
|
| 4 |
+
torch-geometric=2.6.1
|
| 5 |
+
numpy=1.26.2
|
| 6 |
+
pandas=2.2.2
|
| 7 |
+
rdkit-pypi=2024.3.6
|
| 8 |
+
pydantic
|
| 9 |
+
typing-extensions
|
src/model.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GIN(torch.nn.Module):
|
| 10 |
+
def __init__(self, num_features, num_classes, dropout, hidden_dim=64, num_layers=5, add_or_mean="add"):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.num_layers = num_layers
|
| 13 |
+
self.hidden_dim = hidden_dim
|
| 14 |
+
self.add_or_mean = add_or_mean
|
| 15 |
+
self.dropout = dropout
|
| 16 |
+
|
| 17 |
+
self.conv_layers = nn.ModuleList()
|
| 18 |
+
|
| 19 |
+
# input features → hidden_dim
|
| 20 |
+
mlp = nn.Sequential(
|
| 21 |
+
nn.Linear(num_features, hidden_dim),
|
| 22 |
+
nn.ReLU(),
|
| 23 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 24 |
+
nn.BatchNorm1d(hidden_dim)
|
| 25 |
+
)
|
| 26 |
+
self.conv_layers.append(GINConv(mlp, train_eps=True))
|
| 27 |
+
|
| 28 |
+
# hidden GIN layers
|
| 29 |
+
for _ in range(num_layers - 1):
|
| 30 |
+
mlp = nn.Sequential(
|
| 31 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 32 |
+
nn.ReLU(),
|
| 33 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 34 |
+
nn.BatchNorm1d(hidden_dim)
|
| 35 |
+
)
|
| 36 |
+
self.conv_layers.append(GINConv(mlp, train_eps=True))
|
| 37 |
+
|
| 38 |
+
# Final classifier (after pooling)
|
| 39 |
+
self.fc = nn.Linear(hidden_dim, num_classes)
|
| 40 |
+
|
| 41 |
+
def forward(self, x, edge_index, batch):
|
| 42 |
+
for conv in self.conv_layers:
|
| 43 |
+
x = conv(x, edge_index)
|
| 44 |
+
x = F.relu(x)
|
| 45 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 46 |
+
# Pool to get graph-level representation
|
| 47 |
+
if self.add_or_mean == "mean":
|
| 48 |
+
x = global_mean_pool(x, batch)
|
| 49 |
+
elif self.add_or_mean == "add":
|
| 50 |
+
x = global_add_pool(x, batch)
|
| 51 |
+
|
| 52 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
| 53 |
+
return self.fc(x)
|
src/preprocess.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
from rdkit.Chem.MolStandardize import rdMolStandardize
|
| 7 |
+
from rdkit import Chem
|
| 8 |
+
from torch_geometric.data import InMemoryDataset
|
| 9 |
+
from torch_geometric.utils import from_rdmol
|
| 10 |
+
|
| 11 |
+
def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
|
| 12 |
+
"""Create cleaned RDKit Mol objects from SMILES.
|
| 13 |
+
Returns (list of mols, mask of valid mols).
|
| 14 |
+
"""
|
| 15 |
+
clean_mol_mask = []
|
| 16 |
+
mols = []
|
| 17 |
+
|
| 18 |
+
# Standardizer components
|
| 19 |
+
cleaner = rdMolStandardize.CleanupParameters()
|
| 20 |
+
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
|
| 21 |
+
|
| 22 |
+
for smi in smiles:
|
| 23 |
+
try:
|
| 24 |
+
mol = Chem.MolFromSmiles(smi)
|
| 25 |
+
if mol is None:
|
| 26 |
+
clean_mol_mask.append(False)
|
| 27 |
+
continue
|
| 28 |
+
|
| 29 |
+
# Cleanup and canonicalize
|
| 30 |
+
mol = rdMolStandardize.Cleanup(mol, cleaner)
|
| 31 |
+
mol = tautomer_enumerator.Canonicalize(mol)
|
| 32 |
+
|
| 33 |
+
# Recompute canonical SMILES & reload
|
| 34 |
+
can_smi = Chem.MolToSmiles(mol)
|
| 35 |
+
mol = Chem.MolFromSmiles(can_smi)
|
| 36 |
+
|
| 37 |
+
if mol is not None:
|
| 38 |
+
mols.append(mol)
|
| 39 |
+
clean_mol_mask.append(True)
|
| 40 |
+
else:
|
| 41 |
+
clean_mol_mask.append(False)
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Failed to standardize {smi}: {e}")
|
| 45 |
+
clean_mol_mask.append(False)
|
| 46 |
+
|
| 47 |
+
return mols, np.array(clean_mol_mask, dtype=bool)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Tox21Dataset(InMemoryDataset):
|
| 51 |
+
def __init__(self, dataframe):
|
| 52 |
+
super().__init__()
|
| 53 |
+
data_list = []
|
| 54 |
+
|
| 55 |
+
# Clean molecules & filter dataframe
|
| 56 |
+
mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist())
|
| 57 |
+
dataframe = dataframe[clean_mask].reset_index(drop=True)
|
| 58 |
+
|
| 59 |
+
# Now mols and dataframe are aligned, so we can zip
|
| 60 |
+
for mol, (_, row) in zip(mols, dataframe.iterrows()):
|
| 61 |
+
try:
|
| 62 |
+
data = from_rdmol(mol)
|
| 63 |
+
|
| 64 |
+
# Extract labels as a pandas Series
|
| 65 |
+
drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"]
|
| 66 |
+
labels = row.drop(drop_cols)
|
| 67 |
+
|
| 68 |
+
# Mask for valid labels
|
| 69 |
+
mask = ~labels.isna()
|
| 70 |
+
|
| 71 |
+
# Explicit numeric conversion, replaces NaN with 0.0 safely
|
| 72 |
+
labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values
|
| 73 |
+
|
| 74 |
+
# Convert to tensors
|
| 75 |
+
y = torch.tensor(labels, dtype=torch.float).unsqueeze(0)
|
| 76 |
+
m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0)
|
| 77 |
+
|
| 78 |
+
data.y = y
|
| 79 |
+
data.mask = m
|
| 80 |
+
|
| 81 |
+
data_list.append(data)
|
| 82 |
+
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Skipping molecule {row['smiles']} due to error: {e}")
|
| 85 |
+
|
| 86 |
+
# Collate into dataset
|
| 87 |
+
self.data, self.slices = self.collate(data_list)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_graph_dataset(filepath:str):
|
| 91 |
+
"""returns an InMemoryDataset that can be used in dataloaders
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
filepath (str): the filepath of the data csv
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Tox21Dataset: dataset for dataloaders
|
| 98 |
+
"""
|
| 99 |
+
df = pd.read_csv(filepath)
|
| 100 |
+
dataset = Tox21Dataset(df)
|
| 101 |
+
return dataset
|
src/seed.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def set_seed(seed: int = 42):
|
| 7 |
+
random.seed(seed)
|
| 8 |
+
np.random.seed(seed)
|
| 9 |
+
torch.manual_seed(seed)
|
| 10 |
+
torch.cuda.manual_seed(seed) # current GPU
|
| 11 |
+
torch.cuda.manual_seed_all(seed) # all GPUs
|
| 12 |
+
|
| 13 |
+
# Ensure deterministic behavior
|
| 14 |
+
torch.backends.cudnn.deterministic = True
|
| 15 |
+
torch.backends.cudnn.benchmark = False
|
| 16 |
+
|
| 17 |
+
# For PyTorch >= 1.8
|
| 18 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 19 |
+
torch.use_deterministic_algorithms(True, warn_only=True)
|