Spaces:
Running
Running
Merge pull request #156 from MilesCranmer/latex-table
Browse files- pysr/export_latex.py +153 -0
- pysr/sr.py +65 -3
- test/test.py +234 -19
pysr/export_latex.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Functions to help export PySR equations to LaTeX."""
|
| 2 |
+
import sympy
|
| 3 |
+
from sympy.printing.latex import LatexPrinter
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from typing import List
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PreciseLatexPrinter(LatexPrinter):
|
| 10 |
+
"""Modified SymPy printer with custom float precision."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, settings=None, prec=3):
|
| 13 |
+
super().__init__(settings)
|
| 14 |
+
self.prec = prec
|
| 15 |
+
|
| 16 |
+
def _print_Float(self, expr):
|
| 17 |
+
# Reduce precision of float:
|
| 18 |
+
reduced_float = sympy.Float(expr, self.prec)
|
| 19 |
+
return super()._print_Float(reduced_float)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def to_latex(expr, prec=3, full_prec=True, **settings):
|
| 23 |
+
"""Convert sympy expression to LaTeX with custom precision."""
|
| 24 |
+
settings["full_prec"] = full_prec
|
| 25 |
+
printer = PreciseLatexPrinter(settings=settings, prec=prec)
|
| 26 |
+
return printer.doprint(expr)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def generate_table_environment(columns=["equation", "complexity", "loss"]):
|
| 30 |
+
margins = "c" * len(columns)
|
| 31 |
+
column_map = {
|
| 32 |
+
"complexity": "Complexity",
|
| 33 |
+
"loss": "Loss",
|
| 34 |
+
"equation": "Equation",
|
| 35 |
+
"score": "Score",
|
| 36 |
+
}
|
| 37 |
+
columns = [column_map[col] for col in columns]
|
| 38 |
+
top_pieces = [
|
| 39 |
+
r"\begin{table}[h]",
|
| 40 |
+
r"\begin{center}",
|
| 41 |
+
r"\begin{tabular}{@{}" + margins + r"@{}}",
|
| 42 |
+
r"\toprule",
|
| 43 |
+
" & ".join(columns) + r" \\",
|
| 44 |
+
r"\midrule",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
bottom_pieces = [
|
| 48 |
+
r"\bottomrule",
|
| 49 |
+
r"\end{tabular}",
|
| 50 |
+
r"\end{center}",
|
| 51 |
+
r"\end{table}",
|
| 52 |
+
]
|
| 53 |
+
top_latex_table = "\n".join(top_pieces)
|
| 54 |
+
bottom_latex_table = "\n".join(bottom_pieces)
|
| 55 |
+
|
| 56 |
+
return top_latex_table, bottom_latex_table
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def generate_single_table(
|
| 60 |
+
equations: pd.DataFrame,
|
| 61 |
+
indices: List[int] = None,
|
| 62 |
+
precision: int = 3,
|
| 63 |
+
columns=["equation", "complexity", "loss", "score"],
|
| 64 |
+
max_equation_length: int = 50,
|
| 65 |
+
output_variable_name: str = "y",
|
| 66 |
+
):
|
| 67 |
+
"""Generate a booktabs-style LaTeX table for a single set of equations."""
|
| 68 |
+
assert isinstance(equations, pd.DataFrame)
|
| 69 |
+
|
| 70 |
+
latex_top, latex_bottom = generate_table_environment(columns)
|
| 71 |
+
latex_table_content = []
|
| 72 |
+
|
| 73 |
+
if indices is None:
|
| 74 |
+
indices = range(len(equations))
|
| 75 |
+
|
| 76 |
+
for i in indices:
|
| 77 |
+
latex_equation = to_latex(
|
| 78 |
+
equations.iloc[i]["sympy_format"],
|
| 79 |
+
prec=precision,
|
| 80 |
+
)
|
| 81 |
+
complexity = str(equations.iloc[i]["complexity"])
|
| 82 |
+
loss = to_latex(
|
| 83 |
+
sympy.Float(equations.iloc[i]["loss"]),
|
| 84 |
+
prec=precision,
|
| 85 |
+
)
|
| 86 |
+
score = to_latex(
|
| 87 |
+
sympy.Float(equations.iloc[i]["score"]),
|
| 88 |
+
prec=precision,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
row_pieces = []
|
| 92 |
+
for col in columns:
|
| 93 |
+
if col == "equation":
|
| 94 |
+
if len(latex_equation) < max_equation_length:
|
| 95 |
+
row_pieces.append(
|
| 96 |
+
"$" + output_variable_name + " = " + latex_equation + "$"
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
|
| 100 |
+
broken_latex_equation = " ".join(
|
| 101 |
+
[
|
| 102 |
+
r"\begin{minipage}{0.8\linewidth}",
|
| 103 |
+
r"\vspace{-1em}",
|
| 104 |
+
r"\begin{dmath*}",
|
| 105 |
+
output_variable_name + " = " + latex_equation,
|
| 106 |
+
r"\end{dmath*}",
|
| 107 |
+
r"\end{minipage}",
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
row_pieces.append(broken_latex_equation)
|
| 111 |
+
|
| 112 |
+
elif col == "complexity":
|
| 113 |
+
row_pieces.append("$" + complexity + "$")
|
| 114 |
+
elif col == "loss":
|
| 115 |
+
row_pieces.append("$" + loss + "$")
|
| 116 |
+
elif col == "score":
|
| 117 |
+
row_pieces.append("$" + score + "$")
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError(f"Unknown column: {col}")
|
| 120 |
+
|
| 121 |
+
latex_table_content.append(
|
| 122 |
+
" & ".join(row_pieces) + r" \\",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return "\n".join([latex_top, *latex_table_content, latex_bottom])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def generate_multiple_tables(
|
| 129 |
+
equations: List[pd.DataFrame],
|
| 130 |
+
indices: List[List[int]] = None,
|
| 131 |
+
precision: int = 3,
|
| 132 |
+
columns=["equation", "complexity", "loss", "score"],
|
| 133 |
+
output_variable_names: str = None,
|
| 134 |
+
):
|
| 135 |
+
"""Generate multiple latex tables for a list of equation sets."""
|
| 136 |
+
# TODO: Let user specify custom output variable
|
| 137 |
+
|
| 138 |
+
latex_tables = [
|
| 139 |
+
generate_single_table(
|
| 140 |
+
equations[i],
|
| 141 |
+
(None if not indices else indices[i]),
|
| 142 |
+
precision=precision,
|
| 143 |
+
columns=columns,
|
| 144 |
+
output_variable_name=(
|
| 145 |
+
"y_{" + str(i) + "}"
|
| 146 |
+
if output_variable_names is None
|
| 147 |
+
else output_variable_names[i]
|
| 148 |
+
),
|
| 149 |
+
)
|
| 150 |
+
for i in range(len(equations))
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
return "\n\n".join(latex_tables)
|
pysr/sr.py
CHANGED
|
@@ -29,6 +29,7 @@ from .julia_helpers import (
|
|
| 29 |
import_error_string,
|
| 30 |
)
|
| 31 |
from .export_numpy import CallableEquation
|
|
|
|
| 32 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
| 33 |
|
| 34 |
|
|
@@ -1875,7 +1876,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1875 |
return [eq["sympy_format"] for eq in best_equation]
|
| 1876 |
return best_equation["sympy_format"]
|
| 1877 |
|
| 1878 |
-
def latex(self, index=None):
|
| 1879 |
"""
|
| 1880 |
Return latex representation of the equation(s) chosen by `model_selection`.
|
| 1881 |
|
|
@@ -1887,6 +1888,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1887 |
the `model_selection` parameter. If there are multiple output
|
| 1888 |
features, then pass a list of indices with the order the same
|
| 1889 |
as the output feature.
|
|
|
|
|
|
|
|
|
|
| 1890 |
|
| 1891 |
Returns
|
| 1892 |
-------
|
|
@@ -1896,8 +1900,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1896 |
self.refresh()
|
| 1897 |
sympy_representation = self.sympy(index=index)
|
| 1898 |
if self.nout_ > 1:
|
| 1899 |
-
|
| 1900 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1901 |
|
| 1902 |
def jax(self, index=None):
|
| 1903 |
"""
|
|
@@ -2147,6 +2155,60 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2147 |
return ret_outputs
|
| 2148 |
return ret_outputs[0]
|
| 2149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2150 |
|
| 2151 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
|
| 2152 |
"""
|
|
|
|
| 29 |
import_error_string,
|
| 30 |
)
|
| 31 |
from .export_numpy import CallableEquation
|
| 32 |
+
from .export_latex import generate_single_table, generate_multiple_tables, to_latex
|
| 33 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
| 34 |
|
| 35 |
|
|
|
|
| 1876 |
return [eq["sympy_format"] for eq in best_equation]
|
| 1877 |
return best_equation["sympy_format"]
|
| 1878 |
|
| 1879 |
+
def latex(self, index=None, precision=3):
|
| 1880 |
"""
|
| 1881 |
Return latex representation of the equation(s) chosen by `model_selection`.
|
| 1882 |
|
|
|
|
| 1888 |
the `model_selection` parameter. If there are multiple output
|
| 1889 |
features, then pass a list of indices with the order the same
|
| 1890 |
as the output feature.
|
| 1891 |
+
precision : int, default=3
|
| 1892 |
+
The number of significant figures shown in the LaTeX
|
| 1893 |
+
representation.
|
| 1894 |
|
| 1895 |
Returns
|
| 1896 |
-------
|
|
|
|
| 1900 |
self.refresh()
|
| 1901 |
sympy_representation = self.sympy(index=index)
|
| 1902 |
if self.nout_ > 1:
|
| 1903 |
+
output = []
|
| 1904 |
+
for s in sympy_representation:
|
| 1905 |
+
latex = to_latex(s, prec=precision)
|
| 1906 |
+
output.append(latex)
|
| 1907 |
+
return output
|
| 1908 |
+
return to_latex(sympy_representation, prec=precision)
|
| 1909 |
|
| 1910 |
def jax(self, index=None):
|
| 1911 |
"""
|
|
|
|
| 2155 |
return ret_outputs
|
| 2156 |
return ret_outputs[0]
|
| 2157 |
|
| 2158 |
+
def latex_table(
|
| 2159 |
+
self,
|
| 2160 |
+
indices=None,
|
| 2161 |
+
precision=3,
|
| 2162 |
+
columns=["equation", "complexity", "loss", "score"],
|
| 2163 |
+
):
|
| 2164 |
+
"""Create a LaTeX/booktabs table for all, or some, of the equations.
|
| 2165 |
+
|
| 2166 |
+
Parameters
|
| 2167 |
+
----------
|
| 2168 |
+
indices : list[int] | list[list[int]], default=None
|
| 2169 |
+
If you wish to select a particular subset of equations from
|
| 2170 |
+
`self.equations_`, give the row numbers here. By default,
|
| 2171 |
+
all equations will be used. If there are multiple output
|
| 2172 |
+
features, then pass a list of lists.
|
| 2173 |
+
precision : int, default=3
|
| 2174 |
+
The number of significant figures shown in the LaTeX
|
| 2175 |
+
representations.
|
| 2176 |
+
columns : list[str], default=["equation", "complexity", "loss", "score"]
|
| 2177 |
+
Which columns to include in the table.
|
| 2178 |
+
|
| 2179 |
+
Returns
|
| 2180 |
+
-------
|
| 2181 |
+
latex_table_str : str
|
| 2182 |
+
A string that will render a table in LaTeX of the equations.
|
| 2183 |
+
"""
|
| 2184 |
+
self.refresh()
|
| 2185 |
+
|
| 2186 |
+
if self.nout_ > 1:
|
| 2187 |
+
if indices is not None:
|
| 2188 |
+
assert isinstance(indices, list)
|
| 2189 |
+
assert isinstance(indices[0], list)
|
| 2190 |
+
assert isinstance(len(indices), self.nout_)
|
| 2191 |
+
|
| 2192 |
+
generator_fnc = generate_multiple_tables
|
| 2193 |
+
else:
|
| 2194 |
+
if indices is not None:
|
| 2195 |
+
assert isinstance(indices, list)
|
| 2196 |
+
assert isinstance(indices[0], int)
|
| 2197 |
+
|
| 2198 |
+
generator_fnc = generate_single_table
|
| 2199 |
+
|
| 2200 |
+
table_string = generator_fnc(
|
| 2201 |
+
self.equations_, indices=indices, precision=precision, columns=columns
|
| 2202 |
+
)
|
| 2203 |
+
preamble_string = [
|
| 2204 |
+
r"\usepackage{breqn}",
|
| 2205 |
+
r"\usepackage{booktabs}",
|
| 2206 |
+
"",
|
| 2207 |
+
"...",
|
| 2208 |
+
"",
|
| 2209 |
+
]
|
| 2210 |
+
return "\n".join(preamble_string + [table_string])
|
| 2211 |
+
|
| 2212 |
|
| 2213 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
|
| 2214 |
"""
|
test/test.py
CHANGED
|
@@ -11,6 +11,7 @@ from pysr.sr import (
|
|
| 11 |
_csv_filename_to_pkl_filename,
|
| 12 |
idx_model_selection,
|
| 13 |
)
|
|
|
|
| 14 |
from sklearn.utils.estimator_checks import check_estimator
|
| 15 |
import sympy
|
| 16 |
import pandas as pd
|
|
@@ -353,19 +354,49 @@ class TestPipeline(unittest.TestCase):
|
|
| 353 |
np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
|
| 354 |
|
| 355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
class TestBest(unittest.TestCase):
|
| 357 |
def setUp(self):
|
| 358 |
self.rstate = np.random.RandomState(0)
|
| 359 |
self.X = self.rstate.randn(10, 2)
|
| 360 |
self.y = np.cos(self.X[:, 0]) ** 2
|
| 361 |
-
self.model = PySRRegressor(
|
| 362 |
-
progress=False,
|
| 363 |
-
niterations=1,
|
| 364 |
-
extra_sympy_mappings={},
|
| 365 |
-
output_jax_format=False,
|
| 366 |
-
model_selection="accuracy",
|
| 367 |
-
equation_file="equation_file.csv",
|
| 368 |
-
)
|
| 369 |
equations = pd.DataFrame(
|
| 370 |
{
|
| 371 |
"equation": ["1.0", "cos(x0)", "square(cos(x0))"],
|
|
@@ -373,17 +404,7 @@ class TestBest(unittest.TestCase):
|
|
| 373 |
"complexity": [1, 2, 3],
|
| 374 |
}
|
| 375 |
)
|
| 376 |
-
|
| 377 |
-
# Set up internal parameters as if it had been fitted:
|
| 378 |
-
self.model.equation_file_ = "equation_file.csv"
|
| 379 |
-
self.model.nout_ = 1
|
| 380 |
-
self.model.selection_mask_ = None
|
| 381 |
-
self.model.feature_names_in_ = np.array(["x0", "x1"], dtype=object)
|
| 382 |
-
equations["complexity loss equation".split(" ")].to_csv(
|
| 383 |
-
"equation_file.csv.bkup"
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
self.model.refresh()
|
| 387 |
self.equations_ = self.model.equations_
|
| 388 |
|
| 389 |
def test_best(self):
|
|
@@ -585,3 +606,197 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 585 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 586 |
# If any checks failed don't let the test pass.
|
| 587 |
self.assertEqual(len(exception_messages), 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
_csv_filename_to_pkl_filename,
|
| 12 |
idx_model_selection,
|
| 13 |
)
|
| 14 |
+
from pysr.export_latex import to_latex
|
| 15 |
from sklearn.utils.estimator_checks import check_estimator
|
| 16 |
import sympy
|
| 17 |
import pandas as pd
|
|
|
|
| 354 |
np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
|
| 355 |
|
| 356 |
|
| 357 |
+
def manually_create_model(equations, feature_names=None):
|
| 358 |
+
if feature_names is None:
|
| 359 |
+
feature_names = ["x0", "x1"]
|
| 360 |
+
|
| 361 |
+
model = PySRRegressor(
|
| 362 |
+
progress=False,
|
| 363 |
+
niterations=1,
|
| 364 |
+
extra_sympy_mappings={},
|
| 365 |
+
output_jax_format=False,
|
| 366 |
+
model_selection="accuracy",
|
| 367 |
+
equation_file="equation_file.csv",
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Set up internal parameters as if it had been fitted:
|
| 371 |
+
if isinstance(equations, list):
|
| 372 |
+
# Multi-output.
|
| 373 |
+
model.equation_file_ = "equation_file.csv"
|
| 374 |
+
model.nout_ = len(equations)
|
| 375 |
+
model.selection_mask_ = None
|
| 376 |
+
model.feature_names_in_ = np.array(feature_names, dtype=object)
|
| 377 |
+
for i in range(model.nout_):
|
| 378 |
+
equations[i]["complexity loss equation".split(" ")].to_csv(
|
| 379 |
+
f"equation_file.csv.out{i+1}.bkup"
|
| 380 |
+
)
|
| 381 |
+
else:
|
| 382 |
+
model.equation_file_ = "equation_file.csv"
|
| 383 |
+
model.nout_ = 1
|
| 384 |
+
model.selection_mask_ = None
|
| 385 |
+
model.feature_names_in_ = np.array(feature_names, dtype=object)
|
| 386 |
+
equations["complexity loss equation".split(" ")].to_csv(
|
| 387 |
+
"equation_file.csv.bkup"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
model.refresh()
|
| 391 |
+
|
| 392 |
+
return model
|
| 393 |
+
|
| 394 |
+
|
| 395 |
class TestBest(unittest.TestCase):
|
| 396 |
def setUp(self):
|
| 397 |
self.rstate = np.random.RandomState(0)
|
| 398 |
self.X = self.rstate.randn(10, 2)
|
| 399 |
self.y = np.cos(self.X[:, 0]) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
equations = pd.DataFrame(
|
| 401 |
{
|
| 402 |
"equation": ["1.0", "cos(x0)", "square(cos(x0))"],
|
|
|
|
| 404 |
"complexity": [1, 2, 3],
|
| 405 |
}
|
| 406 |
)
|
| 407 |
+
self.model = manually_create_model(equations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
self.equations_ = self.model.equations_
|
| 409 |
|
| 410 |
def test_best(self):
|
|
|
|
| 606 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 607 |
# If any checks failed don't let the test pass.
|
| 608 |
self.assertEqual(len(exception_messages), 0)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
TRUE_PREAMBLE = "\n".join(
|
| 612 |
+
[
|
| 613 |
+
r"\usepackage{breqn}",
|
| 614 |
+
r"\usepackage{booktabs}",
|
| 615 |
+
"",
|
| 616 |
+
"...",
|
| 617 |
+
"",
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
class TestLaTeXTable(unittest.TestCase):
|
| 623 |
+
def setUp(self):
|
| 624 |
+
equations = pd.DataFrame(
|
| 625 |
+
dict(
|
| 626 |
+
equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
|
| 627 |
+
loss=[1.052, 0.02315, 1.12347e-15],
|
| 628 |
+
complexity=[1, 2, 8],
|
| 629 |
+
)
|
| 630 |
+
)
|
| 631 |
+
self.model = manually_create_model(equations)
|
| 632 |
+
self.maxDiff = None
|
| 633 |
+
|
| 634 |
+
def create_true_latex(self, middle_part, include_score=False):
|
| 635 |
+
if include_score:
|
| 636 |
+
true_latex_table_str = r"""
|
| 637 |
+
\begin{table}[h]
|
| 638 |
+
\begin{center}
|
| 639 |
+
\begin{tabular}{@{}cccc@{}}
|
| 640 |
+
\toprule
|
| 641 |
+
Equation & Complexity & Loss & Score \\
|
| 642 |
+
\midrule"""
|
| 643 |
+
else:
|
| 644 |
+
true_latex_table_str = r"""
|
| 645 |
+
\begin{table}[h]
|
| 646 |
+
\begin{center}
|
| 647 |
+
\begin{tabular}{@{}ccc@{}}
|
| 648 |
+
\toprule
|
| 649 |
+
Equation & Complexity & Loss \\
|
| 650 |
+
\midrule"""
|
| 651 |
+
true_latex_table_str += middle_part
|
| 652 |
+
true_latex_table_str += r"""\bottomrule
|
| 653 |
+
\end{tabular}
|
| 654 |
+
\end{center}
|
| 655 |
+
\end{table}
|
| 656 |
+
"""
|
| 657 |
+
# First, remove empty lines:
|
| 658 |
+
true_latex_table_str = "\n".join(
|
| 659 |
+
[line.strip() for line in true_latex_table_str.split("\n") if len(line) > 0]
|
| 660 |
+
)
|
| 661 |
+
return true_latex_table_str.strip()
|
| 662 |
+
|
| 663 |
+
def test_simple_table(self):
|
| 664 |
+
latex_table_str = self.model.latex_table(
|
| 665 |
+
columns=["equation", "complexity", "loss"]
|
| 666 |
+
)
|
| 667 |
+
middle_part = r"""
|
| 668 |
+
$y = x_{0}$ & $1$ & $1.05$ \\
|
| 669 |
+
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
|
| 670 |
+
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
|
| 671 |
+
"""
|
| 672 |
+
true_latex_table_str = (
|
| 673 |
+
TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
|
| 674 |
+
)
|
| 675 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
| 676 |
+
|
| 677 |
+
def test_other_precision(self):
|
| 678 |
+
latex_table_str = self.model.latex_table(
|
| 679 |
+
precision=5, columns=["equation", "complexity", "loss"]
|
| 680 |
+
)
|
| 681 |
+
middle_part = r"""
|
| 682 |
+
$y = x_{0}$ & $1$ & $1.0520$ \\
|
| 683 |
+
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
|
| 684 |
+
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
|
| 685 |
+
"""
|
| 686 |
+
true_latex_table_str = (
|
| 687 |
+
TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
|
| 688 |
+
)
|
| 689 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
| 690 |
+
|
| 691 |
+
def test_include_score(self):
|
| 692 |
+
latex_table_str = self.model.latex_table()
|
| 693 |
+
middle_part = r"""
|
| 694 |
+
$y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
|
| 695 |
+
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
| 696 |
+
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
|
| 697 |
+
"""
|
| 698 |
+
true_latex_table_str = (
|
| 699 |
+
TRUE_PREAMBLE
|
| 700 |
+
+ "\n"
|
| 701 |
+
+ self.create_true_latex(middle_part, include_score=True)
|
| 702 |
+
)
|
| 703 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
| 704 |
+
|
| 705 |
+
def test_last_equation(self):
|
| 706 |
+
latex_table_str = self.model.latex_table(
|
| 707 |
+
indices=[2], columns=["equation", "complexity", "loss"]
|
| 708 |
+
)
|
| 709 |
+
middle_part = r"""
|
| 710 |
+
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
|
| 711 |
+
"""
|
| 712 |
+
true_latex_table_str = (
|
| 713 |
+
TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
|
| 714 |
+
)
|
| 715 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
| 716 |
+
|
| 717 |
+
def test_multi_output(self):
|
| 718 |
+
equations1 = pd.DataFrame(
|
| 719 |
+
dict(
|
| 720 |
+
equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
|
| 721 |
+
loss=[1.052, 0.02315, 1.12347e-15],
|
| 722 |
+
complexity=[1, 2, 8],
|
| 723 |
+
)
|
| 724 |
+
)
|
| 725 |
+
equations2 = pd.DataFrame(
|
| 726 |
+
dict(
|
| 727 |
+
equation=["x1", "cos(x1)", "x0 * x0 * x1"],
|
| 728 |
+
loss=[1.32, 0.052, 2e-15],
|
| 729 |
+
complexity=[1, 2, 5],
|
| 730 |
+
)
|
| 731 |
+
)
|
| 732 |
+
equations = [equations1, equations2]
|
| 733 |
+
model = manually_create_model(equations)
|
| 734 |
+
middle_part_1 = r"""
|
| 735 |
+
$y_{0} = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
|
| 736 |
+
$y_{0} = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
| 737 |
+
$y_{0} = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
|
| 738 |
+
"""
|
| 739 |
+
middle_part_2 = r"""
|
| 740 |
+
$y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
|
| 741 |
+
$y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
|
| 742 |
+
$y_{1} = x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
|
| 743 |
+
"""
|
| 744 |
+
true_latex_table_str = "\n\n".join(
|
| 745 |
+
self.create_true_latex(part, include_score=True)
|
| 746 |
+
for part in [middle_part_1, middle_part_2]
|
| 747 |
+
)
|
| 748 |
+
true_latex_table_str = TRUE_PREAMBLE + "\n" + true_latex_table_str
|
| 749 |
+
latex_table_str = model.latex_table()
|
| 750 |
+
|
| 751 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
| 752 |
+
|
| 753 |
+
def test_latex_float_precision(self):
|
| 754 |
+
"""Test that we can print latex expressions with custom precision"""
|
| 755 |
+
expr = sympy.Float(4583.4485748, dps=50)
|
| 756 |
+
self.assertEqual(to_latex(expr, prec=6), r"4583.45")
|
| 757 |
+
self.assertEqual(to_latex(expr, prec=5), r"4583.4")
|
| 758 |
+
self.assertEqual(to_latex(expr, prec=4), r"4583.")
|
| 759 |
+
self.assertEqual(to_latex(expr, prec=3), r"4.58 \cdot 10^{3}")
|
| 760 |
+
self.assertEqual(to_latex(expr, prec=2), r"4.6 \cdot 10^{3}")
|
| 761 |
+
|
| 762 |
+
# Multiple numbers:
|
| 763 |
+
x = sympy.Symbol("x")
|
| 764 |
+
expr = x * 3232.324857384 - 1.4857485e-10
|
| 765 |
+
self.assertEqual(
|
| 766 |
+
to_latex(expr, prec=2), "3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
|
| 767 |
+
)
|
| 768 |
+
self.assertEqual(
|
| 769 |
+
to_latex(expr, prec=3), "3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
|
| 770 |
+
)
|
| 771 |
+
self.assertEqual(
|
| 772 |
+
to_latex(expr, prec=8), "3232.3249 x - 1.4857485 \cdot 10^{-10}"
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
def test_latex_break_long_equation(self):
|
| 776 |
+
"""Test that we can break a long equation inside the table"""
|
| 777 |
+
long_equation = """
|
| 778 |
+
- cos(x1 * x0) + 3.2 * x0 - 1.2 * x1 + x1 * x1 * x1 + x0 * x0 * x0
|
| 779 |
+
+ 5.2 * sin(0.3256 * sin(x2) - 2.6 * x0) + x0 * x0 * x0 * x0 * x0
|
| 780 |
+
+ cos(cos(x1 * x0) + 3.2 * x0 - 1.2 * x1 + x1 * x1 * x1 + x0 * x0 * x0)
|
| 781 |
+
"""
|
| 782 |
+
long_equation = "".join(long_equation.split("\n")).strip()
|
| 783 |
+
equations = pd.DataFrame(
|
| 784 |
+
dict(
|
| 785 |
+
equation=["x0", "cos(x0)", long_equation],
|
| 786 |
+
loss=[1.052, 0.02315, 1.12347e-15],
|
| 787 |
+
complexity=[1, 2, 30],
|
| 788 |
+
)
|
| 789 |
+
)
|
| 790 |
+
model = manually_create_model(equations)
|
| 791 |
+
latex_table_str = model.latex_table()
|
| 792 |
+
middle_part = r"""
|
| 793 |
+
$y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
|
| 794 |
+
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
| 795 |
+
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
|
| 796 |
+
"""
|
| 797 |
+
true_latex_table_str = (
|
| 798 |
+
TRUE_PREAMBLE
|
| 799 |
+
+ "\n"
|
| 800 |
+
+ self.create_true_latex(middle_part, include_score=True)
|
| 801 |
+
)
|
| 802 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|