Spaces:
Running
Running
Commit
·
c9cead8
1
Parent(s):
7cda629
Make torch custom operator test deterministic
Browse files- test/test_torch.py +10 -3
test/test_torch.py
CHANGED
|
@@ -160,9 +160,10 @@ class TestTorch(unittest.TestCase):
|
|
| 160 |
)
|
| 161 |
|
| 162 |
def test_feature_selection_custom_operators(self):
|
| 163 |
-
|
|
|
|
| 164 |
cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
|
| 165 |
-
y = X["k15"] ** 2 + cos_approx(X["k20"])
|
| 166 |
|
| 167 |
model = PySRRegressor(
|
| 168 |
progress=False,
|
|
@@ -172,7 +173,12 @@ class TestTorch(unittest.TestCase):
|
|
| 172 |
early_stop_condition=1e-5,
|
| 173 |
extra_sympy_mappings={"cos_approx": cos_approx},
|
| 174 |
extra_torch_mappings={"cos_approx": cos_approx},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
|
|
|
| 176 |
model.fit(X.values, y.values)
|
| 177 |
torch_module = model.pytorch()
|
| 178 |
|
|
@@ -180,4 +186,5 @@ class TestTorch(unittest.TestCase):
|
|
| 180 |
|
| 181 |
torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
|
| 182 |
|
| 183 |
-
np.testing.assert_almost_equal(
|
|
|
|
|
|
| 160 |
)
|
| 161 |
|
| 162 |
def test_feature_selection_custom_operators(self):
|
| 163 |
+
rstate = np.random.RandomState(0)
|
| 164 |
+
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
| 165 |
cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
|
| 166 |
+
y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
|
| 167 |
|
| 168 |
model = PySRRegressor(
|
| 169 |
progress=False,
|
|
|
|
| 173 |
early_stop_condition=1e-5,
|
| 174 |
extra_sympy_mappings={"cos_approx": cos_approx},
|
| 175 |
extra_torch_mappings={"cos_approx": cos_approx},
|
| 176 |
+
random_state=0,
|
| 177 |
+
deterministic=True,
|
| 178 |
+
procs=0,
|
| 179 |
+
multithreading=False,
|
| 180 |
)
|
| 181 |
+
np.random.seed(0)
|
| 182 |
model.fit(X.values, y.values)
|
| 183 |
torch_module = model.pytorch()
|
| 184 |
|
|
|
|
| 186 |
|
| 187 |
torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
|
| 188 |
|
| 189 |
+
np.testing.assert_almost_equal(y.values, np_output, decimal=4)
|
| 190 |
+
np.testing.assert_almost_equal(y.values, torch_output, decimal=4)
|