Spaces:
Sleeping
Sleeping
Commit
·
78cdb0e
1
Parent(s):
4ae8a5c
Add test for loading from pickle file
Browse files- test/test.py +27 -0
test/test.py
CHANGED
|
@@ -309,6 +309,33 @@ class TestPipeline(unittest.TestCase):
|
|
| 309 |
|
| 310 |
np.testing.assert_allclose(y_truth, y_test)
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
class TestBest(unittest.TestCase):
|
| 314 |
def setUp(self):
|
|
|
|
| 309 |
|
| 310 |
np.testing.assert_allclose(y_truth, y_test)
|
| 311 |
|
| 312 |
+
def test_load_model_simple(self):
|
| 313 |
+
# Test that we can simply load a model from its equation file.
|
| 314 |
+
y = self.X[:, [0, 1]] ** 2
|
| 315 |
+
model = PySRRegressor(
|
| 316 |
+
# Test that passing a single operator works:
|
| 317 |
+
unary_operators="sq(x) = x^2",
|
| 318 |
+
binary_operators="plus",
|
| 319 |
+
extra_sympy_mappings={"sq": lambda x: x**2},
|
| 320 |
+
**self.default_test_kwargs,
|
| 321 |
+
procs=0,
|
| 322 |
+
denoise=True,
|
| 323 |
+
early_stop_condition="stop_if(loss, complexity) = loss < 0.05 && complexity == 2",
|
| 324 |
+
)
|
| 325 |
+
rand_dir = Path(tempfile.mkdtemp())
|
| 326 |
+
equation_file = rand_dir / "equations.csv"
|
| 327 |
+
model.set_params(temp_equation_file=False)
|
| 328 |
+
model.set_params(equation_file=equation_file)
|
| 329 |
+
model.fit(self.X, y)
|
| 330 |
+
|
| 331 |
+
# lambda functions are removed from the pickling, so we need
|
| 332 |
+
# to pass it during the loading:
|
| 333 |
+
model2 = load(
|
| 334 |
+
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
np.testing.assert_allclose(model.predict(self.X), model2.predict(self.X))
|
| 338 |
+
|
| 339 |
|
| 340 |
class TestBest(unittest.TestCase):
|
| 341 |
def setUp(self):
|