Spaces:
Runtime error
Runtime error
| import copy | |
| from anonymous_demo.functional.config.config_manager import ConfigManager | |
| from anonymous_demo.core.tad.classic.__bert__.models import TADBERT | |
| _tad_config_template = { | |
| "model": TADBERT, | |
| "optimizer": "adamw", | |
| "learning_rate": 0.00002, | |
| "patience": 99999, | |
| "pretrained_bert": "microsoft/mdeberta-v3-base", | |
| "cache_dataset": True, | |
| "warmup_step": -1, | |
| "show_metric": False, | |
| "max_seq_len": 80, | |
| "dropout": 0, | |
| "l2reg": 0.000001, | |
| "num_epoch": 10, | |
| "batch_size": 16, | |
| "initializer": "xavier_uniform_", | |
| "seed": 52, | |
| "polarities_dim": 3, | |
| "log_step": 10, | |
| "evaluate_begin": 0, | |
| "cross_validate_fold": -1, | |
| "use_amp": False, | |
| # split train and test datasets into 5 folds and repeat 3 training | |
| } | |
| _tad_config_base = { | |
| "model": TADBERT, | |
| "optimizer": "adamw", | |
| "learning_rate": 0.00002, | |
| "pretrained_bert": "microsoft/deberta-v3-base", | |
| "cache_dataset": True, | |
| "warmup_step": -1, | |
| "show_metric": False, | |
| "max_seq_len": 80, | |
| "patience": 99999, | |
| "dropout": 0, | |
| "l2reg": 0.000001, | |
| "num_epoch": 10, | |
| "batch_size": 16, | |
| "initializer": "xavier_uniform_", | |
| "seed": 52, | |
| "polarities_dim": 3, | |
| "log_step": 10, | |
| "evaluate_begin": 0, | |
| "cross_validate_fold": -1 | |
| # split train and test datasets into 5 folds and repeat 3 training | |
| } | |
| _tad_config_english = { | |
| "model": TADBERT, | |
| "optimizer": "adamw", | |
| "learning_rate": 0.00002, | |
| "patience": 99999, | |
| "pretrained_bert": "microsoft/deberta-v3-base", | |
| "cache_dataset": True, | |
| "warmup_step": -1, | |
| "show_metric": False, | |
| "max_seq_len": 80, | |
| "dropout": 0, | |
| "l2reg": 0.000001, | |
| "num_epoch": 10, | |
| "batch_size": 16, | |
| "initializer": "xavier_uniform_", | |
| "seed": 52, | |
| "polarities_dim": 3, | |
| "log_step": 10, | |
| "evaluate_begin": 0, | |
| "cross_validate_fold": -1 | |
| # split train and test datasets into 5 folds and repeat 3 training | |
| } | |
| _tad_config_multilingual = { | |
| "model": TADBERT, | |
| "optimizer": "adamw", | |
| "learning_rate": 0.00002, | |
| "patience": 99999, | |
| "pretrained_bert": "microsoft/mdeberta-v3-base", | |
| "cache_dataset": True, | |
| "warmup_step": -1, | |
| "show_metric": False, | |
| "max_seq_len": 80, | |
| "dropout": 0, | |
| "l2reg": 0.000001, | |
| "num_epoch": 10, | |
| "batch_size": 16, | |
| "initializer": "xavier_uniform_", | |
| "seed": 52, | |
| "polarities_dim": 3, | |
| "log_step": 10, | |
| "evaluate_begin": 0, | |
| "cross_validate_fold": -1 | |
| # split train and test datasets into 5 folds and repeat 3 training | |
| } | |
| _tad_config_chinese = { | |
| "model": TADBERT, | |
| "optimizer": "adamw", | |
| "learning_rate": 0.00002, | |
| "patience": 99999, | |
| "cache_dataset": True, | |
| "warmup_step": -1, | |
| "show_metric": False, | |
| "pretrained_bert": "bert-base-chinese", | |
| "max_seq_len": 80, | |
| "dropout": 0, | |
| "l2reg": 0.000001, | |
| "num_epoch": 10, | |
| "batch_size": 16, | |
| "initializer": "xavier_uniform_", | |
| "seed": 52, | |
| "polarities_dim": 3, | |
| "log_step": 10, | |
| "evaluate_begin": 0, | |
| "cross_validate_fold": -1 | |
| # split train and test datasets into 5 folds and repeat 3 training | |
| } | |
| class TADConfigManager(ConfigManager): | |
| def __init__(self, args, **kwargs): | |
| """ | |
| Available Params: {'model': BERT, | |
| 'optimizer': "adamw", | |
| 'learning_rate': 0.00002, | |
| 'pretrained_bert': "roberta-base", | |
| 'cache_dataset': True, | |
| 'warmup_step': -1, | |
| 'show_metric': False, | |
| 'max_seq_len': 80, | |
| 'patience': 99999, | |
| 'dropout': 0, | |
| 'l2reg': 0.000001, | |
| 'num_epoch': 10, | |
| 'batch_size': 16, | |
| 'initializer': 'xavier_uniform_', | |
| 'seed': {52, 25} | |
| 'embed_dim': 768, | |
| 'hidden_dim': 768, | |
| 'polarities_dim': 3, | |
| 'log_step': 10, | |
| 'evaluate_begin': 0, | |
| 'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training | |
| } | |
| :param args: | |
| :param kwargs: | |
| """ | |
| super().__init__(args, **kwargs) | |
| def set_tad_config(configType: str, newitem: dict): | |
| if isinstance(newitem, dict): | |
| if configType == "template": | |
| _tad_config_template.update(newitem) | |
| elif configType == "base": | |
| _tad_config_base.update(newitem) | |
| elif configType == "english": | |
| _tad_config_english.update(newitem) | |
| elif configType == "chinese": | |
| _tad_config_chinese.update(newitem) | |
| elif configType == "multilingual": | |
| _tad_config_multilingual.update(newitem) | |
| elif configType == "glove": | |
| _tad_config_glove.update(newitem) | |
| else: | |
| raise ValueError( | |
| "Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove" | |
| ) | |
| else: | |
| raise TypeError( | |
| "Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}" | |
| ) | |
| def set_tad_config_template(newitem): | |
| TADConfigManager.set_tad_config("template", newitem) | |
| def set_tad_config_base(newitem): | |
| TADConfigManager.set_tad_config("base", newitem) | |
| def set_tad_config_english(newitem): | |
| TADConfigManager.set_tad_config("english", newitem) | |
| def set_tad_config_chinese(newitem): | |
| TADConfigManager.set_tad_config("chinese", newitem) | |
| def set_tad_config_multilingual(newitem): | |
| TADConfigManager.set_tad_config("multilingual", newitem) | |
| def set_tad_config_glove(newitem): | |
| TADConfigManager.set_tad_config("glove", newitem) | |
| def get_tad_config_template() -> ConfigManager: | |
| _tad_config_template.update(_tad_config_template) | |
| return TADConfigManager(copy.deepcopy(_tad_config_template)) | |
| def get_tad_config_base() -> ConfigManager: | |
| _tad_config_template.update(_tad_config_base) | |
| return TADConfigManager(copy.deepcopy(_tad_config_template)) | |
| def get_tad_config_english() -> ConfigManager: | |
| _tad_config_template.update(_tad_config_english) | |
| return TADConfigManager(copy.deepcopy(_tad_config_template)) | |
| def get_tad_config_chinese() -> ConfigManager: | |
| _tad_config_template.update(_tad_config_chinese) | |
| return TADConfigManager(copy.deepcopy(_tad_config_template)) | |
| def get_tad_config_multilingual() -> ConfigManager: | |
| _tad_config_template.update(_tad_config_multilingual) | |
| return TADConfigManager(copy.deepcopy(_tad_config_template)) | |
| def get_tad_config_glove() -> ConfigManager: | |
| _tad_config_template.update(_tad_config_glove) | |
| return TADConfigManager(copy.deepcopy(_tad_config_template)) | |