| """ | |
| In this file, the input functions for query and support set molecules are defined. | |
| Input is assumed to be either a SMILES string, a list of SMILES strings, or a pandas | |
| dataframe. | |
| """ | |
| #--------------------------------------------------------------------------------------- | |
| # Dependencies | |
| import pandas as pd | |
| from typing import List | |
| import torch | |
| from src.data_preprocessing.create_descriptors import preprocess_molecules | |
| #--------------------------------------------------------------------------------------- | |
| # Define main functions | |
| def create_query_input(smiles_input: [str, List[str], pd.DataFrame]): | |
| """ | |
| This function creates the input for the query molecules. | |
| """ | |
| # Create vector representation | |
| numpy_vector_representation = preprocess_molecules(smiles_input) | |
| assert len(numpy_vector_representation.shape) == 2 | |
| # Create pytorch tensor | |
| tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(1).float() | |
| return tensor | |
| def create_support_set_input(smiles_input: [str, List[str], pd.DataFrame]): | |
| """ | |
| This function creates the input for the support set molecules. | |
| """ | |
| # Create vector representation | |
| numpy_vector_representation = preprocess_molecules(smiles_input) | |
| assert len(numpy_vector_representation.shape) == 2 | |
| size = numpy_vector_representation.shape[0] | |
| # Create pytorch tensors | |
| tensor = torch.from_numpy(numpy_vector_representation).unsqueeze(0).float() | |
| size = torch.tensor(size) | |
| return tensor, size |