Spaces:
Running
Running
Commit
·
045bdb1
1
Parent(s):
af8ab17
Add tests for determinism warnings
Browse files- test/test.py +19 -0
test/test.py
CHANGED
|
@@ -358,6 +358,25 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 358 |
model.fit(X, y)
|
| 359 |
self.assertIn("with 10 features or more", str(context.exception))
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
def test_scikit_learn_compatibility(self):
|
| 362 |
"""Test PySRRegressor compatibility with scikit-learn."""
|
| 363 |
model = PySRRegressor(
|
|
|
|
| 358 |
model.fit(X, y)
|
| 359 |
self.assertIn("with 10 features or more", str(context.exception))
|
| 360 |
|
| 361 |
+
def test_deterministic_warnings(self):
|
| 362 |
+
"""Ensure that warnings are given for determinism"""
|
| 363 |
+
model = PySRRegressor(random_state=0)
|
| 364 |
+
X = np.random.randn(100, 2)
|
| 365 |
+
y = np.random.randn(100)
|
| 366 |
+
with warnings.catch_warnings():
|
| 367 |
+
warnings.simplefilter("error")
|
| 368 |
+
with self.assertRaises(Exception) as context:
|
| 369 |
+
model.fit(X, y)
|
| 370 |
+
self.assertIn("`deterministic`", str(context.exception))
|
| 371 |
+
|
| 372 |
+
def test_deterministic_errors(self):
|
| 373 |
+
"""Setting deterministic without random_state should error"""
|
| 374 |
+
model = PySRRegressor(deterministic=True)
|
| 375 |
+
X = np.random.randn(100, 2)
|
| 376 |
+
y = np.random.randn(100)
|
| 377 |
+
with self.assertRaises(ValueError):
|
| 378 |
+
model.fit(X, y)
|
| 379 |
+
|
| 380 |
def test_scikit_learn_compatibility(self):
|
| 381 |
"""Test PySRRegressor compatibility with scikit-learn."""
|
| 382 |
model = PySRRegressor(
|