Spaces:
Sleeping
Sleeping
Commit
·
73aff8b
1
Parent(s):
ab66141
Add early_stop_condition to stop earlier
Browse files- pysr/sr.py +6 -1
pysr/sr.py
CHANGED
|
@@ -420,6 +420,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 420 |
cluster_manager=None,
|
| 421 |
skip_mutation_failures=True,
|
| 422 |
max_evals=None,
|
|
|
|
| 423 |
# To support deprecated kwargs:
|
| 424 |
**kwargs,
|
| 425 |
):
|
|
@@ -562,6 +563,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 562 |
:type skip_mutation_failures: bool
|
| 563 |
:param max_evals: Limits the total number of evaluations of expressions to this number.
|
| 564 |
:type max_evals: int
|
|
|
|
|
|
|
| 565 |
:param kwargs: Supports deprecated keyword arguments. Other arguments will result
|
| 566 |
in an error
|
| 567 |
:type kwargs: dict
|
|
@@ -749,6 +752,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 749 |
cluster_manager=cluster_manager,
|
| 750 |
skip_mutation_failures=skip_mutation_failures,
|
| 751 |
max_evals=max_evals,
|
|
|
|
| 752 |
),
|
| 753 |
}
|
| 754 |
|
|
@@ -1313,8 +1317,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 1313 |
progress=self.params["progress"],
|
| 1314 |
timeout_in_seconds=self.params["timeout_in_seconds"],
|
| 1315 |
crossoverProbability=self.params["crossover_probability"],
|
| 1316 |
-
max_evals=self.params["max_evals"],
|
| 1317 |
skip_mutation_failures=self.params["skip_mutation_failures"],
|
|
|
|
|
|
|
| 1318 |
)
|
| 1319 |
|
| 1320 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
|
|
|
|
| 420 |
cluster_manager=None,
|
| 421 |
skip_mutation_failures=True,
|
| 422 |
max_evals=None,
|
| 423 |
+
early_stop_condition=None,
|
| 424 |
# To support deprecated kwargs:
|
| 425 |
**kwargs,
|
| 426 |
):
|
|
|
|
| 563 |
:type skip_mutation_failures: bool
|
| 564 |
:param max_evals: Limits the total number of evaluations of expressions to this number.
|
| 565 |
:type max_evals: int
|
| 566 |
+
:param early_stop_condition: Stop the search early if this loss is reached.
|
| 567 |
+
:type early_stop_condition: float
|
| 568 |
:param kwargs: Supports deprecated keyword arguments. Other arguments will result
|
| 569 |
in an error
|
| 570 |
:type kwargs: dict
|
|
|
|
| 752 |
cluster_manager=cluster_manager,
|
| 753 |
skip_mutation_failures=skip_mutation_failures,
|
| 754 |
max_evals=max_evals,
|
| 755 |
+
early_stop_condition=early_stop_condition,
|
| 756 |
),
|
| 757 |
}
|
| 758 |
|
|
|
|
| 1317 |
progress=self.params["progress"],
|
| 1318 |
timeout_in_seconds=self.params["timeout_in_seconds"],
|
| 1319 |
crossoverProbability=self.params["crossover_probability"],
|
|
|
|
| 1320 |
skip_mutation_failures=self.params["skip_mutation_failures"],
|
| 1321 |
+
max_evals=self.params["max_evals"],
|
| 1322 |
+
earlyStopCondition=self.params["early_stop_condition"],
|
| 1323 |
)
|
| 1324 |
|
| 1325 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
|