Spaces:
Running
Running
Merge pull request #670 from MilesCranmer/issue666
Browse files- pysr/__init__.py +0 -1
- pysr/export_jax.py +1 -1
- pysr/export_latex.py +2 -2
- pysr/export_numpy.py +1 -1
- pysr/export_sympy.py +1 -1
- pysr/export_torch.py +1 -1
- pysr/test/test.py +12 -3
- pysr/test/test_jax.py +2 -2
- pysr/test/test_torch.py +1 -1
pysr/__init__.py
CHANGED
|
@@ -18,7 +18,6 @@ __all__ = [
|
|
| 18 |
"sklearn_monkeypatch",
|
| 19 |
"sympy2jax",
|
| 20 |
"sympy2torch",
|
| 21 |
-
"Problem",
|
| 22 |
"install",
|
| 23 |
"PySRRegressor",
|
| 24 |
"best",
|
|
|
|
| 18 |
"sklearn_monkeypatch",
|
| 19 |
"sympy2jax",
|
| 20 |
"sympy2torch",
|
|
|
|
| 21 |
"install",
|
| 22 |
"PySRRegressor",
|
| 23 |
"best",
|
pysr/export_jax.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import numpy as np # noqa: F401
|
| 2 |
-
import sympy
|
| 3 |
|
| 4 |
# Special since need to reduce arguments.
|
| 5 |
MUL = 0
|
|
|
|
| 1 |
import numpy as np # noqa: F401
|
| 2 |
+
import sympy # type: ignore
|
| 3 |
|
| 4 |
# Special since need to reduce arguments.
|
| 5 |
MUL = 0
|
pysr/export_latex.py
CHANGED
|
@@ -3,8 +3,8 @@
|
|
| 3 |
from typing import List, Optional, Tuple
|
| 4 |
|
| 5 |
import pandas as pd
|
| 6 |
-
import sympy
|
| 7 |
-
from sympy.printing.latex import LatexPrinter
|
| 8 |
|
| 9 |
|
| 10 |
class PreciseLatexPrinter(LatexPrinter):
|
|
|
|
| 3 |
from typing import List, Optional, Tuple
|
| 4 |
|
| 5 |
import pandas as pd
|
| 6 |
+
import sympy # type: ignore
|
| 7 |
+
from sympy.printing.latex import LatexPrinter # type: ignore
|
| 8 |
|
| 9 |
|
| 10 |
class PreciseLatexPrinter(LatexPrinter):
|
pysr/export_numpy.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import List, Union
|
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
from numpy.typing import NDArray
|
| 9 |
-
from sympy import Expr, Symbol, lambdify
|
| 10 |
|
| 11 |
|
| 12 |
def sympy2numpy(eqn, sympy_symbols, *, selection=None):
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
from numpy.typing import NDArray
|
| 9 |
+
from sympy import Expr, Symbol, lambdify # type: ignore
|
| 10 |
|
| 11 |
|
| 12 |
def sympy2numpy(eqn, sympy_symbols, *, selection=None):
|
pysr/export_sympy.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
from typing import Callable, Dict, List, Optional
|
| 4 |
|
| 5 |
-
import sympy
|
| 6 |
from sympy import sympify
|
| 7 |
|
| 8 |
from .utils import ArrayLike
|
|
|
|
| 2 |
|
| 3 |
from typing import Callable, Dict, List, Optional
|
| 4 |
|
| 5 |
+
import sympy # type: ignore
|
| 6 |
from sympy import sympify
|
| 7 |
|
| 8 |
from .utils import ArrayLike
|
pysr/export_torch.py
CHANGED
|
@@ -4,7 +4,7 @@ import collections as co
|
|
| 4 |
import functools as ft
|
| 5 |
|
| 6 |
import numpy as np # noqa: F401
|
| 7 |
-
import sympy
|
| 8 |
|
| 9 |
|
| 10 |
def _reduce(fn):
|
|
|
|
| 4 |
import functools as ft
|
| 5 |
|
| 6 |
import numpy as np # noqa: F401
|
| 7 |
+
import sympy # type: ignore
|
| 8 |
|
| 9 |
|
| 10 |
def _reduce(fn):
|
pysr/test/test.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import pickle as pkl
|
| 3 |
import tempfile
|
|
@@ -8,7 +9,7 @@ from pathlib import Path
|
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
import pandas as pd
|
| 11 |
-
import sympy
|
| 12 |
from sklearn.utils.estimator_checks import check_estimator
|
| 13 |
|
| 14 |
from pysr import PySRRegressor, install, jl
|
|
@@ -892,7 +893,7 @@ class TestHelpMessages(unittest.TestCase):
|
|
| 892 |
|
| 893 |
# More complex, and with error
|
| 894 |
with self.assertRaises(TypeError) as cm:
|
| 895 |
-
|
| 896 |
|
| 897 |
self.assertIn(
|
| 898 |
"`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
|
|
@@ -903,10 +904,18 @@ class TestHelpMessages(unittest.TestCase):
|
|
| 903 |
|
| 904 |
# Farther matches (this might need to be changed)
|
| 905 |
with self.assertRaises(TypeError) as cm:
|
| 906 |
-
|
| 907 |
|
| 908 |
self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))
|
| 909 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 910 |
|
| 911 |
TRUE_PREAMBLE = "\n".join(
|
| 912 |
[
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
import os
|
| 3 |
import pickle as pkl
|
| 4 |
import tempfile
|
|
|
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
+
import sympy # type: ignore
|
| 13 |
from sklearn.utils.estimator_checks import check_estimator
|
| 14 |
|
| 15 |
from pysr import PySRRegressor, install, jl
|
|
|
|
| 893 |
|
| 894 |
# More complex, and with error
|
| 895 |
with self.assertRaises(TypeError) as cm:
|
| 896 |
+
PySRRegressor(ncyclesperiterationn=5)
|
| 897 |
|
| 898 |
self.assertIn(
|
| 899 |
"`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
|
|
|
|
| 904 |
|
| 905 |
# Farther matches (this might need to be changed)
|
| 906 |
with self.assertRaises(TypeError) as cm:
|
| 907 |
+
PySRRegressor(operators=["+", "-"])
|
| 908 |
|
| 909 |
self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))
|
| 910 |
|
| 911 |
+
def test_issue_666(self):
|
| 912 |
+
# Try the equivalent of `from pysr import *`
|
| 913 |
+
pysr_module = importlib.import_module("pysr")
|
| 914 |
+
names_to_import = pysr_module.__all__
|
| 915 |
+
|
| 916 |
+
for name in names_to_import:
|
| 917 |
+
getattr(pysr_module, name)
|
| 918 |
+
|
| 919 |
|
| 920 |
TRUE_PREAMBLE = "\n".join(
|
| 921 |
[
|
pysr/test/test_jax.py
CHANGED
|
@@ -3,7 +3,7 @@ from functools import partial
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
| 6 |
-
import sympy
|
| 7 |
|
| 8 |
import pysr
|
| 9 |
from pysr import PySRRegressor, sympy2jax
|
|
@@ -102,7 +102,7 @@ class TestJAX(unittest.TestCase):
|
|
| 102 |
)
|
| 103 |
|
| 104 |
def test_issue_656(self):
|
| 105 |
-
import sympy
|
| 106 |
|
| 107 |
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
|
| 108 |
f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
|
|
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
| 6 |
+
import sympy # type: ignore
|
| 7 |
|
| 8 |
import pysr
|
| 9 |
from pysr import PySRRegressor, sympy2jax
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
def test_issue_656(self):
|
| 105 |
+
import sympy # type: ignore
|
| 106 |
|
| 107 |
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
|
| 108 |
f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
|
pysr/test/test_torch.py
CHANGED
|
@@ -2,7 +2,7 @@ import unittest
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
-
import sympy
|
| 6 |
|
| 7 |
import pysr
|
| 8 |
from pysr import PySRRegressor, sympy2torch
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
+
import sympy # type: ignore
|
| 6 |
|
| 7 |
import pysr
|
| 8 |
from pysr import PySRRegressor, sympy2torch
|