Spaces:
Sleeping
Sleeping
Commit
·
2f296b6
1
Parent(s):
5ada6c7
Test that pickle works without equation file
Browse files- test/test.py +52 -18
test/test.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import inspect
|
| 2 |
import unittest
|
| 3 |
import numpy as np
|
|
@@ -8,13 +10,14 @@ from sklearn.utils.estimator_checks import check_estimator
|
|
| 8 |
import sympy
|
| 9 |
import pandas as pd
|
| 10 |
import warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
DEFAULT_NITERATIONS = (
|
| 13 |
-
inspect.signature(PySRRegressor.__init__).parameters["niterations"].default
|
| 14 |
-
)
|
| 15 |
-
DEFAULT_POPULATIONS = (
|
| 16 |
-
inspect.signature(PySRRegressor.__init__).parameters["populations"].default
|
| 17 |
-
)
|
| 18 |
|
| 19 |
class TestPipeline(unittest.TestCase):
|
| 20 |
def setUp(self):
|
|
@@ -399,14 +402,49 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 399 |
with self.assertRaises(ValueError):
|
| 400 |
model.fit(X, y)
|
| 401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
def test_scikit_learn_compatibility(self):
|
| 403 |
"""Test PySRRegressor compatibility with scikit-learn."""
|
| 404 |
model = PySRRegressor(
|
| 405 |
-
|
|
|
|
|
|
|
| 406 |
verbosity=0,
|
| 407 |
progress=False,
|
| 408 |
random_state=0,
|
| 409 |
-
deterministic=True,
|
| 410 |
procs=0,
|
| 411 |
multithreading=False,
|
| 412 |
warm_start=False,
|
|
@@ -419,20 +457,16 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 419 |
try:
|
| 420 |
with warnings.catch_warnings():
|
| 421 |
warnings.simplefilter("ignore")
|
| 422 |
-
# To ensure an equation file is written for each output in
|
| 423 |
-
# nout, set stop condition to niterations=1
|
| 424 |
-
if check.func.__name__ == "check_regressor_multioutput":
|
| 425 |
-
model.set_params(niterations=1, max_evals=None)
|
| 426 |
-
else:
|
| 427 |
-
model.set_params(max_evals=10000)
|
| 428 |
check(model)
|
| 429 |
print("Passed", check.func.__name__)
|
| 430 |
-
except Exception
|
| 431 |
-
error_message = str(
|
| 432 |
-
exception_messages.append(
|
|
|
|
|
|
|
| 433 |
print("Failed", check.func.__name__, "with:")
|
| 434 |
# Add a leading tab to error message, which
|
| 435 |
# might be multi-line:
|
| 436 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 437 |
# If any checks failed don't let the test pass.
|
| 438 |
-
self.assertEqual(
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import traceback
|
| 3 |
import inspect
|
| 4 |
import unittest
|
| 5 |
import numpy as np
|
|
|
|
| 10 |
import sympy
|
| 11 |
import pandas as pd
|
| 12 |
import warnings
|
| 13 |
+
import pickle as pkl
|
| 14 |
+
import tempfile
|
| 15 |
+
|
| 16 |
+
DEFAULT_PARAMS = inspect.signature(PySRRegressor.__init__).parameters
|
| 17 |
+
DEFAULT_NITERATIONS = DEFAULT_PARAMS["niterations"].default
|
| 18 |
+
DEFAULT_POPULATIONS = DEFAULT_PARAMS["populations"].default
|
| 19 |
+
DEFAULT_NCYCLES = DEFAULT_PARAMS["ncyclesperiteration"].default
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
class TestPipeline(unittest.TestCase):
|
| 23 |
def setUp(self):
|
|
|
|
| 402 |
with self.assertRaises(ValueError):
|
| 403 |
model.fit(X, y)
|
| 404 |
|
| 405 |
+
def test_pickle_with_temp_equation_file(self):
|
| 406 |
+
"""If we have a temporary equation file, unpickle the estimator."""
|
| 407 |
+
model = PySRRegressor(
|
| 408 |
+
populations=int(1 + DEFAULT_POPULATIONS / 5),
|
| 409 |
+
temp_equation_file=True,
|
| 410 |
+
procs=0,
|
| 411 |
+
multithreading=False,
|
| 412 |
+
)
|
| 413 |
+
nout = 3
|
| 414 |
+
X = np.random.randn(100, 2)
|
| 415 |
+
y = np.random.randn(100, nout)
|
| 416 |
+
model.fit(X, y)
|
| 417 |
+
contents = model.equation_file_contents_.copy()
|
| 418 |
+
|
| 419 |
+
y_predictions = model.predict(X)
|
| 420 |
+
|
| 421 |
+
equation_file_base = model.equation_file_
|
| 422 |
+
for i in range(1, nout + 1):
|
| 423 |
+
assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
|
| 424 |
+
|
| 425 |
+
with tempfile.NamedTemporaryFile() as pickle_file:
|
| 426 |
+
pkl.dump(model, pickle_file)
|
| 427 |
+
pickle_file.seek(0)
|
| 428 |
+
model2 = pkl.load(pickle_file)
|
| 429 |
+
|
| 430 |
+
contents2 = model2.equation_file_contents_
|
| 431 |
+
cols_to_check = ["equation", "loss", "complexity"]
|
| 432 |
+
for frame1, frame2 in zip(contents, contents2):
|
| 433 |
+
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
|
| 434 |
+
|
| 435 |
+
y_predictions2 = model2.predict(X)
|
| 436 |
+
np.testing.assert_array_equal(y_predictions, y_predictions2)
|
| 437 |
+
|
| 438 |
def test_scikit_learn_compatibility(self):
|
| 439 |
"""Test PySRRegressor compatibility with scikit-learn."""
|
| 440 |
model = PySRRegressor(
|
| 441 |
+
niterations=int(1 + DEFAULT_NITERATIONS / 10),
|
| 442 |
+
populations=int(1 + DEFAULT_POPULATIONS / 3),
|
| 443 |
+
ncyclesperiteration=int(2 + DEFAULT_NCYCLES / 10),
|
| 444 |
verbosity=0,
|
| 445 |
progress=False,
|
| 446 |
random_state=0,
|
| 447 |
+
deterministic=True, # Deterministic as tests require this.
|
| 448 |
procs=0,
|
| 449 |
multithreading=False,
|
| 450 |
warm_start=False,
|
|
|
|
| 457 |
try:
|
| 458 |
with warnings.catch_warnings():
|
| 459 |
warnings.simplefilter("ignore")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
check(model)
|
| 461 |
print("Passed", check.func.__name__)
|
| 462 |
+
except Exception:
|
| 463 |
+
error_message = str(traceback.format_exc())
|
| 464 |
+
exception_messages.append(
|
| 465 |
+
f"{check.func.__name__}:\n" + error_message + "\n"
|
| 466 |
+
)
|
| 467 |
print("Failed", check.func.__name__, "with:")
|
| 468 |
# Add a leading tab to error message, which
|
| 469 |
# might be multi-line:
|
| 470 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 471 |
# If any checks failed don't let the test pass.
|
| 472 |
+
self.assertEqual(len(exception_messages), 0)
|