Spaces:
Running
Running
Commit
·
e29a6da
1
Parent(s):
d170500
Ensure that variable names are not sympy functions
Browse files- pysr/sr.py +6 -0
- test/test.py +7 -0
pysr/sr.py
CHANGED
|
@@ -169,6 +169,12 @@ def _check_assertions(
|
|
| 169 |
assert X.shape[0] == weights.shape[0]
|
| 170 |
if use_custom_variable_names:
|
| 171 |
assert len(variable_names) == X.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
|
| 174 |
def best(*args, **kwargs): # pragma: no cover
|
|
|
|
| 169 |
assert X.shape[0] == weights.shape[0]
|
| 170 |
if use_custom_variable_names:
|
| 171 |
assert len(variable_names) == X.shape[1]
|
| 172 |
+
# Check none of the variable names are function names:
|
| 173 |
+
for var_name in variable_names:
|
| 174 |
+
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
|
| 175 |
+
raise ValueError(
|
| 176 |
+
f"Variable name {var_name} is already a function name."
|
| 177 |
+
)
|
| 178 |
|
| 179 |
|
| 180 |
def best(*args, **kwargs): # pragma: no cover
|
test/test.py
CHANGED
|
@@ -546,6 +546,13 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 546 |
with self.assertRaises(ValueError):
|
| 547 |
model.fit(X, y)
|
| 548 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
def test_pickle_with_temp_equation_file(self):
|
| 550 |
"""If we have a temporary equation file, unpickle the estimator."""
|
| 551 |
model = PySRRegressor(
|
|
|
|
| 546 |
with self.assertRaises(ValueError):
|
| 547 |
model.fit(X, y)
|
| 548 |
|
| 549 |
+
def test_sympy_function_fails_as_variable(self):
|
| 550 |
+
model = PySRRegressor()
|
| 551 |
+
X = np.random.randn(100, 2)
|
| 552 |
+
y = np.random.randn(100)
|
| 553 |
+
with self.assertRaises(ValueError):
|
| 554 |
+
model.fit(X, y, variable_names=["x1", "N"])
|
| 555 |
+
|
| 556 |
def test_pickle_with_temp_equation_file(self):
|
| 557 |
"""If we have a temporary equation file, unpickle the estimator."""
|
| 558 |
model = PySRRegressor(
|