Spaces:
Running
Running
| from torch_geometric.data import Batch | |
| from torch_geometric.utils import from_rdmol | |
| import torch | |
| from src.model import GIN | |
| from src.preprocess import create_clean_mol_objects | |
| from src.seed import set_seed | |
| def predict(smiles_list): | |
| """ | |
| Predict toxicity targets for a list of SMILES strings. | |
| Args: | |
| smiles_list (list[str]): SMILES strings | |
| Returns: | |
| dict: {smiles: {target_name: prediction_prob}} | |
| """ | |
| set_seed(0) | |
| # tox21 targets | |
| TARGET_NAMES = [ | |
| "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53" | |
| ] | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Received {len(smiles_list)} SMILES strings") | |
| # setup model | |
| model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean") | |
| model_path = "./checkpoints/model.pt" | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| print(f"Loaded model from {model_path}") | |
| model.to(DEVICE) | |
| model.eval() | |
| predictions = {} | |
| for smiles in smiles_list: | |
| try: | |
| # Convert SMILES to graph | |
| mol, _ = create_clean_mol_objects([smiles]) | |
| data = from_rdmol(mol[0]).to(DEVICE) | |
| batch = Batch.from_data_list([data]) | |
| # Forward pass | |
| with torch.no_grad(): | |
| logits = model(batch.x, batch.edge_index, batch.batch) | |
| probs = torch.sigmoid(logits).cpu().numpy().flatten() | |
| # Map predictions to targets | |
| pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)} | |
| predictions[smiles] = pred_dict | |
| except Exception as e: | |
| # If SMILES fails, return zeros | |
| pred_dict = {t: 0.5 for t in TARGET_NAMES} | |
| predictions[smiles] = pred_dict | |
| return predictions |