Spaces:
Sleeping
Sleeping
Commit
·
a190947
1
Parent(s):
5f19660
Add more reliable test of feature selection
Browse files- test/test.py +18 -1
test/test.py
CHANGED
|
@@ -190,7 +190,6 @@ class TestPipeline(unittest.TestCase):
|
|
| 190 |
**self.default_test_kwargs,
|
| 191 |
Xresampled=Xresampled,
|
| 192 |
denoise=True,
|
| 193 |
-
select_k_features=2,
|
| 194 |
nested_constraints={"/": {"+": 1, "-": 1}, "+": {"*": 4}},
|
| 195 |
)
|
| 196 |
model.fit(X, y)
|
|
@@ -210,6 +209,24 @@ class TestPipeline(unittest.TestCase):
|
|
| 210 |
self.assertLess(np.average((fn(X2) - true_fn(X2)) ** 2), 1e-1)
|
| 211 |
self.assertLess(np.average((model.predict(X2) - true_fn(X2)) ** 2), 1e-1)
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
class TestBest(unittest.TestCase):
|
| 215 |
def setUp(self):
|
|
|
|
| 190 |
**self.default_test_kwargs,
|
| 191 |
Xresampled=Xresampled,
|
| 192 |
denoise=True,
|
|
|
|
| 193 |
nested_constraints={"/": {"+": 1, "-": 1}, "+": {"*": 4}},
|
| 194 |
)
|
| 195 |
model.fit(X, y)
|
|
|
|
| 209 |
self.assertLess(np.average((fn(X2) - true_fn(X2)) ** 2), 1e-1)
|
| 210 |
self.assertLess(np.average((model.predict(X2) - true_fn(X2)) ** 2), 1e-1)
|
| 211 |
|
| 212 |
+
def test_high_dim_selection_early_stop(self):
|
| 213 |
+
X = pd.DataFrame({f"k{i}": self.rstate.randn(10000) for i in range(10)})
|
| 214 |
+
Xresampled = pd.DataFrame({f"k{i}": self.rstate.randn(100) for i in range(10)})
|
| 215 |
+
y = X["k7"] ** 2 + np.cos(X["k9"]) * 3
|
| 216 |
+
|
| 217 |
+
model = PySRRegressor(
|
| 218 |
+
unary_operators=["cos"],
|
| 219 |
+
select_k_features=3,
|
| 220 |
+
early_stop_condition=1e-4, # Stop once most accurate equation is <1e-4 MSE
|
| 221 |
+
Xresampled=Xresampled,
|
| 222 |
+
maxsize=12,
|
| 223 |
+
**self.default_test_kwargs,
|
| 224 |
+
)
|
| 225 |
+
model.fit(X, y)
|
| 226 |
+
model.set_params(model_selection="accuracy")
|
| 227 |
+
model.predict(X)
|
| 228 |
+
self.assertLess(np.average((model.predict(X) - y) ** 2), 1e-4)
|
| 229 |
+
|
| 230 |
|
| 231 |
class TestBest(unittest.TestCase):
|
| 232 |
def setUp(self):
|