Spaces:
Running
Running
Merge remote-tracking branch 'origin/master' into var-complexity
Browse files- .github/workflows/CI.yml +2 -2
- .github/workflows/CI_Windows.yml +1 -1
- .github/workflows/CI_mac.yml +1 -1
- .github/workflows/docker_deploy.yml +2 -2
- .pre-commit-config.yaml +1 -1
- examples/pysr_demo.ipynb +1 -1
- pysr/sr.py +21 -5
- pysr/test/test.py +118 -91
.github/workflows/CI.yml
CHANGED
|
@@ -52,7 +52,7 @@ jobs:
|
|
| 52 |
with:
|
| 53 |
version: ${{ matrix.julia-version }}
|
| 54 |
- name: "Cache Julia"
|
| 55 |
-
uses: julia-actions/cache@
|
| 56 |
with:
|
| 57 |
cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
|
| 58 |
cache-packages: false
|
|
@@ -144,7 +144,7 @@ jobs:
|
|
| 144 |
activate-environment: pysr-test
|
| 145 |
environment-file: environment.yml
|
| 146 |
- name: "Cache Julia"
|
| 147 |
-
uses: julia-actions/cache@
|
| 148 |
with:
|
| 149 |
cache-name: ${{ matrix.os }}-conda-${{ matrix.python-version }}
|
| 150 |
cache-packages: false
|
|
|
|
| 52 |
with:
|
| 53 |
version: ${{ matrix.julia-version }}
|
| 54 |
- name: "Cache Julia"
|
| 55 |
+
uses: julia-actions/cache@v2
|
| 56 |
with:
|
| 57 |
cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
|
| 58 |
cache-packages: false
|
|
|
|
| 144 |
activate-environment: pysr-test
|
| 145 |
environment-file: environment.yml
|
| 146 |
- name: "Cache Julia"
|
| 147 |
+
uses: julia-actions/cache@v2
|
| 148 |
with:
|
| 149 |
cache-name: ${{ matrix.os }}-conda-${{ matrix.python-version }}
|
| 150 |
cache-packages: false
|
.github/workflows/CI_Windows.yml
CHANGED
|
@@ -40,7 +40,7 @@ jobs:
|
|
| 40 |
with:
|
| 41 |
version: ${{ matrix.julia-version }}
|
| 42 |
- name: "Cache Julia"
|
| 43 |
-
uses: julia-actions/cache@
|
| 44 |
with:
|
| 45 |
cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
|
| 46 |
cache-packages: false
|
|
|
|
| 40 |
with:
|
| 41 |
version: ${{ matrix.julia-version }}
|
| 42 |
- name: "Cache Julia"
|
| 43 |
+
uses: julia-actions/cache@v2
|
| 44 |
with:
|
| 45 |
cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
|
| 46 |
cache-packages: false
|
.github/workflows/CI_mac.yml
CHANGED
|
@@ -40,7 +40,7 @@ jobs:
|
|
| 40 |
with:
|
| 41 |
version: ${{ matrix.julia-version }}
|
| 42 |
- name: "Cache Julia"
|
| 43 |
-
uses: julia-actions/cache@
|
| 44 |
with:
|
| 45 |
cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
|
| 46 |
cache-packages: false
|
|
|
|
| 40 |
with:
|
| 41 |
version: ${{ matrix.julia-version }}
|
| 42 |
- name: "Cache Julia"
|
| 43 |
+
uses: julia-actions/cache@v2
|
| 44 |
with:
|
| 45 |
cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
|
| 46 |
cache-packages: false
|
.github/workflows/docker_deploy.yml
CHANGED
|
@@ -24,13 +24,13 @@ jobs:
|
|
| 24 |
- name: Checkout
|
| 25 |
uses: actions/checkout@v4
|
| 26 |
- name: Login to Docker Hub
|
| 27 |
-
uses: docker/login-action@
|
| 28 |
if: github.event_name != 'pull_request'
|
| 29 |
with:
|
| 30 |
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
| 31 |
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
| 32 |
- name: Login to GitHub registry
|
| 33 |
-
uses: docker/login-action@
|
| 34 |
if: github.event_name != 'pull_request'
|
| 35 |
with:
|
| 36 |
registry: ghcr.io
|
|
|
|
| 24 |
- name: Checkout
|
| 25 |
uses: actions/checkout@v4
|
| 26 |
- name: Login to Docker Hub
|
| 27 |
+
uses: docker/login-action@v3
|
| 28 |
if: github.event_name != 'pull_request'
|
| 29 |
with:
|
| 30 |
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
| 31 |
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
| 32 |
- name: Login to GitHub registry
|
| 33 |
+
uses: docker/login-action@v3
|
| 34 |
if: github.event_name != 'pull_request'
|
| 35 |
with:
|
| 36 |
registry: ghcr.io
|
.pre-commit-config.yaml
CHANGED
|
@@ -9,7 +9,7 @@ repos:
|
|
| 9 |
- id: check-added-large-files
|
| 10 |
# General formatting
|
| 11 |
- repo: https://github.com/psf/black
|
| 12 |
-
rev: 24.4.
|
| 13 |
hooks:
|
| 14 |
- id: black
|
| 15 |
- id: black-jupyter
|
|
|
|
| 9 |
- id: check-added-large-files
|
| 10 |
# General formatting
|
| 11 |
- repo: https://github.com/psf/black
|
| 12 |
+
rev: 24.4.2
|
| 13 |
hooks:
|
| 14 |
- id: black
|
| 15 |
- id: black-jupyter
|
examples/pysr_demo.ipynb
CHANGED
|
@@ -396,7 +396,7 @@
|
|
| 396 |
"id": "wbWHyOjl2_kX"
|
| 397 |
},
|
| 398 |
"source": [
|
| 399 |
-
"Since `quart` is arguably more complex than the other operators, you can also give it a different complexity, using, e.g., `complexity_of_operators={\"quart\": 2}` to give it a complexity of 2 (instead of the default
|
| 400 |
"\n",
|
| 401 |
"\n",
|
| 402 |
"One can also add a binary operator, with, e.g., `\"myoperator(x, y) = x^2 * y\"`. All Julia operators that work on scalar 32-bit floating point values are available.\n",
|
|
|
|
| 396 |
"id": "wbWHyOjl2_kX"
|
| 397 |
},
|
| 398 |
"source": [
|
| 399 |
+
"Since `quart` is arguably more complex than the other operators, you can also give it a different complexity, using, e.g., `complexity_of_operators={\"quart\": 2}` to give it a complexity of 2 (instead of the default 1). You can also define custom complexities for variables and constants (`complexity_of_variables` and `complexity_of_constants`, respectively - both take a single number).\n",
|
| 400 |
"\n",
|
| 401 |
"\n",
|
| 402 |
"One can also add a binary operator, with, e.g., `\"myoperator(x, y) = x^2 * y\"`. All Julia operators that work on scalar 32-bit floating point values are available.\n",
|
pysr/sr.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
"""Define the PySRRegressor scikit-learn interface."""
|
| 2 |
|
| 3 |
import copy
|
|
|
|
|
|
|
| 4 |
import os
|
| 5 |
import pickle as pkl
|
| 6 |
import re
|
|
@@ -912,15 +914,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 912 |
updated_kwarg_name = DEPRECATED_KWARGS[k]
|
| 913 |
setattr(self, updated_kwarg_name, v)
|
| 914 |
warnings.warn(
|
| 915 |
-
f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
|
| 916 |
"Please use that instead.",
|
| 917 |
FutureWarning,
|
| 918 |
)
|
| 919 |
# Handle kwargs that have been moved to the fit method
|
| 920 |
elif k in ["weights", "variable_names", "Xresampled"]:
|
| 921 |
warnings.warn(
|
| 922 |
-
f"{k} is a data
|
| 923 |
-
f"Ignoring parameter; please pass {k} during the call to fit instead.",
|
| 924 |
FutureWarning,
|
| 925 |
)
|
| 926 |
elif k == "julia_project":
|
|
@@ -937,9 +939,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 937 |
FutureWarning,
|
| 938 |
)
|
| 939 |
else:
|
| 940 |
-
|
| 941 |
-
|
|
|
|
| 942 |
)
|
|
|
|
|
|
|
|
|
|
| 943 |
|
| 944 |
@classmethod
|
| 945 |
def from_file(
|
|
@@ -2545,6 +2551,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2545 |
return with_preamble(table_string)
|
| 2546 |
|
| 2547 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2548 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
|
| 2549 |
"""Select an expression and return its index."""
|
| 2550 |
if model_selection == "accuracy":
|
|
|
|
| 1 |
"""Define the PySRRegressor scikit-learn interface."""
|
| 2 |
|
| 3 |
import copy
|
| 4 |
+
import difflib
|
| 5 |
+
import inspect
|
| 6 |
import os
|
| 7 |
import pickle as pkl
|
| 8 |
import re
|
|
|
|
| 914 |
updated_kwarg_name = DEPRECATED_KWARGS[k]
|
| 915 |
setattr(self, updated_kwarg_name, v)
|
| 916 |
warnings.warn(
|
| 917 |
+
f"`{k}` has been renamed to `{updated_kwarg_name}` in PySRRegressor. "
|
| 918 |
"Please use that instead.",
|
| 919 |
FutureWarning,
|
| 920 |
)
|
| 921 |
# Handle kwargs that have been moved to the fit method
|
| 922 |
elif k in ["weights", "variable_names", "Xresampled"]:
|
| 923 |
warnings.warn(
|
| 924 |
+
f"`{k}` is a data-dependent parameter and should be passed when fit is called. "
|
| 925 |
+
f"Ignoring parameter; please pass `{k}` during the call to fit instead.",
|
| 926 |
FutureWarning,
|
| 927 |
)
|
| 928 |
elif k == "julia_project":
|
|
|
|
| 939 |
FutureWarning,
|
| 940 |
)
|
| 941 |
else:
|
| 942 |
+
suggested_keywords = _suggest_keywords(PySRRegressor, k)
|
| 943 |
+
err_msg = (
|
| 944 |
+
f"`{k}` is not a valid keyword argument for PySRRegressor."
|
| 945 |
)
|
| 946 |
+
if len(suggested_keywords) > 0:
|
| 947 |
+
err_msg += f" Did you mean {', '.join(map(lambda s: f'`{s}`', suggested_keywords))}?"
|
| 948 |
+
raise TypeError(err_msg)
|
| 949 |
|
| 950 |
@classmethod
|
| 951 |
def from_file(
|
|
|
|
| 2551 |
return with_preamble(table_string)
|
| 2552 |
|
| 2553 |
|
| 2554 |
+
def _suggest_keywords(cls, k: str) -> List[str]:
|
| 2555 |
+
valid_keywords = [
|
| 2556 |
+
param
|
| 2557 |
+
for param in inspect.signature(cls.__init__).parameters
|
| 2558 |
+
if param not in ["self", "kwargs"]
|
| 2559 |
+
]
|
| 2560 |
+
suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
|
| 2561 |
+
return suggestions
|
| 2562 |
+
|
| 2563 |
+
|
| 2564 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
|
| 2565 |
"""Select an expression and return its index."""
|
| 2566 |
if model_selection == "accuracy":
|
pysr/test/test.py
CHANGED
|
@@ -15,9 +15,8 @@ from pysr import PySRRegressor, install, jl
|
|
| 15 |
from pysr.export_latex import sympy2latex
|
| 16 |
from pysr.feature_selection import _handle_feature_selection, run_feature_selection
|
| 17 |
from pysr.julia_helpers import init_julia
|
| 18 |
-
from pysr.sr import _check_assertions, _process_constraints, idx_model_selection
|
| 19 |
from pysr.utils import _csv_filename_to_pkl_filename
|
| 20 |
-
|
| 21 |
from .params import (
|
| 22 |
DEFAULT_NCYCLES,
|
| 23 |
DEFAULT_NITERATIONS,
|
|
@@ -596,6 +595,105 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 596 |
test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
|
| 597 |
self.assertEqual(test_pkl_file, str(expected_pkl_file))
|
| 598 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
def test_deprecation(self):
|
| 600 |
"""Ensure that deprecation works as expected.
|
| 601 |
|
|
@@ -738,100 +836,28 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 738 |
model.get_best()
|
| 739 |
print("Failed", opt["kwargs"])
|
| 740 |
|
| 741 |
-
def
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
temp_equation_file=True,
|
| 746 |
-
procs=0,
|
| 747 |
-
multithreading=False,
|
| 748 |
)
|
| 749 |
-
nout = 3
|
| 750 |
-
X = np.random.randn(100, 2)
|
| 751 |
-
y = np.random.randn(100, nout)
|
| 752 |
-
model.fit(X, y)
|
| 753 |
-
contents = model.equation_file_contents_.copy()
|
| 754 |
-
|
| 755 |
-
y_predictions = model.predict(X)
|
| 756 |
-
|
| 757 |
-
equation_file_base = model.equation_file_
|
| 758 |
-
for i in range(1, nout + 1):
|
| 759 |
-
assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
|
| 760 |
-
|
| 761 |
-
with tempfile.NamedTemporaryFile() as pickle_file:
|
| 762 |
-
pkl.dump(model, pickle_file)
|
| 763 |
-
pickle_file.seek(0)
|
| 764 |
-
model2 = pkl.load(pickle_file)
|
| 765 |
-
|
| 766 |
-
contents2 = model2.equation_file_contents_
|
| 767 |
-
cols_to_check = ["equation", "loss", "complexity"]
|
| 768 |
-
for frame1, frame2 in zip(contents, contents2):
|
| 769 |
-
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
|
| 770 |
|
| 771 |
-
|
| 772 |
-
|
|
|
|
| 773 |
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
verbosity=0,
|
| 781 |
-
progress=False,
|
| 782 |
-
random_state=0,
|
| 783 |
-
deterministic=True, # Deterministic as tests require this.
|
| 784 |
-
procs=0,
|
| 785 |
-
multithreading=False,
|
| 786 |
-
warm_start=False,
|
| 787 |
-
temp_equation_file=True,
|
| 788 |
-
) # Return early.
|
| 789 |
-
|
| 790 |
-
check_generator = check_estimator(model, generate_only=True)
|
| 791 |
-
exception_messages = []
|
| 792 |
-
for _, check in check_generator:
|
| 793 |
-
if check.func.__name__ == "check_complex_data":
|
| 794 |
-
# We can use complex data, so avoid this check.
|
| 795 |
-
continue
|
| 796 |
-
try:
|
| 797 |
-
with warnings.catch_warnings():
|
| 798 |
-
warnings.simplefilter("ignore")
|
| 799 |
-
check(model)
|
| 800 |
-
print("Passed", check.func.__name__)
|
| 801 |
-
except Exception:
|
| 802 |
-
error_message = str(traceback.format_exc())
|
| 803 |
-
exception_messages.append(
|
| 804 |
-
f"{check.func.__name__}:\n" + error_message + "\n"
|
| 805 |
-
)
|
| 806 |
-
print("Failed", check.func.__name__, "with:")
|
| 807 |
-
# Add a leading tab to error message, which
|
| 808 |
-
# might be multi-line:
|
| 809 |
-
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 810 |
-
# If any checks failed don't let the test pass.
|
| 811 |
-
self.assertEqual(len(exception_messages), 0)
|
| 812 |
-
|
| 813 |
-
def test_param_groupings(self):
|
| 814 |
-
"""Test that param_groupings are complete"""
|
| 815 |
-
param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
|
| 816 |
-
if not param_groupings_file.exists():
|
| 817 |
-
return
|
| 818 |
-
|
| 819 |
-
# Read the file, discarding lines ending in ":",
|
| 820 |
-
# and removing leading "\s*-\s*":
|
| 821 |
-
params = []
|
| 822 |
-
with open(param_groupings_file, "r") as f:
|
| 823 |
-
for line in f.readlines():
|
| 824 |
-
if line.strip().endswith(":"):
|
| 825 |
-
continue
|
| 826 |
-
if line.strip().startswith("-"):
|
| 827 |
-
params.append(line.strip()[1:].strip())
|
| 828 |
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
|
| 833 |
-
|
| 834 |
-
self.assertSetEqual(set(params), set(regressor_params))
|
| 835 |
|
| 836 |
|
| 837 |
TRUE_PREAMBLE = "\n".join(
|
|
@@ -1187,6 +1213,7 @@ def runtests(just_tests=False):
|
|
| 1187 |
TestBest,
|
| 1188 |
TestFeatureSelection,
|
| 1189 |
TestMiscellaneous,
|
|
|
|
| 1190 |
TestLaTeXTable,
|
| 1191 |
TestDimensionalConstraints,
|
| 1192 |
]
|
|
|
|
| 15 |
from pysr.export_latex import sympy2latex
|
| 16 |
from pysr.feature_selection import _handle_feature_selection, run_feature_selection
|
| 17 |
from pysr.julia_helpers import init_julia
|
| 18 |
+
from pysr.sr import _check_assertions, _process_constraints, _suggest_keywords, idx_model_selection
|
| 19 |
from pysr.utils import _csv_filename_to_pkl_filename
|
|
|
|
| 20 |
from .params import (
|
| 21 |
DEFAULT_NCYCLES,
|
| 22 |
DEFAULT_NITERATIONS,
|
|
|
|
| 595 |
test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
|
| 596 |
self.assertEqual(test_pkl_file, str(expected_pkl_file))
|
| 597 |
|
| 598 |
+
def test_pickle_with_temp_equation_file(self):
|
| 599 |
+
"""If we have a temporary equation file, unpickle the estimator."""
|
| 600 |
+
model = PySRRegressor(
|
| 601 |
+
populations=int(1 + DEFAULT_POPULATIONS / 5),
|
| 602 |
+
temp_equation_file=True,
|
| 603 |
+
procs=0,
|
| 604 |
+
multithreading=False,
|
| 605 |
+
)
|
| 606 |
+
nout = 3
|
| 607 |
+
X = np.random.randn(100, 2)
|
| 608 |
+
y = np.random.randn(100, nout)
|
| 609 |
+
model.fit(X, y)
|
| 610 |
+
contents = model.equation_file_contents_.copy()
|
| 611 |
+
|
| 612 |
+
y_predictions = model.predict(X)
|
| 613 |
+
|
| 614 |
+
equation_file_base = model.equation_file_
|
| 615 |
+
for i in range(1, nout + 1):
|
| 616 |
+
assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
|
| 617 |
+
|
| 618 |
+
with tempfile.NamedTemporaryFile() as pickle_file:
|
| 619 |
+
pkl.dump(model, pickle_file)
|
| 620 |
+
pickle_file.seek(0)
|
| 621 |
+
model2 = pkl.load(pickle_file)
|
| 622 |
+
|
| 623 |
+
contents2 = model2.equation_file_contents_
|
| 624 |
+
cols_to_check = ["equation", "loss", "complexity"]
|
| 625 |
+
for frame1, frame2 in zip(contents, contents2):
|
| 626 |
+
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
|
| 627 |
+
|
| 628 |
+
y_predictions2 = model2.predict(X)
|
| 629 |
+
np.testing.assert_array_equal(y_predictions, y_predictions2)
|
| 630 |
+
|
| 631 |
+
def test_scikit_learn_compatibility(self):
|
| 632 |
+
"""Test PySRRegressor compatibility with scikit-learn."""
|
| 633 |
+
model = PySRRegressor(
|
| 634 |
+
niterations=int(1 + DEFAULT_NITERATIONS / 10),
|
| 635 |
+
populations=int(1 + DEFAULT_POPULATIONS / 3),
|
| 636 |
+
ncycles_per_iteration=int(2 + DEFAULT_NCYCLES / 10),
|
| 637 |
+
verbosity=0,
|
| 638 |
+
progress=False,
|
| 639 |
+
random_state=0,
|
| 640 |
+
deterministic=True, # Deterministic as tests require this.
|
| 641 |
+
procs=0,
|
| 642 |
+
multithreading=False,
|
| 643 |
+
warm_start=False,
|
| 644 |
+
temp_equation_file=True,
|
| 645 |
+
) # Return early.
|
| 646 |
+
|
| 647 |
+
check_generator = check_estimator(model, generate_only=True)
|
| 648 |
+
exception_messages = []
|
| 649 |
+
for _, check in check_generator:
|
| 650 |
+
if check.func.__name__ == "check_complex_data":
|
| 651 |
+
# We can use complex data, so avoid this check.
|
| 652 |
+
continue
|
| 653 |
+
try:
|
| 654 |
+
with warnings.catch_warnings():
|
| 655 |
+
warnings.simplefilter("ignore")
|
| 656 |
+
check(model)
|
| 657 |
+
print("Passed", check.func.__name__)
|
| 658 |
+
except Exception:
|
| 659 |
+
error_message = str(traceback.format_exc())
|
| 660 |
+
exception_messages.append(
|
| 661 |
+
f"{check.func.__name__}:\n" + error_message + "\n"
|
| 662 |
+
)
|
| 663 |
+
print("Failed", check.func.__name__, "with:")
|
| 664 |
+
# Add a leading tab to error message, which
|
| 665 |
+
# might be multi-line:
|
| 666 |
+
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 667 |
+
# If any checks failed don't let the test pass.
|
| 668 |
+
self.assertEqual(len(exception_messages), 0)
|
| 669 |
+
|
| 670 |
+
def test_param_groupings(self):
|
| 671 |
+
"""Test that param_groupings are complete"""
|
| 672 |
+
param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
|
| 673 |
+
if not param_groupings_file.exists():
|
| 674 |
+
return
|
| 675 |
+
|
| 676 |
+
# Read the file, discarding lines ending in ":",
|
| 677 |
+
# and removing leading "\s*-\s*":
|
| 678 |
+
params = []
|
| 679 |
+
with open(param_groupings_file, "r") as f:
|
| 680 |
+
for line in f.readlines():
|
| 681 |
+
if line.strip().endswith(":"):
|
| 682 |
+
continue
|
| 683 |
+
if line.strip().startswith("-"):
|
| 684 |
+
params.append(line.strip()[1:].strip())
|
| 685 |
+
|
| 686 |
+
regressor_params = [
|
| 687 |
+
p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
|
| 688 |
+
]
|
| 689 |
+
|
| 690 |
+
# Check the sets are equal:
|
| 691 |
+
self.assertSetEqual(set(params), set(regressor_params))
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
class TestHelpMessages(unittest.TestCase):
|
| 695 |
+
"""Test user help messages."""
|
| 696 |
+
|
| 697 |
def test_deprecation(self):
|
| 698 |
"""Ensure that deprecation works as expected.
|
| 699 |
|
|
|
|
| 836 |
model.get_best()
|
| 837 |
print("Failed", opt["kwargs"])
|
| 838 |
|
| 839 |
+
def test_suggest_keywords(self):
|
| 840 |
+
# Easy
|
| 841 |
+
self.assertEqual(
|
| 842 |
+
_suggest_keywords(PySRRegressor, "loss_function"), ["loss_function"]
|
|
|
|
|
|
|
|
|
|
| 843 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
|
| 845 |
+
# More complex, and with error
|
| 846 |
+
with self.assertRaises(TypeError) as cm:
|
| 847 |
+
model = PySRRegressor(ncyclesperiterationn=5)
|
| 848 |
|
| 849 |
+
self.assertIn(
|
| 850 |
+
"`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
|
| 851 |
+
)
|
| 852 |
+
self.assertIn("Did you mean", str(cm.exception))
|
| 853 |
+
self.assertIn("`ncycles_per_iteration`, ", str(cm.exception))
|
| 854 |
+
self.assertIn("`niterations`", str(cm.exception))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
|
| 856 |
+
# Farther matches (this might need to be changed)
|
| 857 |
+
with self.assertRaises(TypeError) as cm:
|
| 858 |
+
model = PySRRegressor(operators=["+", "-"])
|
| 859 |
|
| 860 |
+
self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))
|
|
|
|
| 861 |
|
| 862 |
|
| 863 |
TRUE_PREAMBLE = "\n".join(
|
|
|
|
| 1213 |
TestBest,
|
| 1214 |
TestFeatureSelection,
|
| 1215 |
TestMiscellaneous,
|
| 1216 |
+
TestHelpMessages,
|
| 1217 |
TestLaTeXTable,
|
| 1218 |
TestDimensionalConstraints,
|
| 1219 |
]
|