Spaces:
Sleeping
Sleeping
Commit
·
9a5df63
1
Parent(s):
0cb1353
Use custom sympy LatexPrinter for precision
Browse files- pysr/export_latex.py +20 -8
- pysr/sr.py +4 -8
- test/test.py +23 -0
pysr/export_latex.py
CHANGED
|
@@ -1,14 +1,26 @@
|
|
| 1 |
"""Functions to help export PySR equations to LaTeX."""
|
| 2 |
-
import
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def generate_top_of_latex_table(columns=["Equation", "Complexity", "Loss"]):
|
|
|
|
| 1 |
"""Functions to help export PySR equations to LaTeX."""
|
| 2 |
+
import sympy
|
| 3 |
+
from sympy.printing.latex import LatexPrinter
|
| 4 |
|
| 5 |
|
| 6 |
+
class PreciseLatexPrinter(LatexPrinter):
|
| 7 |
+
"""Modified SymPy printer with custom float precision."""
|
| 8 |
+
def __init__(self, settings=None, prec=3):
|
| 9 |
+
super().__init__(settings)
|
| 10 |
+
self.prec = prec
|
| 11 |
+
|
| 12 |
+
def _print_Float(self, expr):
|
| 13 |
+
# Reduce precision of float:
|
| 14 |
+
reduced_float = sympy.Float(expr, self.prec)
|
| 15 |
+
return super()._print_Float(reduced_float)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def to_latex(expr, prec=3, **settings):
|
| 19 |
+
"""Convert sympy expression to LaTeX with custom precision."""
|
| 20 |
+
if len(settings) == 0:
|
| 21 |
+
settings = None
|
| 22 |
+
printer = PreciseLatexPrinter(settings=settings, prec=prec)
|
| 23 |
+
return printer.doprint(expr)
|
| 24 |
|
| 25 |
|
| 26 |
def generate_top_of_latex_table(columns=["Equation", "Complexity", "Loss"]):
|
pysr/sr.py
CHANGED
|
@@ -28,7 +28,7 @@ from .julia_helpers import (
|
|
| 28 |
)
|
| 29 |
from .export_numpy import CallableEquation
|
| 30 |
from .export_latex import (
|
| 31 |
-
|
| 32 |
generate_top_of_latex_table,
|
| 33 |
generate_bottom_of_latex_table,
|
| 34 |
)
|
|
@@ -1752,14 +1752,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1752 |
if self.nout_ > 1:
|
| 1753 |
output = []
|
| 1754 |
for s in sympy_representation:
|
| 1755 |
-
|
| 1756 |
-
|
| 1757 |
-
raw_latex, precision
|
| 1758 |
-
)
|
| 1759 |
-
output.append(reduced_latex)
|
| 1760 |
return output
|
| 1761 |
-
|
| 1762 |
-
return set_precision_of_constants_in_string(raw_latex, precision)
|
| 1763 |
|
| 1764 |
def jax(self, index=None):
|
| 1765 |
"""
|
|
|
|
| 28 |
)
|
| 29 |
from .export_numpy import CallableEquation
|
| 30 |
from .export_latex import (
|
| 31 |
+
to_latex,
|
| 32 |
generate_top_of_latex_table,
|
| 33 |
generate_bottom_of_latex_table,
|
| 34 |
)
|
|
|
|
| 1752 |
if self.nout_ > 1:
|
| 1753 |
output = []
|
| 1754 |
for s in sympy_representation:
|
| 1755 |
+
latex = to_latex(s, prec=precision)
|
| 1756 |
+
output.append(latex)
|
|
|
|
|
|
|
|
|
|
| 1757 |
return output
|
| 1758 |
+
return to_latex(sympy_representation, prec=precision)
|
|
|
|
| 1759 |
|
| 1760 |
def jax(self, index=None):
|
| 1761 |
"""
|
test/test.py
CHANGED
|
@@ -6,6 +6,7 @@ import numpy as np
|
|
| 6 |
from sklearn import model_selection
|
| 7 |
from pysr import PySRRegressor
|
| 8 |
from pysr.sr import run_feature_selection, _handle_feature_selection
|
|
|
|
| 9 |
from sklearn.utils.estimator_checks import check_estimator
|
| 10 |
import sympy
|
| 11 |
import pandas as pd
|
|
@@ -573,3 +574,25 @@ class TestLaTeXTable(unittest.TestCase):
|
|
| 573 |
"""
|
| 574 |
true_latex_table_str = self.create_true_latex(middle_part)
|
| 575 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from sklearn import model_selection
|
| 7 |
from pysr import PySRRegressor
|
| 8 |
from pysr.sr import run_feature_selection, _handle_feature_selection
|
| 9 |
+
from pysr.export_latex import to_latex
|
| 10 |
from sklearn.utils.estimator_checks import check_estimator
|
| 11 |
import sympy
|
| 12 |
import pandas as pd
|
|
|
|
| 574 |
"""
|
| 575 |
true_latex_table_str = self.create_true_latex(middle_part)
|
| 576 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
| 577 |
+
|
| 578 |
+
def test_latex_float_precision(self):
|
| 579 |
+
"""Test that we can print latex expressions with custom precision"""
|
| 580 |
+
expr = sympy.Float(4583.4485748, dps=50)
|
| 581 |
+
self.assertEqual(to_latex(expr, prec=6), r"4583.45")
|
| 582 |
+
self.assertEqual(to_latex(expr, prec=5), r"4583.4")
|
| 583 |
+
self.assertEqual(to_latex(expr, prec=4), r"4583.0")
|
| 584 |
+
self.assertEqual(to_latex(expr, prec=3), r"4.58 \cdot 10^{3}")
|
| 585 |
+
self.assertEqual(to_latex(expr, prec=2), r"4.6 \cdot 10^{3}")
|
| 586 |
+
|
| 587 |
+
# Multiple numbers:
|
| 588 |
+
x = sympy.Symbol("x")
|
| 589 |
+
expr = x * 3232.324857384 - 1.4857485e-10
|
| 590 |
+
self.assertEqual(
|
| 591 |
+
to_latex(expr, prec=2), "3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
|
| 592 |
+
)
|
| 593 |
+
self.assertEqual(
|
| 594 |
+
to_latex(expr, prec=3), "3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
|
| 595 |
+
)
|
| 596 |
+
self.assertEqual(
|
| 597 |
+
to_latex(expr, prec=8), "3232.3249 x - 1.4857485 \cdot 10^{-10}"
|
| 598 |
+
)
|