Spaces:
Running
Running
Commit
·
ec8124e
1
Parent(s):
af14165
Get PySRRegressor working with multi-output
Browse files- pysr/sr.py +74 -30
- test/test.py +5 -9
pysr/sr.py
CHANGED
|
@@ -665,27 +665,46 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 665 |
if self.equations is None:
|
| 666 |
return "PySRRegressor.equations = None"
|
| 667 |
|
|
|
|
|
|
|
| 668 |
equations = self.equations
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
chosen_row = -1
|
| 672 |
-
elif self.model_selection == "best":
|
| 673 |
-
chosen_row = equations["score"].idxmax()
|
| 674 |
else:
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
)
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
return output
|
| 690 |
|
| 691 |
def set_params(self, **params):
|
|
@@ -708,13 +727,19 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 708 |
|
| 709 |
def get_best(self):
|
| 710 |
if self.equations is None:
|
| 711 |
-
|
| 712 |
if self.model_selection == "accuracy":
|
|
|
|
|
|
|
| 713 |
return self.equations.iloc[-1]
|
| 714 |
elif self.model_selection == "best":
|
| 715 |
-
|
|
|
|
|
|
|
| 716 |
else:
|
| 717 |
-
raise NotImplementedError
|
|
|
|
|
|
|
| 718 |
|
| 719 |
def fit(self, X, y, weights=None, variable_names=None):
|
| 720 |
"""Search for equations to fit the dataset.
|
|
@@ -747,26 +772,40 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 747 |
|
| 748 |
def predict(self, X):
|
| 749 |
self.refresh()
|
| 750 |
-
|
| 751 |
-
|
|
|
|
|
|
|
| 752 |
|
| 753 |
def sympy(self):
|
| 754 |
self.refresh()
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
| 756 |
|
| 757 |
def latex(self):
|
| 758 |
self.refresh()
|
| 759 |
-
|
|
|
|
|
|
|
|
|
|
| 760 |
|
| 761 |
def jax(self):
|
| 762 |
self.set_params(output_jax_format=True)
|
| 763 |
self.refresh()
|
| 764 |
-
|
| 765 |
-
|
|
|
|
|
|
|
|
|
|
| 766 |
def pytorch(self):
|
| 767 |
self.set_params(output_torch_format=True)
|
| 768 |
self.refresh()
|
| 769 |
-
|
|
|
|
|
|
|
|
|
|
| 770 |
|
| 771 |
def _run(self, X, y, weights, variable_names):
|
| 772 |
global already_ran
|
|
@@ -846,11 +885,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 846 |
|
| 847 |
if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
|
| 848 |
self.multioutput = False
|
| 849 |
-
nout = 1
|
| 850 |
y = y.reshape(-1)
|
| 851 |
elif len(y.shape) == 2:
|
| 852 |
self.multioutput = True
|
| 853 |
-
nout = y.shape[1]
|
| 854 |
else:
|
| 855 |
raise NotImplementedError("y shape not supported!")
|
| 856 |
|
|
@@ -1182,3 +1221,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 1182 |
if self.multioutput:
|
| 1183 |
return ret_outputs
|
| 1184 |
return ret_outputs[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
if self.equations is None:
|
| 666 |
return "PySRRegressor.equations = None"
|
| 667 |
|
| 668 |
+
output = "PySRRegressor.equations = [\n"
|
| 669 |
+
|
| 670 |
equations = self.equations
|
| 671 |
+
if not isinstance(equations, list):
|
| 672 |
+
all_equations = [equations]
|
|
|
|
|
|
|
|
|
|
| 673 |
else:
|
| 674 |
+
all_equations = equations
|
| 675 |
+
|
| 676 |
+
for i, equations in enumerate(all_equations):
|
| 677 |
+
selected = ["" for _ in range(len(equations))]
|
| 678 |
+
if self.model_selection == "accuracy":
|
| 679 |
+
chosen_row = -1
|
| 680 |
+
elif self.model_selection == "best":
|
| 681 |
+
chosen_row = equations["score"].idxmax()
|
| 682 |
+
else:
|
| 683 |
+
raise NotImplementedError
|
| 684 |
+
selected[chosen_row] = ">>>>"
|
| 685 |
+
repr_equations = pd.DataFrame(
|
| 686 |
+
dict(
|
| 687 |
+
pick=selected,
|
| 688 |
+
score=equations["score"],
|
| 689 |
+
equation=equations["equation"],
|
| 690 |
+
loss=equations["loss"],
|
| 691 |
+
complexity=equations["complexity"],
|
| 692 |
+
)
|
| 693 |
)
|
| 694 |
+
|
| 695 |
+
if len(all_equations) > 1:
|
| 696 |
+
output += "[\n"
|
| 697 |
+
|
| 698 |
+
for line in repr_equations.__repr__().split("\n"):
|
| 699 |
+
output += "\t" + line + "\n"
|
| 700 |
+
|
| 701 |
+
if len(all_equations) > 1:
|
| 702 |
+
output += "]"
|
| 703 |
+
|
| 704 |
+
if i < len(all_equations) - 1:
|
| 705 |
+
output += ", "
|
| 706 |
+
|
| 707 |
+
output += "]"
|
| 708 |
return output
|
| 709 |
|
| 710 |
def set_params(self, **params):
|
|
|
|
| 727 |
|
| 728 |
def get_best(self):
|
| 729 |
if self.equations is None:
|
| 730 |
+
raise ValueError("No equations have been generated yet.")
|
| 731 |
if self.model_selection == "accuracy":
|
| 732 |
+
if isinstance(self.equations, list):
|
| 733 |
+
return [eq.iloc[-1] for eq in self.equations]
|
| 734 |
return self.equations.iloc[-1]
|
| 735 |
elif self.model_selection == "best":
|
| 736 |
+
if isinstance(self.equations, list):
|
| 737 |
+
return [eq.iloc[eq["score"].idxmax()] for eq in self.equations]
|
| 738 |
+
return self.equations.iloc[self.equations["score"].idxmax()]
|
| 739 |
else:
|
| 740 |
+
raise NotImplementedError(
|
| 741 |
+
f"{self.model_selection} is not a valid model selection strategy."
|
| 742 |
+
)
|
| 743 |
|
| 744 |
def fit(self, X, y, weights=None, variable_names=None):
|
| 745 |
"""Search for equations to fit the dataset.
|
|
|
|
| 772 |
|
| 773 |
def predict(self, X):
|
| 774 |
self.refresh()
|
| 775 |
+
best = self.get_best()
|
| 776 |
+
if self.multioutput:
|
| 777 |
+
return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
|
| 778 |
+
return best["lambda_format"](X)
|
| 779 |
|
| 780 |
def sympy(self):
|
| 781 |
self.refresh()
|
| 782 |
+
best = self.get_best()
|
| 783 |
+
if self.multioutput:
|
| 784 |
+
return [eq["sympy_format"] for eq in best]
|
| 785 |
+
return best["sympy_format"]
|
| 786 |
|
| 787 |
def latex(self):
|
| 788 |
self.refresh()
|
| 789 |
+
sympy_representation = self.sympy()
|
| 790 |
+
if self.multioutput:
|
| 791 |
+
return [sympy.latex(s) for s in sympy_representation]
|
| 792 |
+
return sympy.latex(sympy_representation)
|
| 793 |
|
| 794 |
def jax(self):
|
| 795 |
self.set_params(output_jax_format=True)
|
| 796 |
self.refresh()
|
| 797 |
+
best = self.get_best()
|
| 798 |
+
if self.multioutput:
|
| 799 |
+
return [eq["jax_format"] for eq in best]
|
| 800 |
+
return best["jax_format"]
|
| 801 |
+
|
| 802 |
def pytorch(self):
|
| 803 |
self.set_params(output_torch_format=True)
|
| 804 |
self.refresh()
|
| 805 |
+
best = self.get_best()
|
| 806 |
+
if self.multioutput:
|
| 807 |
+
return [eq["torch_format"] for eq in best]
|
| 808 |
+
return best["torch_format"]
|
| 809 |
|
| 810 |
def _run(self, X, y, weights, variable_names):
|
| 811 |
global already_ran
|
|
|
|
| 885 |
|
| 886 |
if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
|
| 887 |
self.multioutput = False
|
| 888 |
+
self.nout = 1
|
| 889 |
y = y.reshape(-1)
|
| 890 |
elif len(y.shape) == 2:
|
| 891 |
self.multioutput = True
|
| 892 |
+
self.nout = y.shape[1]
|
| 893 |
else:
|
| 894 |
raise NotImplementedError("y shape not supported!")
|
| 895 |
|
|
|
|
| 1221 |
if self.multioutput:
|
| 1222 |
return ret_outputs
|
| 1223 |
return ret_outputs[0]
|
| 1224 |
+
|
| 1225 |
+
def score(self, X, y):
|
| 1226 |
+
del X
|
| 1227 |
+
del y
|
| 1228 |
+
raise NotImplementedError
|
test/test.py
CHANGED
|
@@ -171,13 +171,13 @@ class TestBest(unittest.TestCase):
|
|
| 171 |
def setUp(self):
|
| 172 |
equations = pd.DataFrame(
|
| 173 |
{
|
| 174 |
-
"
|
| 175 |
-
"
|
| 176 |
-
"
|
| 177 |
}
|
| 178 |
)
|
| 179 |
|
| 180 |
-
equations["
|
| 181 |
"equation_file.csv.bkup", sep="|"
|
| 182 |
)
|
| 183 |
|
|
@@ -195,19 +195,15 @@ class TestBest(unittest.TestCase):
|
|
| 195 |
self.model.equations = self.equations
|
| 196 |
|
| 197 |
def test_best(self):
|
| 198 |
-
self.assertEqual(best(self.equations), sympy.cos(sympy.Symbol("x0")) ** 2)
|
| 199 |
-
self.assertEqual(best(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
| 200 |
self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
| 201 |
|
| 202 |
def test_best_tex(self):
|
| 203 |
-
self.assertEqual(best_tex(self.equations), "\\cos^{2}{\\left(x_{0} \\right)}")
|
| 204 |
-
self.assertEqual(best_tex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
| 205 |
self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
| 206 |
|
| 207 |
def test_best_lambda(self):
|
| 208 |
X = np.random.randn(10, 2)
|
| 209 |
y = np.cos(X[:, 0]) ** 2
|
| 210 |
-
for f in [
|
| 211 |
np.testing.assert_almost_equal(f(X), y, decimal=4)
|
| 212 |
|
| 213 |
|
|
|
|
| 171 |
def setUp(self):
|
| 172 |
equations = pd.DataFrame(
|
| 173 |
{
|
| 174 |
+
"equation": ["1.0", "cos(x0)", "square(cos(x0))"],
|
| 175 |
+
"loss": [1.0, 0.1, 1e-5],
|
| 176 |
+
"complexity": [1, 2, 3],
|
| 177 |
}
|
| 178 |
)
|
| 179 |
|
| 180 |
+
equations["complexity loss equation".split(" ")].to_csv(
|
| 181 |
"equation_file.csv.bkup", sep="|"
|
| 182 |
)
|
| 183 |
|
|
|
|
| 195 |
self.model.equations = self.equations
|
| 196 |
|
| 197 |
def test_best(self):
|
|
|
|
|
|
|
| 198 |
self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
| 199 |
|
| 200 |
def test_best_tex(self):
|
|
|
|
|
|
|
| 201 |
self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
| 202 |
|
| 203 |
def test_best_lambda(self):
|
| 204 |
X = np.random.randn(10, 2)
|
| 205 |
y = np.cos(X[:, 0]) ** 2
|
| 206 |
+
for f in [self.model.predict, self.equations.iloc[-1]['lambda_format']]:
|
| 207 |
np.testing.assert_almost_equal(f(X), y, decimal=4)
|
| 208 |
|
| 209 |
|