| import logging | |
| from collections import defaultdict | |
| from typing import List | |
| import mols2grid | |
| import pandas as pd | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| def draw_grid_generate( | |
| samples: List[str], | |
| seeds: List[str] = [], | |
| n_cols: int = 3, | |
| size=(140, 200), | |
| ) -> str: | |
| """ | |
| Uses mols2grid to draw a HTML grid for the generated molecules | |
| Args: | |
| samples: The generated samples. | |
| n_cols: Number of columns in grid. Defaults to 5. | |
| size: Size of molecule in grid. Defaults to (140, 200). | |
| Returns: | |
| HTML to display | |
| """ | |
| result = defaultdict(list) | |
| result.update( | |
| { | |
| "SMILES": seeds + samples, | |
| "Name": [f"Seed_{i}" for i in range(len(seeds))] | |
| + [f"Generated_{i}" for i in range(len(samples))], | |
| }, | |
| ) | |
| result_df = pd.DataFrame(result) | |
| obj = mols2grid.display( | |
| result_df, | |
| tooltip=list(result.keys()), | |
| height=1100, | |
| n_cols=n_cols, | |
| name="Results", | |
| size=size, | |
| ) | |
| return obj.data | |