Spaces:
Running
Running
Commit
·
d423f0c
1
Parent(s):
215a692
Refactor table env generator
Browse files- pysr/export_latex.py +7 -7
- pysr/sr.py +4 -9
pysr/export_latex.py
CHANGED
|
@@ -23,7 +23,7 @@ def to_latex(expr, prec=3, full_prec=True, **settings):
|
|
| 23 |
return printer.doprint(expr)
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
| 27 |
margins = "".join([("l" if col == "equation" else "c") for col in columns])
|
| 28 |
column_map = {
|
| 29 |
"complexity": "Complexity",
|
|
@@ -32,7 +32,7 @@ def generate_top_of_latex_table(columns=["equation", "complexity", "loss"]):
|
|
| 32 |
"score": "Score",
|
| 33 |
}
|
| 34 |
columns = [column_map[col] for col in columns]
|
| 35 |
-
|
| 36 |
r"\begin{table}[h]",
|
| 37 |
r"\begin{center}",
|
| 38 |
r"\begin{tabular}{@{}" + margins + r"@{}}",
|
|
@@ -40,14 +40,14 @@ def generate_top_of_latex_table(columns=["equation", "complexity", "loss"]):
|
|
| 40 |
" & ".join(columns) + r" \\",
|
| 41 |
r"\midrule",
|
| 42 |
]
|
| 43 |
-
return "\n".join(latex_table_pieces)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
def generate_bottom_of_latex_table():
|
| 47 |
-
latex_table_pieces = [
|
| 48 |
r"\bottomrule",
|
| 49 |
r"\end{tabular}",
|
| 50 |
r"\end{center}",
|
| 51 |
r"\end{table}",
|
| 52 |
]
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
return printer.doprint(expr)
|
| 24 |
|
| 25 |
|
| 26 |
+
def generate_table_environment(columns=["equation", "complexity", "loss"]):
|
| 27 |
margins = "".join([("l" if col == "equation" else "c") for col in columns])
|
| 28 |
column_map = {
|
| 29 |
"complexity": "Complexity",
|
|
|
|
| 32 |
"score": "Score",
|
| 33 |
}
|
| 34 |
columns = [column_map[col] for col in columns]
|
| 35 |
+
top_pieces = [
|
| 36 |
r"\begin{table}[h]",
|
| 37 |
r"\begin{center}",
|
| 38 |
r"\begin{tabular}{@{}" + margins + r"@{}}",
|
|
|
|
| 40 |
" & ".join(columns) + r" \\",
|
| 41 |
r"\midrule",
|
| 42 |
]
|
|
|
|
| 43 |
|
| 44 |
+
bottom_pieces = [
|
|
|
|
|
|
|
| 45 |
r"\bottomrule",
|
| 46 |
r"\end{tabular}",
|
| 47 |
r"\end{center}",
|
| 48 |
r"\end{table}",
|
| 49 |
]
|
| 50 |
+
top_latex_table = "\n".join(top_pieces)
|
| 51 |
+
bottom_latex_table = "\n".join(bottom_pieces)
|
| 52 |
+
|
| 53 |
+
return top_latex_table, bottom_latex_table
|
pysr/sr.py
CHANGED
|
@@ -27,11 +27,7 @@ from .julia_helpers import (
|
|
| 27 |
import_error_string,
|
| 28 |
)
|
| 29 |
from .export_numpy import CallableEquation
|
| 30 |
-
from .export_latex import
|
| 31 |
-
to_latex,
|
| 32 |
-
generate_top_of_latex_table,
|
| 33 |
-
generate_bottom_of_latex_table,
|
| 34 |
-
)
|
| 35 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
| 36 |
|
| 37 |
|
|
@@ -2037,8 +2033,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2037 |
else:
|
| 2038 |
indices = list(range(len(self.equations_)))
|
| 2039 |
|
| 2040 |
-
|
| 2041 |
-
latex_table_bottom = generate_bottom_of_latex_table()
|
| 2042 |
|
| 2043 |
equations = self.equations_
|
| 2044 |
|
|
@@ -2092,9 +2087,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2092 |
all_latex_table_str.append(
|
| 2093 |
"\n".join(
|
| 2094 |
[
|
| 2095 |
-
|
| 2096 |
*latex_table_content,
|
| 2097 |
-
|
| 2098 |
]
|
| 2099 |
)
|
| 2100 |
)
|
|
|
|
| 27 |
import_error_string,
|
| 28 |
)
|
| 29 |
from .export_numpy import CallableEquation
|
| 30 |
+
from .export_latex import to_latex, generate_table_environment
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
| 32 |
|
| 33 |
|
|
|
|
| 2033 |
else:
|
| 2034 |
indices = list(range(len(self.equations_)))
|
| 2035 |
|
| 2036 |
+
latex_top, latex_bottom = generate_table_environment(columns)
|
|
|
|
| 2037 |
|
| 2038 |
equations = self.equations_
|
| 2039 |
|
|
|
|
| 2087 |
all_latex_table_str.append(
|
| 2088 |
"\n".join(
|
| 2089 |
[
|
| 2090 |
+
latex_top,
|
| 2091 |
*latex_table_content,
|
| 2092 |
+
latex_bottom,
|
| 2093 |
]
|
| 2094 |
)
|
| 2095 |
)
|