Spaces:
Running
Running
tttc3
commited on
Commit
·
ad1c492
1
Parent(s):
3182a3b
Addressed some DeepSource issues
Browse files- pysr/sr.py +44 -30
pysr/sr.py
CHANGED
|
@@ -2,7 +2,6 @@ import os
|
|
| 2 |
import sys
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
-
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
| 6 |
import sympy
|
| 7 |
from sympy import sympify
|
| 8 |
import re
|
|
@@ -13,6 +12,7 @@ from datetime import datetime
|
|
| 13 |
import warnings
|
| 14 |
from multiprocessing import cpu_count
|
| 15 |
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
|
|
|
|
| 16 |
from sklearn.utils.validation import (
|
| 17 |
_check_feature_names_in,
|
| 18 |
check_is_fitted,
|
|
@@ -76,7 +76,8 @@ sympy_mappings = {
|
|
| 76 |
|
| 77 |
def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
| 78 |
warnings.warn(
|
| 79 |
-
"Calling `pysr` is deprecated.
|
|
|
|
| 80 |
FutureWarning,
|
| 81 |
)
|
| 82 |
model = PySRRegressor(**kwargs)
|
|
@@ -95,7 +96,8 @@ def _process_constraints(binary_operators, unary_operators, constraints):
|
|
| 95 |
if op in ["plus", "sub", "+", "-"]:
|
| 96 |
if constraints[op][0] != constraints[op][1]:
|
| 97 |
raise NotImplementedError(
|
| 98 |
-
"You need equal constraints on both sides for - and +,
|
|
|
|
| 99 |
)
|
| 100 |
elif op in ["mult", "*"]:
|
| 101 |
# Make sure the complex expression is in the left side.
|
|
@@ -128,7 +130,8 @@ def _maybe_create_inline_operators(binary_operators, unary_operators):
|
|
| 128 |
if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
|
| 129 |
raise ValueError(
|
| 130 |
f"Invalid function name {function_name}. "
|
| 131 |
-
"Only alphanumeric characters, numbers,
|
|
|
|
| 132 |
)
|
| 133 |
op_list[i] = function_name
|
| 134 |
return binary_operators, unary_operators
|
|
@@ -154,25 +157,32 @@ def _check_assertions(
|
|
| 154 |
|
| 155 |
def best(*args, **kwargs): # pragma: no cover
|
| 156 |
raise NotImplementedError(
|
| 157 |
-
"`best` has been deprecated. Please use the `PySRRegressor` interface.
|
|
|
|
|
|
|
| 158 |
)
|
| 159 |
|
| 160 |
|
| 161 |
def best_row(*args, **kwargs): # pragma: no cover
|
| 162 |
raise NotImplementedError(
|
| 163 |
-
"`best_row` has been deprecated. Please use the `PySRRegressor` interface.
|
|
|
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
|
| 167 |
def best_tex(*args, **kwargs): # pragma: no cover
|
| 168 |
raise NotImplementedError(
|
| 169 |
-
"`best_tex` has been deprecated. Please use the `PySRRegressor` interface.
|
|
|
|
|
|
|
| 170 |
)
|
| 171 |
|
| 172 |
|
| 173 |
def best_callable(*args, **kwargs): # pragma: no cover
|
| 174 |
raise NotImplementedError(
|
| 175 |
-
"`best_callable` has been deprecated. Please use the `PySRRegressor`
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
|
|
@@ -775,7 +785,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 775 |
setattr(self, updated_kwarg_name, v)
|
| 776 |
warnings.warn(
|
| 777 |
f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
|
| 778 |
-
"
|
| 779 |
FutureWarning,
|
| 780 |
)
|
| 781 |
# Handle kwargs that have been moved to the fit method
|
|
@@ -787,7 +797,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 787 |
)
|
| 788 |
else:
|
| 789 |
raise TypeError(
|
| 790 |
-
f"{k} is not a valid keyword argument for PySRRegressor"
|
| 791 |
)
|
| 792 |
|
| 793 |
def __repr__(self):
|
|
@@ -964,7 +974,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 964 |
values. For example, default parameters are set here
|
| 965 |
when a parameter is left set to `None`.
|
| 966 |
"""
|
| 967 |
-
|
| 968 |
# Immutable parameter validation
|
| 969 |
# Ensure instance parameters are allowable values:
|
| 970 |
if self.tournament_selection_n > self.population_size:
|
|
@@ -974,27 +983,29 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 974 |
|
| 975 |
if self.maxsize > 40:
|
| 976 |
warnings.warn(
|
| 977 |
-
"Note: Using a large maxsize for the equation search will be
|
|
|
|
|
|
|
| 978 |
)
|
| 979 |
elif self.maxsize < 7:
|
| 980 |
raise ValueError("PySR requires a maxsize of at least 7")
|
| 981 |
|
| 982 |
-
if self.deterministic
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
)
|
| 992 |
|
| 993 |
-
if self.random_state
|
|
|
|
|
|
|
| 994 |
warnings.warn(
|
| 995 |
"Note: Setting `random_state` without also setting `deterministic` "
|
| 996 |
-
"to True and `procs` to 0 "
|
| 997 |
-
"will result in non-deterministic searches. "
|
| 998 |
)
|
| 999 |
|
| 1000 |
# NotImplementedError - Values that could be supported at a later time
|
|
@@ -1035,7 +1046,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1035 |
parameter_value = 1
|
| 1036 |
elif parameter == "progress" and not buffer_available:
|
| 1037 |
warnings.warn(
|
| 1038 |
-
"Note: it looks like you are running in Jupyter.
|
|
|
|
| 1039 |
)
|
| 1040 |
parameter_value = False
|
| 1041 |
packed_modified_params[parameter] = parameter_value
|
|
@@ -1087,7 +1099,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1087 |
Validated list of variable names for each feature in `X`.
|
| 1088 |
|
| 1089 |
"""
|
| 1090 |
-
|
| 1091 |
if isinstance(X, pd.DataFrame):
|
| 1092 |
if variable_names:
|
| 1093 |
variable_names = None
|
|
@@ -1803,7 +1814,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1803 |
)
|
| 1804 |
except FileNotFoundError:
|
| 1805 |
raise RuntimeError(
|
| 1806 |
-
"Couldn't find equation file! The equation search likely exited
|
|
|
|
| 1807 |
)
|
| 1808 |
|
| 1809 |
# It is expected extra_jax/torch_mappings will be updated after fit.
|
|
@@ -1814,7 +1826,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1814 |
for value in extra_jax_mappings.values():
|
| 1815 |
if not isinstance(value, str):
|
| 1816 |
raise ValueError(
|
| 1817 |
-
"extra_jax_mappings must have keys that are strings!
|
|
|
|
| 1818 |
)
|
| 1819 |
else:
|
| 1820 |
extra_jax_mappings = {}
|
|
@@ -1822,7 +1835,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1822 |
for value in extra_jax_mappings.values():
|
| 1823 |
if not callable(value):
|
| 1824 |
raise ValueError(
|
| 1825 |
-
"extra_torch_mappings must be callable functions!
|
|
|
|
| 1826 |
)
|
| 1827 |
else:
|
| 1828 |
extra_torch_mappings = {}
|
|
|
|
| 2 |
import sys
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
|
|
|
| 5 |
import sympy
|
| 6 |
from sympy import sympify
|
| 7 |
import re
|
|
|
|
| 12 |
import warnings
|
| 13 |
from multiprocessing import cpu_count
|
| 14 |
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
|
| 15 |
+
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
| 16 |
from sklearn.utils.validation import (
|
| 17 |
_check_feature_names_in,
|
| 18 |
check_is_fitted,
|
|
|
|
| 76 |
|
| 77 |
def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
| 78 |
warnings.warn(
|
| 79 |
+
"Calling `pysr` is deprecated. "
|
| 80 |
+
"Please use `model = PySRRegressor(**params); model.fit(X, y)` going forward.",
|
| 81 |
FutureWarning,
|
| 82 |
)
|
| 83 |
model = PySRRegressor(**kwargs)
|
|
|
|
| 96 |
if op in ["plus", "sub", "+", "-"]:
|
| 97 |
if constraints[op][0] != constraints[op][1]:
|
| 98 |
raise NotImplementedError(
|
| 99 |
+
"You need equal constraints on both sides for - and +, "
|
| 100 |
+
"due to simplification strategies."
|
| 101 |
)
|
| 102 |
elif op in ["mult", "*"]:
|
| 103 |
# Make sure the complex expression is in the left side.
|
|
|
|
| 130 |
if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
|
| 131 |
raise ValueError(
|
| 132 |
f"Invalid function name {function_name}. "
|
| 133 |
+
"Only alphanumeric characters, numbers, "
|
| 134 |
+
"and underscores are allowed."
|
| 135 |
)
|
| 136 |
op_list[i] = function_name
|
| 137 |
return binary_operators, unary_operators
|
|
|
|
| 157 |
|
| 158 |
def best(*args, **kwargs): # pragma: no cover
|
| 159 |
raise NotImplementedError(
|
| 160 |
+
"`best` has been deprecated. Please use the `PySRRegressor` interface. "
|
| 161 |
+
"After fitting, you can return `.sympy()` to get the sympy representation "
|
| 162 |
+
"of the best equation."
|
| 163 |
)
|
| 164 |
|
| 165 |
|
| 166 |
def best_row(*args, **kwargs): # pragma: no cover
|
| 167 |
raise NotImplementedError(
|
| 168 |
+
"`best_row` has been deprecated. Please use the `PySRRegressor` interface. "
|
| 169 |
+
"After fitting, you can run `print(model)` to view the best equation, or "
|
| 170 |
+
"`model.get_best()` to return the best equation's row in `model.equations`."
|
| 171 |
)
|
| 172 |
|
| 173 |
|
| 174 |
def best_tex(*args, **kwargs): # pragma: no cover
|
| 175 |
raise NotImplementedError(
|
| 176 |
+
"`best_tex` has been deprecated. Please use the `PySRRegressor` interface. "
|
| 177 |
+
"After fitting, you can return `.latex()` to get the sympy representation "
|
| 178 |
+
"of the best equation."
|
| 179 |
)
|
| 180 |
|
| 181 |
|
| 182 |
def best_callable(*args, **kwargs): # pragma: no cover
|
| 183 |
raise NotImplementedError(
|
| 184 |
+
"`best_callable` has been deprecated. Please use the `PySRRegressor` "
|
| 185 |
+
"interface. After fitting, you can use `.predict(X)` to use the best callable."
|
| 186 |
)
|
| 187 |
|
| 188 |
|
|
|
|
| 785 |
setattr(self, updated_kwarg_name, v)
|
| 786 |
warnings.warn(
|
| 787 |
f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
|
| 788 |
+
"Please use that instead.",
|
| 789 |
FutureWarning,
|
| 790 |
)
|
| 791 |
# Handle kwargs that have been moved to the fit method
|
|
|
|
| 797 |
)
|
| 798 |
else:
|
| 799 |
raise TypeError(
|
| 800 |
+
f"{k} is not a valid keyword argument for PySRRegressor."
|
| 801 |
)
|
| 802 |
|
| 803 |
def __repr__(self):
|
|
|
|
| 974 |
values. For example, default parameters are set here
|
| 975 |
when a parameter is left set to `None`.
|
| 976 |
"""
|
|
|
|
| 977 |
# Immutable parameter validation
|
| 978 |
# Ensure instance parameters are allowable values:
|
| 979 |
if self.tournament_selection_n > self.population_size:
|
|
|
|
| 983 |
|
| 984 |
if self.maxsize > 40:
|
| 985 |
warnings.warn(
|
| 986 |
+
"Note: Using a large maxsize for the equation search will be "
|
| 987 |
+
"exponentially slower and use significant memory. You should consider "
|
| 988 |
+
"turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
| 989 |
)
|
| 990 |
elif self.maxsize < 7:
|
| 991 |
raise ValueError("PySR requires a maxsize of at least 7")
|
| 992 |
|
| 993 |
+
if self.deterministic and not (
|
| 994 |
+
self.multithreading in [False, None]
|
| 995 |
+
and self.procs == 0
|
| 996 |
+
and self.random_state is not None
|
| 997 |
+
):
|
| 998 |
+
raise ValueError(
|
| 999 |
+
"To ensure deterministic searches, you must set `random_state` to a seed, "
|
| 1000 |
+
"`procs` to `0`, and `multithreading` to `False` or `None`."
|
| 1001 |
+
)
|
|
|
|
| 1002 |
|
| 1003 |
+
if self.random_state is not None and (
|
| 1004 |
+
not self.deterministic or self.procs != 0
|
| 1005 |
+
):
|
| 1006 |
warnings.warn(
|
| 1007 |
"Note: Setting `random_state` without also setting `deterministic` "
|
| 1008 |
+
"to True and `procs` to 0 will result in non-deterministic searches. "
|
|
|
|
| 1009 |
)
|
| 1010 |
|
| 1011 |
# NotImplementedError - Values that could be supported at a later time
|
|
|
|
| 1046 |
parameter_value = 1
|
| 1047 |
elif parameter == "progress" and not buffer_available:
|
| 1048 |
warnings.warn(
|
| 1049 |
+
"Note: it looks like you are running in Jupyter. "
|
| 1050 |
+
"The progress bar will be turned off."
|
| 1051 |
)
|
| 1052 |
parameter_value = False
|
| 1053 |
packed_modified_params[parameter] = parameter_value
|
|
|
|
| 1099 |
Validated list of variable names for each feature in `X`.
|
| 1100 |
|
| 1101 |
"""
|
|
|
|
| 1102 |
if isinstance(X, pd.DataFrame):
|
| 1103 |
if variable_names:
|
| 1104 |
variable_names = None
|
|
|
|
| 1814 |
)
|
| 1815 |
except FileNotFoundError:
|
| 1816 |
raise RuntimeError(
|
| 1817 |
+
"Couldn't find equation file! The equation search likely exited "
|
| 1818 |
+
"before a single iteration completed."
|
| 1819 |
)
|
| 1820 |
|
| 1821 |
# It is expected extra_jax/torch_mappings will be updated after fit.
|
|
|
|
| 1826 |
for value in extra_jax_mappings.values():
|
| 1827 |
if not isinstance(value, str):
|
| 1828 |
raise ValueError(
|
| 1829 |
+
"extra_jax_mappings must have keys that are strings! "
|
| 1830 |
+
"e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 1831 |
)
|
| 1832 |
else:
|
| 1833 |
extra_jax_mappings = {}
|
|
|
|
| 1835 |
for value in extra_jax_mappings.values():
|
| 1836 |
if not callable(value):
|
| 1837 |
raise ValueError(
|
| 1838 |
+
"extra_torch_mappings must be callable functions! "
|
| 1839 |
+
"e.g., {sympy.sqrt: torch.sqrt}."
|
| 1840 |
)
|
| 1841 |
else:
|
| 1842 |
extra_torch_mappings = {}
|