Spaces:
Running
Running
Commit
·
6a4fa2c
1
Parent(s):
97e6589
Fix issue with lambda getting redefined; add test
Browse files- pysr/sr.py +15 -2
- test/test.py +28 -4
pysr/sr.py
CHANGED
|
@@ -61,6 +61,19 @@ sympy_mappings = {
|
|
| 61 |
'gamma': lambda x : sympy.gamma(x),
|
| 62 |
}
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def pysr(X, y, weights=None,
|
| 65 |
binary_operators=None,
|
| 66 |
unary_operators=None,
|
|
@@ -774,8 +787,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 774 |
if output_jax_format:
|
| 775 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 776 |
jax_format.append({'callable': func, 'parameters': params})
|
| 777 |
-
|
| 778 |
-
lambda_format.append(
|
| 779 |
curMSE = output.loc[i, 'MSE']
|
| 780 |
curComplexity = output.loc[i, 'Complexity']
|
| 781 |
|
|
|
|
| 61 |
'gamma': lambda x : sympy.gamma(x),
|
| 62 |
}
|
| 63 |
|
| 64 |
+
class CallableEquation(object):
|
| 65 |
+
"""Simple wrapper for numpy lambda functions built with sympy"""
|
| 66 |
+
def __init__(self, sympy_symbols, eqn):
|
| 67 |
+
self._sympy = eqn
|
| 68 |
+
self._sympy_symbols = sympy_symbols
|
| 69 |
+
self._lambda = lambdify(sympy_symbols, eqn)
|
| 70 |
+
|
| 71 |
+
def __repr__(self):
|
| 72 |
+
return f"PySRFunction(X=>{self._sympy})"
|
| 73 |
+
|
| 74 |
+
def __call__(self, X):
|
| 75 |
+
return self._lambda(*X.T)
|
| 76 |
+
|
| 77 |
def pysr(X, y, weights=None,
|
| 78 |
binary_operators=None,
|
| 79 |
unary_operators=None,
|
|
|
|
| 787 |
if output_jax_format:
|
| 788 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 789 |
jax_format.append({'callable': func, 'parameters': params})
|
| 790 |
+
|
| 791 |
+
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
| 792 |
curMSE = output.loc[i, 'MSE']
|
| 793 |
curComplexity = output.loc[i, 'Complexity']
|
| 794 |
|
test/test.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import unittest
|
| 2 |
import numpy as np
|
| 3 |
-
from pysr import pysr, get_hof, best, best_tex, best_callable
|
| 4 |
from pysr.sr import run_feature_selection, _handle_feature_selection
|
| 5 |
import sympy
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
|
| 8 |
class TestPipeline(unittest.TestCase):
|
|
@@ -27,12 +28,36 @@ class TestPipeline(unittest.TestCase):
|
|
| 27 |
y = self.X[:, [0, 1]]**2
|
| 28 |
equations = pysr(self.X, y,
|
| 29 |
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
|
| 30 |
-
extra_sympy_mappings={'
|
| 31 |
-
**self.default_test_kwargs
|
|
|
|
| 32 |
print(equations)
|
| 33 |
self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
|
| 34 |
self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def test_empty_operators_single_input(self):
|
| 37 |
X = np.random.randn(100, 1)
|
| 38 |
y = X[:, 0] + 3.0
|
|
@@ -40,7 +65,6 @@ class TestPipeline(unittest.TestCase):
|
|
| 40 |
unary_operators=[], binary_operators=["plus"],
|
| 41 |
**self.default_test_kwargs)
|
| 42 |
|
| 43 |
-
print(equations)
|
| 44 |
self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
|
| 45 |
|
| 46 |
class TestBest(unittest.TestCase):
|
|
|
|
| 1 |
import unittest
|
| 2 |
import numpy as np
|
| 3 |
+
from pysr import pysr, get_hof, best, best_tex, best_callable, best_row
|
| 4 |
from pysr.sr import run_feature_selection, _handle_feature_selection
|
| 5 |
import sympy
|
| 6 |
+
from sympy import lambdify
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
class TestPipeline(unittest.TestCase):
|
|
|
|
| 28 |
y = self.X[:, [0, 1]]**2
|
| 29 |
equations = pysr(self.X, y,
|
| 30 |
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
|
| 31 |
+
extra_sympy_mappings={'sq': lambda x: x**2},
|
| 32 |
+
**self.default_test_kwargs,
|
| 33 |
+
procs=0)
|
| 34 |
print(equations)
|
| 35 |
self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
|
| 36 |
self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)
|
| 37 |
|
| 38 |
+
def test_multioutput_weighted_with_callable(self):
|
| 39 |
+
y = self.X[:, [0, 1]]**2
|
| 40 |
+
w = np.random.rand(*y.shape)
|
| 41 |
+
w[w < 0.5] = 0.0
|
| 42 |
+
w[w >= 0.5] = 1.0
|
| 43 |
+
|
| 44 |
+
# Double equation when weights are 0:
|
| 45 |
+
y += (1-w) * y
|
| 46 |
+
# Thus, pysr needs to use the weights to find the right equation!
|
| 47 |
+
|
| 48 |
+
equations = pysr(self.X, y, weights=w,
|
| 49 |
+
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
|
| 50 |
+
extra_sympy_mappings={'sq': lambda x: x**2},
|
| 51 |
+
**self.default_test_kwargs,
|
| 52 |
+
procs=0)
|
| 53 |
+
|
| 54 |
+
np.testing.assert_almost_equal(
|
| 55 |
+
best_callable()[0](self.X),
|
| 56 |
+
self.X[:, 0]**2)
|
| 57 |
+
np.testing.assert_almost_equal(
|
| 58 |
+
best_callable()[1](self.X),
|
| 59 |
+
self.X[:, 1]**2)
|
| 60 |
+
|
| 61 |
def test_empty_operators_single_input(self):
|
| 62 |
X = np.random.randn(100, 1)
|
| 63 |
y = X[:, 0] + 3.0
|
|
|
|
| 65 |
unary_operators=[], binary_operators=["plus"],
|
| 66 |
**self.default_test_kwargs)
|
| 67 |
|
|
|
|
| 68 |
self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
|
| 69 |
|
| 70 |
class TestBest(unittest.TestCase):
|