Spaces:
Running
Running
Merge pull request #620 from MilesCranmer/autocorrect-kwarg
Browse files- pysr/sr.py +21 -5
- pysr/test/test.py +123 -90
pysr/sr.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
"""Define the PySRRegressor scikit-learn interface."""
|
| 2 |
|
| 3 |
import copy
|
|
|
|
|
|
|
| 4 |
import os
|
| 5 |
import pickle as pkl
|
| 6 |
import re
|
|
@@ -900,15 +902,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 900 |
updated_kwarg_name = DEPRECATED_KWARGS[k]
|
| 901 |
setattr(self, updated_kwarg_name, v)
|
| 902 |
warnings.warn(
|
| 903 |
-
f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
|
| 904 |
"Please use that instead.",
|
| 905 |
FutureWarning,
|
| 906 |
)
|
| 907 |
# Handle kwargs that have been moved to the fit method
|
| 908 |
elif k in ["weights", "variable_names", "Xresampled"]:
|
| 909 |
warnings.warn(
|
| 910 |
-
f"{k} is a data
|
| 911 |
-
f"Ignoring parameter; please pass {k} during the call to fit instead.",
|
| 912 |
FutureWarning,
|
| 913 |
)
|
| 914 |
elif k == "julia_project":
|
|
@@ -925,9 +927,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 925 |
FutureWarning,
|
| 926 |
)
|
| 927 |
else:
|
| 928 |
-
|
| 929 |
-
|
|
|
|
| 930 |
)
|
|
|
|
|
|
|
|
|
|
| 931 |
|
| 932 |
@classmethod
|
| 933 |
def from_file(
|
|
@@ -2459,6 +2465,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2459 |
return with_preamble(table_string)
|
| 2460 |
|
| 2461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2462 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
|
| 2463 |
"""Select an expression and return its index."""
|
| 2464 |
if model_selection == "accuracy":
|
|
|
|
| 1 |
"""Define the PySRRegressor scikit-learn interface."""
|
| 2 |
|
| 3 |
import copy
|
| 4 |
+
import difflib
|
| 5 |
+
import inspect
|
| 6 |
import os
|
| 7 |
import pickle as pkl
|
| 8 |
import re
|
|
|
|
| 902 |
updated_kwarg_name = DEPRECATED_KWARGS[k]
|
| 903 |
setattr(self, updated_kwarg_name, v)
|
| 904 |
warnings.warn(
|
| 905 |
+
f"`{k}` has been renamed to `{updated_kwarg_name}` in PySRRegressor. "
|
| 906 |
"Please use that instead.",
|
| 907 |
FutureWarning,
|
| 908 |
)
|
| 909 |
# Handle kwargs that have been moved to the fit method
|
| 910 |
elif k in ["weights", "variable_names", "Xresampled"]:
|
| 911 |
warnings.warn(
|
| 912 |
+
f"`{k}` is a data-dependent parameter and should be passed when fit is called. "
|
| 913 |
+
f"Ignoring parameter; please pass `{k}` during the call to fit instead.",
|
| 914 |
FutureWarning,
|
| 915 |
)
|
| 916 |
elif k == "julia_project":
|
|
|
|
| 927 |
FutureWarning,
|
| 928 |
)
|
| 929 |
else:
|
| 930 |
+
suggested_keywords = _suggest_keywords(PySRRegressor, k)
|
| 931 |
+
err_msg = (
|
| 932 |
+
f"`{k}` is not a valid keyword argument for PySRRegressor."
|
| 933 |
)
|
| 934 |
+
if len(suggested_keywords) > 0:
|
| 935 |
+
err_msg += f" Did you mean {', '.join(map(lambda s: f'`{s}`', suggested_keywords))}?"
|
| 936 |
+
raise TypeError(err_msg)
|
| 937 |
|
| 938 |
@classmethod
|
| 939 |
def from_file(
|
|
|
|
| 2465 |
return with_preamble(table_string)
|
| 2466 |
|
| 2467 |
|
| 2468 |
+
def _suggest_keywords(cls, k: str) -> List[str]:
|
| 2469 |
+
valid_keywords = [
|
| 2470 |
+
param
|
| 2471 |
+
for param in inspect.signature(cls.__init__).parameters
|
| 2472 |
+
if param not in ["self", "kwargs"]
|
| 2473 |
+
]
|
| 2474 |
+
suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
|
| 2475 |
+
return suggestions
|
| 2476 |
+
|
| 2477 |
+
|
| 2478 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
|
| 2479 |
"""Select an expression and return its index."""
|
| 2480 |
if model_selection == "accuracy":
|
pysr/test/test.py
CHANGED
|
@@ -15,7 +15,12 @@ from .. import PySRRegressor, install, jl
|
|
| 15 |
from ..export_latex import sympy2latex
|
| 16 |
from ..feature_selection import _handle_feature_selection, run_feature_selection
|
| 17 |
from ..julia_helpers import init_julia
|
| 18 |
-
from ..sr import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from ..utils import _csv_filename_to_pkl_filename
|
| 20 |
from .params import (
|
| 21 |
DEFAULT_NCYCLES,
|
|
@@ -573,6 +578,105 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 573 |
test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
|
| 574 |
self.assertEqual(test_pkl_file, str(expected_pkl_file))
|
| 575 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
def test_deprecation(self):
|
| 577 |
"""Ensure that deprecation works as expected.
|
| 578 |
|
|
@@ -715,100 +819,28 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 715 |
model.get_best()
|
| 716 |
print("Failed", opt["kwargs"])
|
| 717 |
|
| 718 |
-
def
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
temp_equation_file=True,
|
| 723 |
-
procs=0,
|
| 724 |
-
multithreading=False,
|
| 725 |
)
|
| 726 |
-
nout = 3
|
| 727 |
-
X = np.random.randn(100, 2)
|
| 728 |
-
y = np.random.randn(100, nout)
|
| 729 |
-
model.fit(X, y)
|
| 730 |
-
contents = model.equation_file_contents_.copy()
|
| 731 |
-
|
| 732 |
-
y_predictions = model.predict(X)
|
| 733 |
-
|
| 734 |
-
equation_file_base = model.equation_file_
|
| 735 |
-
for i in range(1, nout + 1):
|
| 736 |
-
assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
|
| 737 |
-
|
| 738 |
-
with tempfile.NamedTemporaryFile() as pickle_file:
|
| 739 |
-
pkl.dump(model, pickle_file)
|
| 740 |
-
pickle_file.seek(0)
|
| 741 |
-
model2 = pkl.load(pickle_file)
|
| 742 |
-
|
| 743 |
-
contents2 = model2.equation_file_contents_
|
| 744 |
-
cols_to_check = ["equation", "loss", "complexity"]
|
| 745 |
-
for frame1, frame2 in zip(contents, contents2):
|
| 746 |
-
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
|
| 747 |
|
| 748 |
-
|
| 749 |
-
|
|
|
|
| 750 |
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
verbosity=0,
|
| 758 |
-
progress=False,
|
| 759 |
-
random_state=0,
|
| 760 |
-
deterministic=True, # Deterministic as tests require this.
|
| 761 |
-
procs=0,
|
| 762 |
-
multithreading=False,
|
| 763 |
-
warm_start=False,
|
| 764 |
-
temp_equation_file=True,
|
| 765 |
-
) # Return early.
|
| 766 |
-
|
| 767 |
-
check_generator = check_estimator(model, generate_only=True)
|
| 768 |
-
exception_messages = []
|
| 769 |
-
for _, check in check_generator:
|
| 770 |
-
if check.func.__name__ == "check_complex_data":
|
| 771 |
-
# We can use complex data, so avoid this check.
|
| 772 |
-
continue
|
| 773 |
-
try:
|
| 774 |
-
with warnings.catch_warnings():
|
| 775 |
-
warnings.simplefilter("ignore")
|
| 776 |
-
check(model)
|
| 777 |
-
print("Passed", check.func.__name__)
|
| 778 |
-
except Exception:
|
| 779 |
-
error_message = str(traceback.format_exc())
|
| 780 |
-
exception_messages.append(
|
| 781 |
-
f"{check.func.__name__}:\n" + error_message + "\n"
|
| 782 |
-
)
|
| 783 |
-
print("Failed", check.func.__name__, "with:")
|
| 784 |
-
# Add a leading tab to error message, which
|
| 785 |
-
# might be multi-line:
|
| 786 |
-
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 787 |
-
# If any checks failed don't let the test pass.
|
| 788 |
-
self.assertEqual(len(exception_messages), 0)
|
| 789 |
-
|
| 790 |
-
def test_param_groupings(self):
|
| 791 |
-
"""Test that param_groupings are complete"""
|
| 792 |
-
param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
|
| 793 |
-
if not param_groupings_file.exists():
|
| 794 |
-
return
|
| 795 |
-
|
| 796 |
-
# Read the file, discarding lines ending in ":",
|
| 797 |
-
# and removing leading "\s*-\s*":
|
| 798 |
-
params = []
|
| 799 |
-
with open(param_groupings_file, "r") as f:
|
| 800 |
-
for line in f.readlines():
|
| 801 |
-
if line.strip().endswith(":"):
|
| 802 |
-
continue
|
| 803 |
-
if line.strip().startswith("-"):
|
| 804 |
-
params.append(line.strip()[1:].strip())
|
| 805 |
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
|
| 810 |
-
|
| 811 |
-
self.assertSetEqual(set(params), set(regressor_params))
|
| 812 |
|
| 813 |
|
| 814 |
TRUE_PREAMBLE = "\n".join(
|
|
@@ -1158,6 +1190,7 @@ def runtests(just_tests=False):
|
|
| 1158 |
TestBest,
|
| 1159 |
TestFeatureSelection,
|
| 1160 |
TestMiscellaneous,
|
|
|
|
| 1161 |
TestLaTeXTable,
|
| 1162 |
TestDimensionalConstraints,
|
| 1163 |
]
|
|
|
|
| 15 |
from ..export_latex import sympy2latex
|
| 16 |
from ..feature_selection import _handle_feature_selection, run_feature_selection
|
| 17 |
from ..julia_helpers import init_julia
|
| 18 |
+
from ..sr import (
|
| 19 |
+
_check_assertions,
|
| 20 |
+
_process_constraints,
|
| 21 |
+
_suggest_keywords,
|
| 22 |
+
idx_model_selection,
|
| 23 |
+
)
|
| 24 |
from ..utils import _csv_filename_to_pkl_filename
|
| 25 |
from .params import (
|
| 26 |
DEFAULT_NCYCLES,
|
|
|
|
| 578 |
test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
|
| 579 |
self.assertEqual(test_pkl_file, str(expected_pkl_file))
|
| 580 |
|
| 581 |
+
def test_pickle_with_temp_equation_file(self):
|
| 582 |
+
"""If we have a temporary equation file, unpickle the estimator."""
|
| 583 |
+
model = PySRRegressor(
|
| 584 |
+
populations=int(1 + DEFAULT_POPULATIONS / 5),
|
| 585 |
+
temp_equation_file=True,
|
| 586 |
+
procs=0,
|
| 587 |
+
multithreading=False,
|
| 588 |
+
)
|
| 589 |
+
nout = 3
|
| 590 |
+
X = np.random.randn(100, 2)
|
| 591 |
+
y = np.random.randn(100, nout)
|
| 592 |
+
model.fit(X, y)
|
| 593 |
+
contents = model.equation_file_contents_.copy()
|
| 594 |
+
|
| 595 |
+
y_predictions = model.predict(X)
|
| 596 |
+
|
| 597 |
+
equation_file_base = model.equation_file_
|
| 598 |
+
for i in range(1, nout + 1):
|
| 599 |
+
assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
|
| 600 |
+
|
| 601 |
+
with tempfile.NamedTemporaryFile() as pickle_file:
|
| 602 |
+
pkl.dump(model, pickle_file)
|
| 603 |
+
pickle_file.seek(0)
|
| 604 |
+
model2 = pkl.load(pickle_file)
|
| 605 |
+
|
| 606 |
+
contents2 = model2.equation_file_contents_
|
| 607 |
+
cols_to_check = ["equation", "loss", "complexity"]
|
| 608 |
+
for frame1, frame2 in zip(contents, contents2):
|
| 609 |
+
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
|
| 610 |
+
|
| 611 |
+
y_predictions2 = model2.predict(X)
|
| 612 |
+
np.testing.assert_array_equal(y_predictions, y_predictions2)
|
| 613 |
+
|
| 614 |
+
def test_scikit_learn_compatibility(self):
|
| 615 |
+
"""Test PySRRegressor compatibility with scikit-learn."""
|
| 616 |
+
model = PySRRegressor(
|
| 617 |
+
niterations=int(1 + DEFAULT_NITERATIONS / 10),
|
| 618 |
+
populations=int(1 + DEFAULT_POPULATIONS / 3),
|
| 619 |
+
ncycles_per_iteration=int(2 + DEFAULT_NCYCLES / 10),
|
| 620 |
+
verbosity=0,
|
| 621 |
+
progress=False,
|
| 622 |
+
random_state=0,
|
| 623 |
+
deterministic=True, # Deterministic as tests require this.
|
| 624 |
+
procs=0,
|
| 625 |
+
multithreading=False,
|
| 626 |
+
warm_start=False,
|
| 627 |
+
temp_equation_file=True,
|
| 628 |
+
) # Return early.
|
| 629 |
+
|
| 630 |
+
check_generator = check_estimator(model, generate_only=True)
|
| 631 |
+
exception_messages = []
|
| 632 |
+
for _, check in check_generator:
|
| 633 |
+
if check.func.__name__ == "check_complex_data":
|
| 634 |
+
# We can use complex data, so avoid this check.
|
| 635 |
+
continue
|
| 636 |
+
try:
|
| 637 |
+
with warnings.catch_warnings():
|
| 638 |
+
warnings.simplefilter("ignore")
|
| 639 |
+
check(model)
|
| 640 |
+
print("Passed", check.func.__name__)
|
| 641 |
+
except Exception:
|
| 642 |
+
error_message = str(traceback.format_exc())
|
| 643 |
+
exception_messages.append(
|
| 644 |
+
f"{check.func.__name__}:\n" + error_message + "\n"
|
| 645 |
+
)
|
| 646 |
+
print("Failed", check.func.__name__, "with:")
|
| 647 |
+
# Add a leading tab to error message, which
|
| 648 |
+
# might be multi-line:
|
| 649 |
+
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 650 |
+
# If any checks failed don't let the test pass.
|
| 651 |
+
self.assertEqual(len(exception_messages), 0)
|
| 652 |
+
|
| 653 |
+
def test_param_groupings(self):
|
| 654 |
+
"""Test that param_groupings are complete"""
|
| 655 |
+
param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
|
| 656 |
+
if not param_groupings_file.exists():
|
| 657 |
+
return
|
| 658 |
+
|
| 659 |
+
# Read the file, discarding lines ending in ":",
|
| 660 |
+
# and removing leading "\s*-\s*":
|
| 661 |
+
params = []
|
| 662 |
+
with open(param_groupings_file, "r") as f:
|
| 663 |
+
for line in f.readlines():
|
| 664 |
+
if line.strip().endswith(":"):
|
| 665 |
+
continue
|
| 666 |
+
if line.strip().startswith("-"):
|
| 667 |
+
params.append(line.strip()[1:].strip())
|
| 668 |
+
|
| 669 |
+
regressor_params = [
|
| 670 |
+
p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
|
| 671 |
+
]
|
| 672 |
+
|
| 673 |
+
# Check the sets are equal:
|
| 674 |
+
self.assertSetEqual(set(params), set(regressor_params))
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class TestHelpMessages(unittest.TestCase):
|
| 678 |
+
"""Test user help messages."""
|
| 679 |
+
|
| 680 |
def test_deprecation(self):
|
| 681 |
"""Ensure that deprecation works as expected.
|
| 682 |
|
|
|
|
| 819 |
model.get_best()
|
| 820 |
print("Failed", opt["kwargs"])
|
| 821 |
|
| 822 |
+
def test_suggest_keywords(self):
|
| 823 |
+
# Easy
|
| 824 |
+
self.assertEqual(
|
| 825 |
+
_suggest_keywords(PySRRegressor, "loss_function"), ["loss_function"]
|
|
|
|
|
|
|
|
|
|
| 826 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
|
| 828 |
+
# More complex, and with error
|
| 829 |
+
with self.assertRaises(TypeError) as cm:
|
| 830 |
+
model = PySRRegressor(ncyclesperiterationn=5)
|
| 831 |
|
| 832 |
+
self.assertIn(
|
| 833 |
+
"`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
|
| 834 |
+
)
|
| 835 |
+
self.assertIn("Did you mean", str(cm.exception))
|
| 836 |
+
self.assertIn("`ncycles_per_iteration`, ", str(cm.exception))
|
| 837 |
+
self.assertIn("`niterations`", str(cm.exception))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
|
| 839 |
+
# Farther matches (this might need to be changed)
|
| 840 |
+
with self.assertRaises(TypeError) as cm:
|
| 841 |
+
model = PySRRegressor(operators=["+", "-"])
|
| 842 |
|
| 843 |
+
self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))
|
|
|
|
| 844 |
|
| 845 |
|
| 846 |
TRUE_PREAMBLE = "\n".join(
|
|
|
|
| 1190 |
TestBest,
|
| 1191 |
TestFeatureSelection,
|
| 1192 |
TestMiscellaneous,
|
| 1193 |
+
TestHelpMessages,
|
| 1194 |
TestLaTeXTable,
|
| 1195 |
TestDimensionalConstraints,
|
| 1196 |
]
|