Spaces:
Running
Running
tttc3
commited on
Commit
·
bd90cfc
1
Parent(s):
3ef5500
Added pickle support
Browse files- pysr/export_numpy.py +4 -1
- pysr/sr.py +36 -0
- test/test.py +4 -4
pysr/export_numpy.py
CHANGED
|
@@ -13,7 +13,6 @@ class CallableEquation:
|
|
| 13 |
self._sympy_symbols = sympy_symbols
|
| 14 |
self._selection = selection
|
| 15 |
self._variable_names = variable_names
|
| 16 |
-
self._lambda = lambdify(sympy_symbols, eqn)
|
| 17 |
|
| 18 |
def __repr__(self):
|
| 19 |
return f"PySRFunction(X=>{self._sympy})"
|
|
@@ -35,3 +34,7 @@ class CallableEquation:
|
|
| 35 |
)
|
| 36 |
X = X[:, self._selection]
|
| 37 |
return self._lambda(*X.T) * np.ones(expected_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
self._sympy_symbols = sympy_symbols
|
| 14 |
self._selection = selection
|
| 15 |
self._variable_names = variable_names
|
|
|
|
| 16 |
|
| 17 |
def __repr__(self):
|
| 18 |
return f"PySRFunction(X=>{self._sympy})"
|
|
|
|
| 34 |
)
|
| 35 |
X = X[:, self._selection]
|
| 36 |
return self._lambda(*X.T) * np.ones(expected_shape)
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def _lambda(self):
|
| 40 |
+
return lambdify(self._sympy_symbols, self._sympy)
|
pysr/sr.py
CHANGED
|
@@ -816,6 +816,42 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 816 |
output += "]"
|
| 817 |
return output
|
| 818 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
@property
|
| 820 |
def equations(self): # pragma: no cover
|
| 821 |
warnings.warn(
|
|
|
|
| 816 |
output += "]"
|
| 817 |
return output
|
| 818 |
|
| 819 |
+
def __getstate__(self):
|
| 820 |
+
"""
|
| 821 |
+
Handles pickle serialization for PySRRegressor.
|
| 822 |
+
|
| 823 |
+
The Scikit-learn standard requires estimators to be serializable via
|
| 824 |
+
`pickle.dumps()`. However, `PyCall.jlwrap` does not support pickle
|
| 825 |
+
serialization.
|
| 826 |
+
|
| 827 |
+
Thus, for `PySRRegressor` to support pickle serialization, the
|
| 828 |
+
`raw_julia_state_` attribute must be hidden from pickle. This will
|
| 829 |
+
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
| 830 |
+
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
| 831 |
+
to be serialized. Note: Jax and Torch format equations are also removed
|
| 832 |
+
from the pickled instance.
|
| 833 |
+
"""
|
| 834 |
+
warnings.warn(
|
| 835 |
+
"raw_julia_state_ cannot be pickled and will be removed from the "
|
| 836 |
+
"serialized instance. This will prevent a `warm_start` fit of any "
|
| 837 |
+
"model that is deserialized via `pickle.loads()`."
|
| 838 |
+
)
|
| 839 |
+
state = self.__dict__
|
| 840 |
+
pickled_state = {
|
| 841 |
+
key: None if key == "raw_julia_state_" else value
|
| 842 |
+
for key, value in state.items()
|
| 843 |
+
}
|
| 844 |
+
if "equations_" in pickled_state:
|
| 845 |
+
pickled_state["output_torch_format"] = False
|
| 846 |
+
pickled_state["output_jax_format"] = False
|
| 847 |
+
pickled_columns = ~pickled_state["equations_"].columns.isin(
|
| 848 |
+
["jax_format", "torch_format"]
|
| 849 |
+
)
|
| 850 |
+
pickled_state["equations_"] = (
|
| 851 |
+
pickled_state["equations_"].loc[:, pickled_columns].copy()
|
| 852 |
+
)
|
| 853 |
+
return pickled_state
|
| 854 |
+
|
| 855 |
@property
|
| 856 |
def equations(self): # pragma: no cover
|
| 857 |
warnings.warn(
|
test/test.py
CHANGED
|
@@ -348,18 +348,18 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 348 |
max_evals=10000, verbosity=0, progress=False
|
| 349 |
) # Return early.
|
| 350 |
check_generator = check_estimator(model, generate_only=True)
|
|
|
|
| 351 |
for (_, check) in check_generator:
|
| 352 |
-
if "pickle" in check.func.__name__:
|
| 353 |
-
# Skip pickling tests.
|
| 354 |
-
continue
|
| 355 |
-
|
| 356 |
try:
|
| 357 |
with warnings.catch_warnings():
|
| 358 |
warnings.simplefilter("ignore")
|
| 359 |
check(model)
|
| 360 |
print("Passed", check.func.__name__)
|
| 361 |
except Exception as e:
|
|
|
|
| 362 |
print("Failed", check.func.__name__, "with:")
|
| 363 |
# Add a leading tab to error message, which
|
| 364 |
# might be multi-line:
|
| 365 |
print("\n".join([(" " * 4) + row for row in str(e).split("\n")]))
|
|
|
|
|
|
|
|
|
| 348 |
max_evals=10000, verbosity=0, progress=False
|
| 349 |
) # Return early.
|
| 350 |
check_generator = check_estimator(model, generate_only=True)
|
| 351 |
+
exception_messages = []
|
| 352 |
for (_, check) in check_generator:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
try:
|
| 354 |
with warnings.catch_warnings():
|
| 355 |
warnings.simplefilter("ignore")
|
| 356 |
check(model)
|
| 357 |
print("Passed", check.func.__name__)
|
| 358 |
except Exception as e:
|
| 359 |
+
exception_messages.append(f"{check.func.__name__}: {e}\n")
|
| 360 |
print("Failed", check.func.__name__, "with:")
|
| 361 |
# Add a leading tab to error message, which
|
| 362 |
# might be multi-line:
|
| 363 |
print("\n".join([(" " * 4) + row for row in str(e).split("\n")]))
|
| 364 |
+
# If any checks failed don't let the test pass.
|
| 365 |
+
self.assertEqual([], exception_messages)
|