Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| """ | |
| A model maps "text_only" data to float. | |
| """ | |
| from lmflow.models.regression_model import RegressionModel | |
| from lmflow.datasets.dataset import Dataset | |
| class TextRegressionModel(RegressionModel): | |
| r""" | |
| Initializes a TextRegressionModel instance. | |
| Parameters | |
| ------------ | |
| model_args : | |
| Model arguments such as model name, path, revision, etc. | |
| args : Optional. | |
| Positional arguments. | |
| kwargs : Optional. | |
| Keyword arguments. | |
| """ | |
| def __init__( | |
| self, | |
| model_args, | |
| *args, | |
| **kwargs | |
| ): | |
| """ | |
| Initializes a TextRegressionModel instance. | |
| :param model_args: dictionary with model arguments such as model name, path, revision, etc. | |
| """ | |
| self.inference_func = None | |
| def register_inference_function(self, inference_func): | |
| """ | |
| Registers a regression function. | |
| """ | |
| self.inference_func = inference_func | |
| def inference(self, inputs: Dataset): | |
| """ | |
| Gets regression results of a given dataset. | |
| :inputs: Dataset object, only accept type "text_only". | |
| """ | |
| if self.inference_func is not None: | |
| return self.inference_func(inputs) | |
| else: | |
| pass | |