Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| """Automatically get correct model type. | |
| """ | |
| from lmflow.models.hf_decoder_model import HFDecoderModel | |
| from lmflow.models.text_regression_model import TextRegressionModel | |
| from lmflow.models.hf_encoder_decoder_model import HFEncoderDecoderModel | |
| class AutoModel: | |
| def get_model(self, model_args, *args, **kwargs): | |
| arch_type = model_args.arch_type | |
| if arch_type == "decoder_only": | |
| return HFDecoderModel(model_args, *args, **kwargs) | |
| elif arch_type == "text_regression": | |
| return TextRegressionModel(model_args, *args, **kwargs) | |
| elif arch_type == "encoder_decoder": | |
| return HFEncoderDecoderModel(model_args, *args, **kwargs) | |
| else: | |
| raise NotImplementedError( | |
| f"model architecture type \"{arch_type}\" is not supported" | |
| ) | |