Spaces:
Sleeping
Sleeping
Merge pull request #427 from MilesCranmer/refactor-exports
Browse files- pysr/export_latex.py +7 -7
- pysr/export_numpy.py +6 -3
- pysr/export_sympy.py +72 -0
- pysr/sr.py +26 -84
- pysr/test/test.py +9 -10
pysr/export_latex.py
CHANGED
|
@@ -19,7 +19,7 @@ class PreciseLatexPrinter(LatexPrinter):
|
|
| 19 |
return super()._print_Float(reduced_float)
|
| 20 |
|
| 21 |
|
| 22 |
-
def
|
| 23 |
"""Convert sympy expression to LaTeX with custom precision."""
|
| 24 |
settings["full_prec"] = full_prec
|
| 25 |
printer = PreciseLatexPrinter(settings=settings, prec=prec)
|
|
@@ -56,7 +56,7 @@ def generate_table_environment(columns=["equation", "complexity", "loss"]):
|
|
| 56 |
return top_latex_table, bottom_latex_table
|
| 57 |
|
| 58 |
|
| 59 |
-
def
|
| 60 |
equations: pd.DataFrame,
|
| 61 |
indices: List[int] = None,
|
| 62 |
precision: int = 3,
|
|
@@ -74,16 +74,16 @@ def generate_single_table(
|
|
| 74 |
indices = range(len(equations))
|
| 75 |
|
| 76 |
for i in indices:
|
| 77 |
-
latex_equation =
|
| 78 |
equations.iloc[i]["sympy_format"],
|
| 79 |
prec=precision,
|
| 80 |
)
|
| 81 |
complexity = str(equations.iloc[i]["complexity"])
|
| 82 |
-
loss =
|
| 83 |
sympy.Float(equations.iloc[i]["loss"]),
|
| 84 |
prec=precision,
|
| 85 |
)
|
| 86 |
-
score =
|
| 87 |
sympy.Float(equations.iloc[i]["score"]),
|
| 88 |
prec=precision,
|
| 89 |
)
|
|
@@ -124,7 +124,7 @@ def generate_single_table(
|
|
| 124 |
return "\n".join([latex_top, *latex_table_content, latex_bottom])
|
| 125 |
|
| 126 |
|
| 127 |
-
def
|
| 128 |
equations: List[pd.DataFrame],
|
| 129 |
indices: List[List[int]] = None,
|
| 130 |
precision: int = 3,
|
|
@@ -135,7 +135,7 @@ def generate_multiple_tables(
|
|
| 135 |
# TODO: Let user specify custom output variable
|
| 136 |
|
| 137 |
latex_tables = [
|
| 138 |
-
|
| 139 |
equations[i],
|
| 140 |
(None if not indices else indices[i]),
|
| 141 |
precision=precision,
|
|
|
|
| 19 |
return super()._print_Float(reduced_float)
|
| 20 |
|
| 21 |
|
| 22 |
+
def sympy2latex(expr, prec=3, full_prec=True, **settings):
|
| 23 |
"""Convert sympy expression to LaTeX with custom precision."""
|
| 24 |
settings["full_prec"] = full_prec
|
| 25 |
printer = PreciseLatexPrinter(settings=settings, prec=prec)
|
|
|
|
| 56 |
return top_latex_table, bottom_latex_table
|
| 57 |
|
| 58 |
|
| 59 |
+
def sympy2latextable(
|
| 60 |
equations: pd.DataFrame,
|
| 61 |
indices: List[int] = None,
|
| 62 |
precision: int = 3,
|
|
|
|
| 74 |
indices = range(len(equations))
|
| 75 |
|
| 76 |
for i in indices:
|
| 77 |
+
latex_equation = sympy2latex(
|
| 78 |
equations.iloc[i]["sympy_format"],
|
| 79 |
prec=precision,
|
| 80 |
)
|
| 81 |
complexity = str(equations.iloc[i]["complexity"])
|
| 82 |
+
loss = sympy2latex(
|
| 83 |
sympy.Float(equations.iloc[i]["loss"]),
|
| 84 |
prec=precision,
|
| 85 |
)
|
| 86 |
+
score = sympy2latex(
|
| 87 |
sympy.Float(equations.iloc[i]["score"]),
|
| 88 |
prec=precision,
|
| 89 |
)
|
|
|
|
| 124 |
return "\n".join([latex_top, *latex_table_content, latex_bottom])
|
| 125 |
|
| 126 |
|
| 127 |
+
def sympy2multilatextable(
|
| 128 |
equations: List[pd.DataFrame],
|
| 129 |
indices: List[List[int]] = None,
|
| 130 |
precision: int = 3,
|
|
|
|
| 135 |
# TODO: Let user specify custom output variable
|
| 136 |
|
| 137 |
latex_tables = [
|
| 138 |
+
sympy2latextable(
|
| 139 |
equations[i],
|
| 140 |
(None if not indices else indices[i]),
|
| 141 |
precision=precision,
|
pysr/export_numpy.py
CHANGED
|
@@ -6,14 +6,17 @@ import pandas as pd
|
|
| 6 |
from sympy import lambdify
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
class CallableEquation:
|
| 10 |
"""Simple wrapper for numpy lambda functions built with sympy"""
|
| 11 |
|
| 12 |
-
def __init__(self,
|
| 13 |
self._sympy = eqn
|
| 14 |
self._sympy_symbols = sympy_symbols
|
| 15 |
self._selection = selection
|
| 16 |
-
self._variable_names = variable_names
|
| 17 |
|
| 18 |
def __repr__(self):
|
| 19 |
return f"PySRFunction(X=>{self._sympy})"
|
|
@@ -23,7 +26,7 @@ class CallableEquation:
|
|
| 23 |
if isinstance(X, pd.DataFrame):
|
| 24 |
# Lambda function takes as argument:
|
| 25 |
return self._lambda(
|
| 26 |
-
**{k: X[k].values for k in self.
|
| 27 |
) * np.ones(expected_shape)
|
| 28 |
if self._selection is not None:
|
| 29 |
if X.shape[1] != len(self._selection):
|
|
|
|
| 6 |
from sympy import lambdify
|
| 7 |
|
| 8 |
|
| 9 |
+
def sympy2numpy(eqn, sympy_symbols, *, selection=None):
|
| 10 |
+
return CallableEquation(eqn, sympy_symbols, selection=selection)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
class CallableEquation:
|
| 14 |
"""Simple wrapper for numpy lambda functions built with sympy"""
|
| 15 |
|
| 16 |
+
def __init__(self, eqn, sympy_symbols, selection=None):
|
| 17 |
self._sympy = eqn
|
| 18 |
self._sympy_symbols = sympy_symbols
|
| 19 |
self._selection = selection
|
|
|
|
| 20 |
|
| 21 |
def __repr__(self):
|
| 22 |
return f"PySRFunction(X=>{self._sympy})"
|
|
|
|
| 26 |
if isinstance(X, pd.DataFrame):
|
| 27 |
# Lambda function takes as argument:
|
| 28 |
return self._lambda(
|
| 29 |
+
**{k: X[k].values for k in map(str, self._sympy_symbols)}
|
| 30 |
) * np.ones(expected_shape)
|
| 31 |
if self._selection is not None:
|
| 32 |
if X.shape[1] != len(self._selection):
|
pysr/export_sympy.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Define utilities to export to sympy"""
|
| 2 |
+
from typing import Callable, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import sympy
|
| 5 |
+
from sympy import sympify
|
| 6 |
+
|
| 7 |
+
sympy_mappings = {
|
| 8 |
+
"div": lambda x, y: x / y,
|
| 9 |
+
"mult": lambda x, y: x * y,
|
| 10 |
+
"sqrt": lambda x: sympy.sqrt(x),
|
| 11 |
+
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
|
| 12 |
+
"square": lambda x: x**2,
|
| 13 |
+
"cube": lambda x: x**3,
|
| 14 |
+
"plus": lambda x, y: x + y,
|
| 15 |
+
"sub": lambda x, y: x - y,
|
| 16 |
+
"neg": lambda x: -x,
|
| 17 |
+
"pow": lambda x, y: x**y,
|
| 18 |
+
"pow_abs": lambda x, y: abs(x) ** y,
|
| 19 |
+
"cos": sympy.cos,
|
| 20 |
+
"sin": sympy.sin,
|
| 21 |
+
"tan": sympy.tan,
|
| 22 |
+
"cosh": sympy.cosh,
|
| 23 |
+
"sinh": sympy.sinh,
|
| 24 |
+
"tanh": sympy.tanh,
|
| 25 |
+
"exp": sympy.exp,
|
| 26 |
+
"acos": sympy.acos,
|
| 27 |
+
"asin": sympy.asin,
|
| 28 |
+
"atan": sympy.atan,
|
| 29 |
+
"acosh": lambda x: sympy.acosh(x),
|
| 30 |
+
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
|
| 31 |
+
"asinh": sympy.asinh,
|
| 32 |
+
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
|
| 33 |
+
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
|
| 34 |
+
"abs": abs,
|
| 35 |
+
"mod": sympy.Mod,
|
| 36 |
+
"erf": sympy.erf,
|
| 37 |
+
"erfc": sympy.erfc,
|
| 38 |
+
"log": lambda x: sympy.log(x),
|
| 39 |
+
"log10": lambda x: sympy.log(x, 10),
|
| 40 |
+
"log2": lambda x: sympy.log(x, 2),
|
| 41 |
+
"log1p": lambda x: sympy.log(x + 1),
|
| 42 |
+
"log_abs": lambda x: sympy.log(abs(x)),
|
| 43 |
+
"log10_abs": lambda x: sympy.log(abs(x), 10),
|
| 44 |
+
"log2_abs": lambda x: sympy.log(abs(x), 2),
|
| 45 |
+
"log1p_abs": lambda x: sympy.log(abs(x) + 1),
|
| 46 |
+
"floor": sympy.floor,
|
| 47 |
+
"ceil": sympy.ceiling,
|
| 48 |
+
"sign": sympy.sign,
|
| 49 |
+
"gamma": sympy.gamma,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def create_sympy_symbols(
|
| 54 |
+
feature_names_in: Optional[List[str]] = None,
|
| 55 |
+
) -> List[sympy.Symbol]:
|
| 56 |
+
return [sympy.Symbol(variable) for variable in feature_names_in]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def pysr2sympy(
|
| 60 |
+
equation: str, *, extra_sympy_mappings: Optional[Dict[str, Callable]] = None
|
| 61 |
+
) -> sympy.Expr:
|
| 62 |
+
local_sympy_mappings = {
|
| 63 |
+
**(extra_sympy_mappings if extra_sympy_mappings else {}),
|
| 64 |
+
**sympy_mappings,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
return sympify(equation, locals=local_sympy_mappings)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def assert_valid_sympy_symbol(var_name: str) -> None:
|
| 71 |
+
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
|
| 72 |
+
raise ValueError(f"Variable name {var_name} is already a function name.")
|
pysr/sr.py
CHANGED
|
@@ -14,15 +14,16 @@ from pathlib import Path
|
|
| 14 |
|
| 15 |
import numpy as np
|
| 16 |
import pandas as pd
|
| 17 |
-
import sympy
|
| 18 |
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
|
| 19 |
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
| 20 |
from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
|
| 21 |
-
from sympy import sympify
|
| 22 |
|
| 23 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
| 24 |
-
from .
|
| 25 |
-
from .
|
|
|
|
|
|
|
|
|
|
| 26 |
from .julia_helpers import (
|
| 27 |
_escape_filename,
|
| 28 |
_load_backend,
|
|
@@ -37,51 +38,6 @@ Main = None # TODO: Rename to more descriptive name like "julia_runtime"
|
|
| 37 |
|
| 38 |
already_ran = False
|
| 39 |
|
| 40 |
-
sympy_mappings = {
|
| 41 |
-
"div": lambda x, y: x / y,
|
| 42 |
-
"mult": lambda x, y: x * y,
|
| 43 |
-
"sqrt": lambda x: sympy.sqrt(x),
|
| 44 |
-
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
|
| 45 |
-
"square": lambda x: x**2,
|
| 46 |
-
"cube": lambda x: x**3,
|
| 47 |
-
"plus": lambda x, y: x + y,
|
| 48 |
-
"sub": lambda x, y: x - y,
|
| 49 |
-
"neg": lambda x: -x,
|
| 50 |
-
"pow": lambda x, y: x**y,
|
| 51 |
-
"pow_abs": lambda x, y: abs(x) ** y,
|
| 52 |
-
"cos": sympy.cos,
|
| 53 |
-
"sin": sympy.sin,
|
| 54 |
-
"tan": sympy.tan,
|
| 55 |
-
"cosh": sympy.cosh,
|
| 56 |
-
"sinh": sympy.sinh,
|
| 57 |
-
"tanh": sympy.tanh,
|
| 58 |
-
"exp": sympy.exp,
|
| 59 |
-
"acos": sympy.acos,
|
| 60 |
-
"asin": sympy.asin,
|
| 61 |
-
"atan": sympy.atan,
|
| 62 |
-
"acosh": lambda x: sympy.acosh(x),
|
| 63 |
-
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
|
| 64 |
-
"asinh": sympy.asinh,
|
| 65 |
-
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
|
| 66 |
-
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
|
| 67 |
-
"abs": abs,
|
| 68 |
-
"mod": sympy.Mod,
|
| 69 |
-
"erf": sympy.erf,
|
| 70 |
-
"erfc": sympy.erfc,
|
| 71 |
-
"log": lambda x: sympy.log(x),
|
| 72 |
-
"log10": lambda x: sympy.log(x, 10),
|
| 73 |
-
"log2": lambda x: sympy.log(x, 2),
|
| 74 |
-
"log1p": lambda x: sympy.log(x + 1),
|
| 75 |
-
"log_abs": lambda x: sympy.log(abs(x)),
|
| 76 |
-
"log10_abs": lambda x: sympy.log(abs(x), 10),
|
| 77 |
-
"log2_abs": lambda x: sympy.log(abs(x), 2),
|
| 78 |
-
"log1p_abs": lambda x: sympy.log(abs(x) + 1),
|
| 79 |
-
"floor": sympy.floor,
|
| 80 |
-
"ceil": sympy.ceiling,
|
| 81 |
-
"sign": sympy.sign,
|
| 82 |
-
"gamma": sympy.gamma,
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
|
| 86 |
def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
| 87 |
warnings.warn(
|
|
@@ -188,10 +144,6 @@ def _check_assertions(
|
|
| 188 |
assert len(variable_names) == X.shape[1]
|
| 189 |
# Check none of the variable names are function names:
|
| 190 |
for var_name in variable_names:
|
| 191 |
-
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
|
| 192 |
-
raise ValueError(
|
| 193 |
-
f"Variable name {var_name} is already a function name."
|
| 194 |
-
)
|
| 195 |
# Check if alphanumeric only:
|
| 196 |
if not re.match(r"^[ββββββ
ββββa-zA-Z0-9_]+$", var_name):
|
| 197 |
raise ValueError(
|
|
@@ -199,6 +151,7 @@ def _check_assertions(
|
|
| 199 |
"Only alphanumeric characters, numbers, "
|
| 200 |
"and underscores are allowed."
|
| 201 |
)
|
|
|
|
| 202 |
if X_units is not None and len(X_units) != X.shape[1]:
|
| 203 |
raise ValueError(
|
| 204 |
"The number of units in `X_units` must equal the number of features in `X`."
|
|
@@ -2116,10 +2069,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2116 |
if self.nout_ > 1:
|
| 2117 |
output = []
|
| 2118 |
for s in sympy_representation:
|
| 2119 |
-
latex =
|
| 2120 |
output.append(latex)
|
| 2121 |
return output
|
| 2122 |
-
return
|
| 2123 |
|
| 2124 |
def jax(self, index=None):
|
| 2125 |
"""
|
|
@@ -2282,53 +2235,41 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2282 |
jax_format = []
|
| 2283 |
if self.output_torch_format:
|
| 2284 |
torch_format = []
|
| 2285 |
-
local_sympy_mappings = {
|
| 2286 |
-
**(self.extra_sympy_mappings if self.extra_sympy_mappings else {}),
|
| 2287 |
-
**sympy_mappings,
|
| 2288 |
-
}
|
| 2289 |
-
|
| 2290 |
-
sympy_symbols = [
|
| 2291 |
-
sympy.Symbol(variable) for variable in self.feature_names_in_
|
| 2292 |
-
]
|
| 2293 |
|
| 2294 |
for _, eqn_row in output.iterrows():
|
| 2295 |
-
eqn =
|
|
|
|
|
|
|
|
|
|
| 2296 |
sympy_format.append(eqn)
|
| 2297 |
|
| 2298 |
-
#
|
|
|
|
| 2299 |
lambda_format.append(
|
| 2300 |
-
|
| 2301 |
-
|
|
|
|
|
|
|
| 2302 |
)
|
| 2303 |
)
|
| 2304 |
|
| 2305 |
# JAX:
|
| 2306 |
if self.output_jax_format:
|
| 2307 |
-
from .export_jax import sympy2jax
|
| 2308 |
-
|
| 2309 |
func, params = sympy2jax(
|
| 2310 |
eqn,
|
| 2311 |
sympy_symbols,
|
| 2312 |
selection=self.selection_mask_,
|
| 2313 |
-
extra_jax_mappings=
|
| 2314 |
-
self.extra_jax_mappings if self.extra_jax_mappings else {}
|
| 2315 |
-
),
|
| 2316 |
)
|
| 2317 |
jax_format.append({"callable": func, "parameters": params})
|
| 2318 |
|
| 2319 |
# Torch:
|
| 2320 |
if self.output_torch_format:
|
| 2321 |
-
from .export_torch import sympy2torch
|
| 2322 |
-
|
| 2323 |
module = sympy2torch(
|
| 2324 |
eqn,
|
| 2325 |
sympy_symbols,
|
| 2326 |
selection=self.selection_mask_,
|
| 2327 |
-
extra_torch_mappings=
|
| 2328 |
-
self.extra_torch_mappings
|
| 2329 |
-
if self.extra_torch_mappings
|
| 2330 |
-
else {}
|
| 2331 |
-
),
|
| 2332 |
)
|
| 2333 |
torch_format.append(module)
|
| 2334 |
|
|
@@ -2410,17 +2351,18 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2410 |
assert isinstance(indices[0], list)
|
| 2411 |
assert len(indices) == self.nout_
|
| 2412 |
|
| 2413 |
-
|
|
|
|
|
|
|
| 2414 |
else:
|
| 2415 |
if indices is not None:
|
| 2416 |
assert isinstance(indices, list)
|
| 2417 |
assert isinstance(indices[0], int)
|
| 2418 |
|
| 2419 |
-
|
|
|
|
|
|
|
| 2420 |
|
| 2421 |
-
table_string = generator_fnc(
|
| 2422 |
-
self.equations_, indices=indices, precision=precision, columns=columns
|
| 2423 |
-
)
|
| 2424 |
preamble_string = [
|
| 2425 |
r"\usepackage{breqn}",
|
| 2426 |
r"\usepackage{booktabs}",
|
|
|
|
| 14 |
|
| 15 |
import numpy as np
|
| 16 |
import pandas as pd
|
|
|
|
| 17 |
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
|
| 18 |
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
| 19 |
from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
|
|
|
|
| 20 |
|
| 21 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
| 22 |
+
from .export_jax import sympy2jax
|
| 23 |
+
from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable
|
| 24 |
+
from .export_numpy import sympy2numpy
|
| 25 |
+
from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
|
| 26 |
+
from .export_torch import sympy2torch
|
| 27 |
from .julia_helpers import (
|
| 28 |
_escape_filename,
|
| 29 |
_load_backend,
|
|
|
|
| 38 |
|
| 39 |
already_ran = False
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
| 43 |
warnings.warn(
|
|
|
|
| 144 |
assert len(variable_names) == X.shape[1]
|
| 145 |
# Check none of the variable names are function names:
|
| 146 |
for var_name in variable_names:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
# Check if alphanumeric only:
|
| 148 |
if not re.match(r"^[ββββββ
ββββa-zA-Z0-9_]+$", var_name):
|
| 149 |
raise ValueError(
|
|
|
|
| 151 |
"Only alphanumeric characters, numbers, "
|
| 152 |
"and underscores are allowed."
|
| 153 |
)
|
| 154 |
+
assert_valid_sympy_symbol(var_name)
|
| 155 |
if X_units is not None and len(X_units) != X.shape[1]:
|
| 156 |
raise ValueError(
|
| 157 |
"The number of units in `X_units` must equal the number of features in `X`."
|
|
|
|
| 2069 |
if self.nout_ > 1:
|
| 2070 |
output = []
|
| 2071 |
for s in sympy_representation:
|
| 2072 |
+
latex = sympy2latex(s, prec=precision)
|
| 2073 |
output.append(latex)
|
| 2074 |
return output
|
| 2075 |
+
return sympy2latex(sympy_representation, prec=precision)
|
| 2076 |
|
| 2077 |
def jax(self, index=None):
|
| 2078 |
"""
|
|
|
|
| 2235 |
jax_format = []
|
| 2236 |
if self.output_torch_format:
|
| 2237 |
torch_format = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2238 |
|
| 2239 |
for _, eqn_row in output.iterrows():
|
| 2240 |
+
eqn = pysr2sympy(
|
| 2241 |
+
eqn_row["equation"],
|
| 2242 |
+
extra_sympy_mappings=self.extra_sympy_mappings,
|
| 2243 |
+
)
|
| 2244 |
sympy_format.append(eqn)
|
| 2245 |
|
| 2246 |
+
# NumPy:
|
| 2247 |
+
sympy_symbols = create_sympy_symbols(self.feature_names_in_)
|
| 2248 |
lambda_format.append(
|
| 2249 |
+
sympy2numpy(
|
| 2250 |
+
eqn,
|
| 2251 |
+
sympy_symbols,
|
| 2252 |
+
selection=self.selection_mask_,
|
| 2253 |
)
|
| 2254 |
)
|
| 2255 |
|
| 2256 |
# JAX:
|
| 2257 |
if self.output_jax_format:
|
|
|
|
|
|
|
| 2258 |
func, params = sympy2jax(
|
| 2259 |
eqn,
|
| 2260 |
sympy_symbols,
|
| 2261 |
selection=self.selection_mask_,
|
| 2262 |
+
extra_jax_mappings=self.extra_jax_mappings,
|
|
|
|
|
|
|
| 2263 |
)
|
| 2264 |
jax_format.append({"callable": func, "parameters": params})
|
| 2265 |
|
| 2266 |
# Torch:
|
| 2267 |
if self.output_torch_format:
|
|
|
|
|
|
|
| 2268 |
module = sympy2torch(
|
| 2269 |
eqn,
|
| 2270 |
sympy_symbols,
|
| 2271 |
selection=self.selection_mask_,
|
| 2272 |
+
extra_torch_mappings=self.extra_torch_mappings,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2273 |
)
|
| 2274 |
torch_format.append(module)
|
| 2275 |
|
|
|
|
| 2351 |
assert isinstance(indices[0], list)
|
| 2352 |
assert len(indices) == self.nout_
|
| 2353 |
|
| 2354 |
+
table_string = sympy2multilatextable(
|
| 2355 |
+
self.equations_, indices=indices, precision=precision, columns=columns
|
| 2356 |
+
)
|
| 2357 |
else:
|
| 2358 |
if indices is not None:
|
| 2359 |
assert isinstance(indices, list)
|
| 2360 |
assert isinstance(indices[0], int)
|
| 2361 |
|
| 2362 |
+
table_string = sympy2latextable(
|
| 2363 |
+
self.equations_, indices=indices, precision=precision, columns=columns
|
| 2364 |
+
)
|
| 2365 |
|
|
|
|
|
|
|
|
|
|
| 2366 |
preamble_string = [
|
| 2367 |
r"\usepackage{breqn}",
|
| 2368 |
r"\usepackage{booktabs}",
|
pysr/test/test.py
CHANGED
|
@@ -10,11 +10,10 @@ from pathlib import Path
|
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
import sympy
|
| 13 |
-
from sklearn import model_selection
|
| 14 |
from sklearn.utils.estimator_checks import check_estimator
|
| 15 |
|
| 16 |
from .. import PySRRegressor, julia_helpers
|
| 17 |
-
from ..export_latex import
|
| 18 |
from ..sr import (
|
| 19 |
_check_assertions,
|
| 20 |
_csv_filename_to_pkl_filename,
|
|
@@ -884,23 +883,23 @@ class TestLaTeXTable(unittest.TestCase):
|
|
| 884 |
def test_latex_float_precision(self):
|
| 885 |
"""Test that we can print latex expressions with custom precision"""
|
| 886 |
expr = sympy.Float(4583.4485748, dps=50)
|
| 887 |
-
self.assertEqual(
|
| 888 |
-
self.assertEqual(
|
| 889 |
-
self.assertEqual(
|
| 890 |
-
self.assertEqual(
|
| 891 |
-
self.assertEqual(
|
| 892 |
|
| 893 |
# Multiple numbers:
|
| 894 |
x = sympy.Symbol("x")
|
| 895 |
expr = x * 3232.324857384 - 1.4857485e-10
|
| 896 |
self.assertEqual(
|
| 897 |
-
|
| 898 |
)
|
| 899 |
self.assertEqual(
|
| 900 |
-
|
| 901 |
)
|
| 902 |
self.assertEqual(
|
| 903 |
-
|
| 904 |
)
|
| 905 |
|
| 906 |
def test_latex_break_long_equation(self):
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
import sympy
|
|
|
|
| 13 |
from sklearn.utils.estimator_checks import check_estimator
|
| 14 |
|
| 15 |
from .. import PySRRegressor, julia_helpers
|
| 16 |
+
from ..export_latex import sympy2latex
|
| 17 |
from ..sr import (
|
| 18 |
_check_assertions,
|
| 19 |
_csv_filename_to_pkl_filename,
|
|
|
|
| 883 |
def test_latex_float_precision(self):
|
| 884 |
"""Test that we can print latex expressions with custom precision"""
|
| 885 |
expr = sympy.Float(4583.4485748, dps=50)
|
| 886 |
+
self.assertEqual(sympy2latex(expr, prec=6), r"4583.45")
|
| 887 |
+
self.assertEqual(sympy2latex(expr, prec=5), r"4583.4")
|
| 888 |
+
self.assertEqual(sympy2latex(expr, prec=4), r"4583.")
|
| 889 |
+
self.assertEqual(sympy2latex(expr, prec=3), r"4.58 \cdot 10^{3}")
|
| 890 |
+
self.assertEqual(sympy2latex(expr, prec=2), r"4.6 \cdot 10^{3}")
|
| 891 |
|
| 892 |
# Multiple numbers:
|
| 893 |
x = sympy.Symbol("x")
|
| 894 |
expr = x * 3232.324857384 - 1.4857485e-10
|
| 895 |
self.assertEqual(
|
| 896 |
+
sympy2latex(expr, prec=2), r"3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
|
| 897 |
)
|
| 898 |
self.assertEqual(
|
| 899 |
+
sympy2latex(expr, prec=3), r"3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
|
| 900 |
)
|
| 901 |
self.assertEqual(
|
| 902 |
+
sympy2latex(expr, prec=8), r"3232.3249 x - 1.4857485 \cdot 10^{-10}"
|
| 903 |
)
|
| 904 |
|
| 905 |
def test_latex_break_long_equation(self):
|