| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import math | |
| import time | |
| import datetime | |
| import json | |
| from json import encoder | |
| FORMAT_INFO = { | |
| "inchi": { | |
| "name": "InChI_text", | |
| "tokenizer": "tokenizer_inchi.json", | |
| "max_len": 300 | |
| }, | |
| "atomtok": { | |
| "name": "SMILES_atomtok", | |
| "tokenizer": "tokenizer_smiles_atomtok.json", | |
| "max_len": 256 | |
| }, | |
| "nodes": {"max_len": 384}, | |
| "atomtok_coords": {"max_len": 480}, | |
| "chartok_coords": {"max_len": 480} | |
| } | |
| def init_logger(log_file='train.log'): | |
| from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler | |
| logger = getLogger(__name__) | |
| logger.setLevel(INFO) | |
| handler1 = StreamHandler() | |
| handler1.setFormatter(Formatter("%(message)s")) | |
| handler2 = FileHandler(filename=log_file) | |
| handler2.setFormatter(Formatter("%(message)s")) | |
| logger.addHandler(handler1) | |
| logger.addHandler(handler2) | |
| return logger | |
| def init_summary_writer(save_path): | |
| from tensorboardX import SummaryWriter | |
| summary = SummaryWriter(save_path) | |
| return summary | |
| def save_args(args): | |
| dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M") | |
| path = os.path.join(args.save_path, f'train_{dt}.log') | |
| with open(path, 'w') as f: | |
| for k, v in vars(args).items(): | |
| f.write(f"**** {k} = *{v}*\n") | |
| return | |
| def seed_torch(seed=42): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| class EpochMeter(AverageMeter): | |
| def __init__(self): | |
| super().__init__() | |
| self.epoch = AverageMeter() | |
| def update(self, val, n=1): | |
| super().update(val, n) | |
| self.epoch.update(val, n) | |
| class LossMeter(EpochMeter): | |
| def __init__(self): | |
| self.subs = {} | |
| super().__init__() | |
| def reset(self): | |
| super().reset() | |
| for k in self.subs: | |
| self.subs[k].reset() | |
| def update(self, loss, losses, n=1): | |
| loss = loss.item() | |
| super().update(loss, n) | |
| losses = {k: v.item() for k, v in losses.items()} | |
| for k, v in losses.items(): | |
| if k not in self.subs: | |
| self.subs[k] = EpochMeter() | |
| self.subs[k].update(v, n) | |
| def asMinutes(s): | |
| m = math.floor(s / 60) | |
| s -= m * 60 | |
| return '%dm %ds' % (m, s) | |
| def timeSince(since, percent): | |
| now = time.time() | |
| s = now - since | |
| es = s / (percent) | |
| rs = es - s | |
| return '%s (remain %s)' % (asMinutes(s), asMinutes(rs)) | |
| def print_rank_0(message): | |
| if torch.distributed.is_initialized(): | |
| if torch.distributed.get_rank() == 0: | |
| print(message, flush=True) | |
| else: | |
| print(message, flush=True) | |
| def to_device(data, device): | |
| if torch.is_tensor(data): | |
| return data.to(device) | |
| if type(data) is list: | |
| return [to_device(v, device) for v in data] | |
| if type(data) is dict: | |
| return {k: to_device(v, device) for k, v in data.items()} | |
| def round_floats(o): | |
| if isinstance(o, float): | |
| return round(o, 3) | |
| if isinstance(o, dict): | |
| return {k: round_floats(v) for k, v in o.items()} | |
| if isinstance(o, (list, tuple)): | |
| return [round_floats(x) for x in o] | |
| return o | |
| def format_df(df): | |
| def _dumps(obj): | |
| if obj is None: | |
| return obj | |
| return json.dumps(round_floats(obj)).replace(" ", "") | |
| for field in ['node_coords', 'node_symbols', 'edges']: | |
| if field in df.columns: | |
| df[field] = [_dumps(obj) for obj in df[field]] | |
| return df | |