Spaces:
Running
Running
tttc3
commited on
Commit
·
c51257e
1
Parent(s):
4b56660
Fixed weight checking
Browse files- pysr/sr.py +4 -4
pysr/sr.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import sys
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
-
from sklearn.utils import check_array, check_random_state
|
| 6 |
import sympy
|
| 7 |
from sympy import sympify
|
| 8 |
import re
|
|
@@ -15,7 +15,7 @@ from multiprocessing import cpu_count
|
|
| 15 |
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
|
| 16 |
from sklearn.utils.validation import (
|
| 17 |
_check_feature_names_in,
|
| 18 |
-
|
| 19 |
check_is_fitted,
|
| 20 |
)
|
| 21 |
|
|
@@ -1073,7 +1073,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1073 |
if Xresampled is not None:
|
| 1074 |
Xresampled = check_array(Xresampled)
|
| 1075 |
if weights is not None:
|
| 1076 |
-
weights =
|
|
|
|
| 1077 |
X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
|
| 1078 |
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
|
| 1079 |
variable_names = self.feature_names_in_
|
|
@@ -1461,7 +1462,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1461 |
|
| 1462 |
mutated_params = self._validate_init_params()
|
| 1463 |
|
| 1464 |
-
# Parameter input validation (for parameters defined in __init__)
|
| 1465 |
X, y, Xresampled, weights, variable_names = self._validate_fit_params(
|
| 1466 |
X, y, Xresampled, weights, variable_names
|
| 1467 |
)
|
|
|
|
| 2 |
import sys
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
+
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
| 6 |
import sympy
|
| 7 |
from sympy import sympify
|
| 8 |
import re
|
|
|
|
| 15 |
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
|
| 16 |
from sklearn.utils.validation import (
|
| 17 |
_check_feature_names_in,
|
| 18 |
+
check_X_y,
|
| 19 |
check_is_fitted,
|
| 20 |
)
|
| 21 |
|
|
|
|
| 1073 |
if Xresampled is not None:
|
| 1074 |
Xresampled = check_array(Xresampled)
|
| 1075 |
if weights is not None:
|
| 1076 |
+
weights = check_array(weights)
|
| 1077 |
+
check_consistent_length(weights, y)
|
| 1078 |
X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
|
| 1079 |
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
|
| 1080 |
variable_names = self.feature_names_in_
|
|
|
|
| 1462 |
|
| 1463 |
mutated_params = self._validate_init_params()
|
| 1464 |
|
|
|
|
| 1465 |
X, y, Xresampled, weights, variable_names = self._validate_fit_params(
|
| 1466 |
X, y, Xresampled, weights, variable_names
|
| 1467 |
)
|