Spaces:
Running
Running
refactor: improved type inference in return values
Browse files- pysr/sr.py +24 -14
pysr/sr.py
CHANGED
|
@@ -2006,11 +2006,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2006 |
X = self._validate_data(X, reset=False)
|
| 2007 |
|
| 2008 |
try:
|
| 2009 |
-
if
|
|
|
|
| 2010 |
return np.stack(
|
| 2011 |
[eq["lambda_format"](X) for eq in best_equation], axis=1
|
| 2012 |
)
|
| 2013 |
-
|
|
|
|
| 2014 |
except Exception as error:
|
| 2015 |
raise ValueError(
|
| 2016 |
"Failed to evaluate the expression. "
|
|
@@ -2040,9 +2042,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2040 |
"""
|
| 2041 |
self.refresh()
|
| 2042 |
best_equation = self.get_best(index=index)
|
| 2043 |
-
if
|
|
|
|
| 2044 |
return [eq["sympy_format"] for eq in best_equation]
|
| 2045 |
-
|
|
|
|
| 2046 |
|
| 2047 |
def latex(self, index=None, precision=3):
|
| 2048 |
"""
|
|
@@ -2102,9 +2106,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2102 |
self.set_params(output_jax_format=True)
|
| 2103 |
self.refresh()
|
| 2104 |
best_equation = self.get_best(index=index)
|
| 2105 |
-
if
|
|
|
|
| 2106 |
return [eq["jax_format"] for eq in best_equation]
|
| 2107 |
-
|
|
|
|
| 2108 |
|
| 2109 |
def pytorch(self, index=None):
|
| 2110 |
"""
|
|
@@ -2132,9 +2138,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2132 |
self.set_params(output_torch_format=True)
|
| 2133 |
self.refresh()
|
| 2134 |
best_equation = self.get_best(index=index)
|
| 2135 |
-
if
|
|
|
|
|
|
|
| 2136 |
return [eq["torch_format"] for eq in best_equation]
|
| 2137 |
-
return best_equation["torch_format"]
|
| 2138 |
|
| 2139 |
def _read_equation_file(self):
|
| 2140 |
"""Read the hall of fame file created by `SymbolicRegression.jl`."""
|
|
@@ -2233,10 +2240,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2233 |
lastComplexity = 0
|
| 2234 |
sympy_format = []
|
| 2235 |
lambda_format = []
|
| 2236 |
-
|
| 2237 |
-
|
| 2238 |
-
if self.output_torch_format:
|
| 2239 |
-
torch_format = []
|
| 2240 |
|
| 2241 |
for _, eqn_row in output.iterrows():
|
| 2242 |
eqn = pysr2sympy(
|
|
@@ -2348,7 +2353,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2348 |
"""
|
| 2349 |
self.refresh()
|
| 2350 |
|
| 2351 |
-
if self.
|
| 2352 |
if indices is not None:
|
| 2353 |
assert isinstance(indices, list)
|
| 2354 |
assert isinstance(indices[0], list)
|
|
@@ -2357,7 +2362,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2357 |
table_string = sympy2multilatextable(
|
| 2358 |
self.equations_, indices=indices, precision=precision, columns=columns
|
| 2359 |
)
|
| 2360 |
-
|
| 2361 |
if indices is not None:
|
| 2362 |
assert isinstance(indices, list)
|
| 2363 |
assert isinstance(indices[0], int)
|
|
@@ -2365,6 +2370,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2365 |
table_string = sympy2latextable(
|
| 2366 |
self.equations_, indices=indices, precision=precision, columns=columns
|
| 2367 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2368 |
|
| 2369 |
preamble_string = [
|
| 2370 |
r"\usepackage{breqn}",
|
|
|
|
| 2006 |
X = self._validate_data(X, reset=False)
|
| 2007 |
|
| 2008 |
try:
|
| 2009 |
+
if isinstance(best_equation, list):
|
| 2010 |
+
assert self.nout_ > 1
|
| 2011 |
return np.stack(
|
| 2012 |
[eq["lambda_format"](X) for eq in best_equation], axis=1
|
| 2013 |
)
|
| 2014 |
+
else:
|
| 2015 |
+
return best_equation["lambda_format"](X)
|
| 2016 |
except Exception as error:
|
| 2017 |
raise ValueError(
|
| 2018 |
"Failed to evaluate the expression. "
|
|
|
|
| 2042 |
"""
|
| 2043 |
self.refresh()
|
| 2044 |
best_equation = self.get_best(index=index)
|
| 2045 |
+
if isinstance(best_equation, list):
|
| 2046 |
+
assert self.nout_ > 1
|
| 2047 |
return [eq["sympy_format"] for eq in best_equation]
|
| 2048 |
+
else:
|
| 2049 |
+
return best_equation["sympy_format"]
|
| 2050 |
|
| 2051 |
def latex(self, index=None, precision=3):
|
| 2052 |
"""
|
|
|
|
| 2106 |
self.set_params(output_jax_format=True)
|
| 2107 |
self.refresh()
|
| 2108 |
best_equation = self.get_best(index=index)
|
| 2109 |
+
if isinstance(best_equation, list):
|
| 2110 |
+
assert self.nout_ > 1
|
| 2111 |
return [eq["jax_format"] for eq in best_equation]
|
| 2112 |
+
else:
|
| 2113 |
+
return best_equation["jax_format"]
|
| 2114 |
|
| 2115 |
def pytorch(self, index=None):
|
| 2116 |
"""
|
|
|
|
| 2138 |
self.set_params(output_torch_format=True)
|
| 2139 |
self.refresh()
|
| 2140 |
best_equation = self.get_best(index=index)
|
| 2141 |
+
if isinstance(best_equation, pd.Series):
|
| 2142 |
+
return best_equation["torch_format"]
|
| 2143 |
+
else:
|
| 2144 |
return [eq["torch_format"] for eq in best_equation]
|
|
|
|
| 2145 |
|
| 2146 |
def _read_equation_file(self):
|
| 2147 |
"""Read the hall of fame file created by `SymbolicRegression.jl`."""
|
|
|
|
| 2240 |
lastComplexity = 0
|
| 2241 |
sympy_format = []
|
| 2242 |
lambda_format = []
|
| 2243 |
+
jax_format = []
|
| 2244 |
+
torch_format = []
|
|
|
|
|
|
|
| 2245 |
|
| 2246 |
for _, eqn_row in output.iterrows():
|
| 2247 |
eqn = pysr2sympy(
|
|
|
|
| 2353 |
"""
|
| 2354 |
self.refresh()
|
| 2355 |
|
| 2356 |
+
if isinstance(self.equations_, list):
|
| 2357 |
if indices is not None:
|
| 2358 |
assert isinstance(indices, list)
|
| 2359 |
assert isinstance(indices[0], list)
|
|
|
|
| 2362 |
table_string = sympy2multilatextable(
|
| 2363 |
self.equations_, indices=indices, precision=precision, columns=columns
|
| 2364 |
)
|
| 2365 |
+
elif isinstance(self.equations_, pd.DataFrame):
|
| 2366 |
if indices is not None:
|
| 2367 |
assert isinstance(indices, list)
|
| 2368 |
assert isinstance(indices[0], int)
|
|
|
|
| 2370 |
table_string = sympy2latextable(
|
| 2371 |
self.equations_, indices=indices, precision=precision, columns=columns
|
| 2372 |
)
|
| 2373 |
+
else:
|
| 2374 |
+
raise ValueError(
|
| 2375 |
+
"Invalid type for equations_ to pass to `latex_table`. "
|
| 2376 |
+
"Expected a DataFrame or a list of DataFrames."
|
| 2377 |
+
)
|
| 2378 |
|
| 2379 |
preamble_string = [
|
| 2380 |
r"\usepackage{breqn}",
|