Spaces:
Sleeping
Sleeping
tttc3
commited on
Commit
·
6881818
1
Parent(s):
3e8d44d
Updated parameter validation
Browse files- pysr/sr.py +112 -97
pysr/sr.py
CHANGED
|
@@ -529,6 +529,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 529 |
List of indices for input features that are selected when
|
| 530 |
:param`select_k_features` is set.
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
|
| 533 |
The state for the julia SymbolicRegression.jl backend post fitting.
|
| 534 |
|
|
@@ -928,6 +934,71 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 928 |
else:
|
| 929 |
self.equation_file_ = self.equation_file
|
| 930 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
def _validate_fit_params(self, X, y, Xresampled, variable_names):
|
| 932 |
"""
|
| 933 |
Validates the parameters passed to the :term`fit` method.
|
|
@@ -965,39 +1036,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 965 |
|
| 966 |
"""
|
| 967 |
|
| 968 |
-
# Ensure instance parameters are allowable values:
|
| 969 |
-
if self.tournament_selection_n > self.population_size:
|
| 970 |
-
raise ValueError(
|
| 971 |
-
"tournament_selection_n parameter must be smaller than population_size."
|
| 972 |
-
)
|
| 973 |
-
|
| 974 |
-
if self.maxsize > 40:
|
| 975 |
-
warnings.warn(
|
| 976 |
-
"Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
| 977 |
-
)
|
| 978 |
-
elif self.maxsize < 7:
|
| 979 |
-
raise ValueError("PySR requires a maxsize of at least 7")
|
| 980 |
-
|
| 981 |
-
if self.extra_jax_mappings is not None:
|
| 982 |
-
for value in self.extra_jax_mappings.values():
|
| 983 |
-
if not isinstance(value, str):
|
| 984 |
-
raise ValueError(
|
| 985 |
-
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 986 |
-
)
|
| 987 |
-
|
| 988 |
-
if self.extra_torch_mappings is not None:
|
| 989 |
-
for value in self.extra_jax_mappings.values():
|
| 990 |
-
if not callable(value):
|
| 991 |
-
raise ValueError(
|
| 992 |
-
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
| 993 |
-
)
|
| 994 |
-
|
| 995 |
-
# NotImplementedError - Values that could be supported at a later time
|
| 996 |
-
if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
|
| 997 |
-
raise NotImplementedError(
|
| 998 |
-
f"PySR currently only supports the following optimizer algorithms: {VALID_OPTIMIZER_ALGORITHMS}"
|
| 999 |
-
)
|
| 1000 |
-
|
| 1001 |
if isinstance(X, pd.DataFrame):
|
| 1002 |
if variable_names:
|
| 1003 |
variable_names = None
|
|
@@ -1020,13 +1058,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1020 |
"Spaces have been replaced with underscores. \n"
|
| 1021 |
"Please use valid names instead."
|
| 1022 |
)
|
| 1023 |
-
# Only numpy values are needed from Xresampled, column metadata is
|
| 1024 |
-
# provided by X
|
| 1025 |
-
if isinstance(Xresampled, pd.DataFrame):
|
| 1026 |
-
Xresampled = Xresampled.values
|
| 1027 |
|
| 1028 |
# Data validation and feature name fetching via sklearn
|
| 1029 |
# This method sets the n_features_in_ attribute
|
|
|
|
| 1030 |
X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
|
| 1031 |
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
|
| 1032 |
variable_names = self.feature_names_in_
|
|
@@ -1126,7 +1161,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1126 |
|
| 1127 |
return X, y, variable_names
|
| 1128 |
|
| 1129 |
-
def _run(self, X, y, weights, seed):
|
| 1130 |
"""
|
| 1131 |
Run the symbolic regression fitting process on the julia backend.
|
| 1132 |
|
|
@@ -1138,10 +1173,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1138 |
y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
|
| 1139 |
Target values. Will be cast to X's dtype if necessary.
|
| 1140 |
|
| 1141 |
-
|
|
|
|
|
|
|
|
|
|
| 1142 |
Each element is how to weight the mean-square-error loss
|
| 1143 |
for that particular element of y.
|
| 1144 |
|
|
|
|
|
|
|
|
|
|
| 1145 |
Returns
|
| 1146 |
-------
|
| 1147 |
self : object
|
|
@@ -1159,66 +1200,17 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1159 |
|
| 1160 |
# These are the parameters which may be modified from the ones
|
| 1161 |
# specified in init, so we define them here locally:
|
| 1162 |
-
binary_operators =
|
| 1163 |
-
unary_operators =
|
| 1164 |
-
|
|
|
|
| 1165 |
nested_constraints = self.nested_constraints
|
| 1166 |
complexity_of_operators = self.complexity_of_operators
|
| 1167 |
-
multithreading =
|
| 1168 |
-
update_verbosity = self.update_verbosity
|
| 1169 |
-
maxdepth = self.maxdepth
|
| 1170 |
-
batch_size = self.batch_size
|
| 1171 |
-
progress = self.progress
|
| 1172 |
cluster_manager = self.cluster_manager
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
# Deal with default values, and type conversions:
|
| 1178 |
-
if binary_operators is None:
|
| 1179 |
-
binary_operators = "+ * - /".split(" ")
|
| 1180 |
-
elif isinstance(binary_operators, str):
|
| 1181 |
-
binary_operators = [binary_operators]
|
| 1182 |
-
|
| 1183 |
-
if unary_operators is None:
|
| 1184 |
-
unary_operators = []
|
| 1185 |
-
elif isinstance(unary_operators, str):
|
| 1186 |
-
unary_operators = [unary_operators]
|
| 1187 |
-
|
| 1188 |
-
assert len(unary_operators) + len(binary_operators) > 0
|
| 1189 |
-
|
| 1190 |
-
if constraints is None:
|
| 1191 |
-
constraints = {}
|
| 1192 |
-
|
| 1193 |
-
if multithreading is None:
|
| 1194 |
-
# Default is multithreading=True, unless explicitly set,
|
| 1195 |
-
# or procs is set to 0 (serial mode).
|
| 1196 |
-
multithreading = self.procs != 0 and cluster_manager is None
|
| 1197 |
-
|
| 1198 |
-
if update_verbosity is None:
|
| 1199 |
-
update_verbosity = self.verbosity
|
| 1200 |
-
|
| 1201 |
-
if maxdepth is None:
|
| 1202 |
-
maxdepth = self.maxsize
|
| 1203 |
-
|
| 1204 |
-
# Warn if instance parameters are not sensible values:
|
| 1205 |
-
if batch_size < 1:
|
| 1206 |
-
warnings.warn(
|
| 1207 |
-
"Given :param`batch_size` must be greater than or equal to one. "
|
| 1208 |
-
":param`batch_size` has been increased to equal one."
|
| 1209 |
-
)
|
| 1210 |
-
batch_size = 1
|
| 1211 |
-
|
| 1212 |
-
# Handle presentation of the progress bar:
|
| 1213 |
-
buffer_available = "buffer" in sys.stdout.__dir__()
|
| 1214 |
-
if progress is not None:
|
| 1215 |
-
if progress and not buffer_available:
|
| 1216 |
-
warnings.warn(
|
| 1217 |
-
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
| 1218 |
-
)
|
| 1219 |
-
progress = False
|
| 1220 |
-
else:
|
| 1221 |
-
progress = buffer_available
|
| 1222 |
|
| 1223 |
# Start julia backend processes
|
| 1224 |
if Main is None:
|
|
@@ -1455,6 +1447,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1455 |
|
| 1456 |
self._setup_equation_file()
|
| 1457 |
|
|
|
|
|
|
|
| 1458 |
# Parameter input validation (for parameters defined in __init__)
|
| 1459 |
X, y, Xresampled, variable_names = self._validate_fit_params(
|
| 1460 |
X, y, Xresampled, variable_names
|
|
@@ -1505,7 +1499,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1505 |
)
|
| 1506 |
|
| 1507 |
# Fitting procedure
|
| 1508 |
-
return self._run(X
|
| 1509 |
|
| 1510 |
def refresh(self, checkpoint_file=None):
|
| 1511 |
"""
|
|
@@ -1736,6 +1730,27 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1736 |
"Couldn't find equation file! The equation search likely exited before a single iteration completed."
|
| 1737 |
)
|
| 1738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1739 |
ret_outputs = []
|
| 1740 |
|
| 1741 |
for output in all_outputs:
|
|
|
|
| 529 |
List of indices for input features that are selected when
|
| 530 |
:param`select_k_features` is set.
|
| 531 |
|
| 532 |
+
tempdir_ : Path
|
| 533 |
+
Path to the temporary equations directory.
|
| 534 |
+
|
| 535 |
+
equation_file_ : str
|
| 536 |
+
Output equation file name produced by the julia backend.
|
| 537 |
+
|
| 538 |
raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
|
| 539 |
The state for the julia SymbolicRegression.jl backend post fitting.
|
| 540 |
|
|
|
|
| 934 |
else:
|
| 935 |
self.equation_file_ = self.equation_file
|
| 936 |
|
| 937 |
+
def _validate_init_params(self):
|
| 938 |
+
|
| 939 |
+
# Immutable parameter validation
|
| 940 |
+
# Ensure instance parameters are allowable values:
|
| 941 |
+
if self.tournament_selection_n > self.population_size:
|
| 942 |
+
raise ValueError(
|
| 943 |
+
"tournament_selection_n parameter must be smaller than population_size."
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
if self.maxsize > 40:
|
| 947 |
+
warnings.warn(
|
| 948 |
+
"Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
| 949 |
+
)
|
| 950 |
+
elif self.maxsize < 7:
|
| 951 |
+
raise ValueError("PySR requires a maxsize of at least 7")
|
| 952 |
+
|
| 953 |
+
# NotImplementedError - Values that could be supported at a later time
|
| 954 |
+
if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
|
| 955 |
+
raise NotImplementedError(
|
| 956 |
+
f"PySR currently only supports the following optimizer algorithms: {VALID_OPTIMIZER_ALGORITHMS}"
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
# 'Mutable' parameter validation
|
| 960 |
+
buffer_available = "buffer" in sys.stdout.__dir__()
|
| 961 |
+
modifiable_params = {
|
| 962 |
+
"binary_operators": "+ * - /".split(" "),
|
| 963 |
+
"unary_operators": [],
|
| 964 |
+
"maxdepth": self.maxsize,
|
| 965 |
+
"constraints": {},
|
| 966 |
+
"multithreading": self.procs != 0 and self.cluster_manager is None,
|
| 967 |
+
"batch_size": 1,
|
| 968 |
+
"update_verbosity": self.verbosity,
|
| 969 |
+
"progress": buffer_available,
|
| 970 |
+
}
|
| 971 |
+
packed_modified_params = {}
|
| 972 |
+
for parameter, default_value in modifiable_params.items():
|
| 973 |
+
parameter_value = getattr(self, parameter)
|
| 974 |
+
if parameter_value is None:
|
| 975 |
+
parameter_value = default_value
|
| 976 |
+
else:
|
| 977 |
+
# Special cases such as when binary_operators is a string
|
| 978 |
+
if parameter in ["binary_operators", "unary_operators"] and isinstance(
|
| 979 |
+
parameter_value, str
|
| 980 |
+
):
|
| 981 |
+
parameter_value = [parameter_value]
|
| 982 |
+
elif parameter is "batch_size" and parameter_value < 1:
|
| 983 |
+
warnings.warn(
|
| 984 |
+
"Given :param`batch_size` must be greater than or equal to one. "
|
| 985 |
+
":param`batch_size` has been increased to equal one."
|
| 986 |
+
)
|
| 987 |
+
parameter_value = 1
|
| 988 |
+
elif parameter is "progress" and not buffer_available:
|
| 989 |
+
warnings.warn(
|
| 990 |
+
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
| 991 |
+
)
|
| 992 |
+
parameter_value = False
|
| 993 |
+
packed_modified_params[parameter] = parameter_value
|
| 994 |
+
|
| 995 |
+
assert (
|
| 996 |
+
len(packed_modified_params["binary_operators"])
|
| 997 |
+
+ len(packed_modified_params["unary_operators"])
|
| 998 |
+
> 0
|
| 999 |
+
)
|
| 1000 |
+
return packed_modified_params
|
| 1001 |
+
|
| 1002 |
def _validate_fit_params(self, X, y, Xresampled, variable_names):
|
| 1003 |
"""
|
| 1004 |
Validates the parameters passed to the :term`fit` method.
|
|
|
|
| 1036 |
|
| 1037 |
"""
|
| 1038 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1039 |
if isinstance(X, pd.DataFrame):
|
| 1040 |
if variable_names:
|
| 1041 |
variable_names = None
|
|
|
|
| 1058 |
"Spaces have been replaced with underscores. \n"
|
| 1059 |
"Please use valid names instead."
|
| 1060 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1061 |
|
| 1062 |
# Data validation and feature name fetching via sklearn
|
| 1063 |
# This method sets the n_features_in_ attribute
|
| 1064 |
+
Xresampled = check_array(Xresampled)
|
| 1065 |
X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
|
| 1066 |
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
|
| 1067 |
variable_names = self.feature_names_in_
|
|
|
|
| 1161 |
|
| 1162 |
return X, y, variable_names
|
| 1163 |
|
| 1164 |
+
def _run(self, X, y, mutated_params, weights, seed):
|
| 1165 |
"""
|
| 1166 |
Run the symbolic regression fitting process on the julia backend.
|
| 1167 |
|
|
|
|
| 1173 |
y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
|
| 1174 |
Target values. Will be cast to X's dtype if necessary.
|
| 1175 |
|
| 1176 |
+
mutated_params : dict[str, Any]
|
| 1177 |
+
Dictionary of mutated versions of some parameters passed in __init__.
|
| 1178 |
+
|
| 1179 |
+
weights : {ndarray | pandas.DataFrame} of the same shape as y
|
| 1180 |
Each element is how to weight the mean-square-error loss
|
| 1181 |
for that particular element of y.
|
| 1182 |
|
| 1183 |
+
seed : int
|
| 1184 |
+
Random seed for julia backend process.
|
| 1185 |
+
|
| 1186 |
Returns
|
| 1187 |
-------
|
| 1188 |
self : object
|
|
|
|
| 1200 |
|
| 1201 |
# These are the parameters which may be modified from the ones
|
| 1202 |
# specified in init, so we define them here locally:
|
| 1203 |
+
binary_operators = mutated_params["binary_operators"]
|
| 1204 |
+
unary_operators = mutated_params["unary_operators"]
|
| 1205 |
+
maxdepth = mutated_params["maxdepth"]
|
| 1206 |
+
constraints = mutated_params["constraints"]
|
| 1207 |
nested_constraints = self.nested_constraints
|
| 1208 |
complexity_of_operators = self.complexity_of_operators
|
| 1209 |
+
multithreading = mutated_params["multithreading"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1210 |
cluster_manager = self.cluster_manager
|
| 1211 |
+
batch_size = mutated_params["batch_size"]
|
| 1212 |
+
update_verbosity = mutated_params["update_verbosity"]
|
| 1213 |
+
progress = mutated_params["progress"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1214 |
|
| 1215 |
# Start julia backend processes
|
| 1216 |
if Main is None:
|
|
|
|
| 1447 |
|
| 1448 |
self._setup_equation_file()
|
| 1449 |
|
| 1450 |
+
mutated_params = self._validate_init_params()
|
| 1451 |
+
|
| 1452 |
# Parameter input validation (for parameters defined in __init__)
|
| 1453 |
X, y, Xresampled, variable_names = self._validate_fit_params(
|
| 1454 |
X, y, Xresampled, variable_names
|
|
|
|
| 1499 |
)
|
| 1500 |
|
| 1501 |
# Fitting procedure
|
| 1502 |
+
return self._run(X, y, mutated_params, weights=weights, seed=seed)
|
| 1503 |
|
| 1504 |
def refresh(self, checkpoint_file=None):
|
| 1505 |
"""
|
|
|
|
| 1730 |
"Couldn't find equation file! The equation search likely exited before a single iteration completed."
|
| 1731 |
)
|
| 1732 |
|
| 1733 |
+
# It is expected extra_jax/torch_mappings will be updated after fit.
|
| 1734 |
+
# Thus, validation is performed here instead of in _validate_init_params
|
| 1735 |
+
extra_jax_mappings = self.extra_jax_mappings
|
| 1736 |
+
extra_torch_mappings = self.extra_torch_mappings
|
| 1737 |
+
if extra_jax_mappings is not None:
|
| 1738 |
+
for value in self.extra_jax_mappings.values():
|
| 1739 |
+
if not isinstance(value, str):
|
| 1740 |
+
raise ValueError(
|
| 1741 |
+
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 1742 |
+
)
|
| 1743 |
+
else:
|
| 1744 |
+
extra_jax_mappings = {}
|
| 1745 |
+
if extra_torch_mappings is not None:
|
| 1746 |
+
for value in self.extra_jax_mappings.values():
|
| 1747 |
+
if not callable(value):
|
| 1748 |
+
raise ValueError(
|
| 1749 |
+
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
| 1750 |
+
)
|
| 1751 |
+
else:
|
| 1752 |
+
extra_torch_mappings = {}
|
| 1753 |
+
|
| 1754 |
ret_outputs = []
|
| 1755 |
|
| 1756 |
for output in all_outputs:
|