| import transformers | |
| from transformers.models.auto import CONFIG_MAPPING | |
| class CXRMateEDConfig(transformers.PretrainedConfig): | |
| model_type = 'cxrmate-ed' | |
| def __init__( | |
| self, | |
| vision_config=None, | |
| text_config=None, | |
| index_value_encoder_intermediate_size: int = 2048, | |
| include_time_delta: bool = True, | |
| time_delta_monotonic_inversion: bool = True, | |
| add_time_deltas: bool = True, | |
| history: int = 0, | |
| tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'], | |
| prompt_report_sections_filter: list = ['indication', 'history'], | |
| pad_token_id: int = 4, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.vision_config = vision_config | |
| self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size | |
| self.include_time_delta = include_time_delta | |
| self.time_delta_monotonic_inversion = time_delta_monotonic_inversion | |
| self.add_time_deltas = add_time_deltas | |
| self.history = history | |
| self.tables_filter = tables_filter | |
| self.prompt_report_sections_filter = prompt_report_sections_filter | |
| self.pad_token_id = pad_token_id | |
| if isinstance(vision_config, dict): | |
| vision_config = transformers.AutoConfig.from_pretrained( | |
| 'aehrc/uniformer_base_tl_384', | |
| trust_remote_code=True, | |
| **vision_config, | |
| ) | |
| self.vision_config = vision_config | |
| if isinstance(text_config, dict): | |
| text_config['model_type'] = text_config['model_type'] if 'model_type' in text_config else 'llama' | |
| text_config = CONFIG_MAPPING[text_config['model_type']](**text_config) | |
| self.text_config = text_config | |