Spaces:
Running
Running
Commit
·
b8a97f1
1
Parent(s):
b53e7fa
Use .pkl instead of .csv.pkl
Browse files- pysr/sr.py +28 -10
- test/test.py +20 -1
pysr/sr.py
CHANGED
|
@@ -930,7 +930,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 930 |
This should only be used internally by PySRRegressor."""
|
| 931 |
# Save model state:
|
| 932 |
self.show_pickle_warnings_ = False
|
| 933 |
-
with open(
|
| 934 |
pkl.dump(self, f)
|
| 935 |
self.show_pickle_warnings_ = True
|
| 936 |
|
|
@@ -1636,14 +1636,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1636 |
|
| 1637 |
# Initially, just save model parameters, so that
|
| 1638 |
# it can be loaded from an early exit:
|
| 1639 |
-
self.
|
|
|
|
| 1640 |
|
| 1641 |
# Perform the search:
|
| 1642 |
self._run(X, y, mutated_params, weights=weights, seed=seed)
|
| 1643 |
|
| 1644 |
# Then, after fit, we save again, so the pickle file contains
|
| 1645 |
# the equations:
|
| 1646 |
-
self.
|
|
|
|
| 1647 |
|
| 1648 |
return self
|
| 1649 |
|
|
@@ -2077,6 +2079,17 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
|
|
| 2077 |
return selector.get_support(indices=True)
|
| 2078 |
|
| 2079 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2080 |
def load(
|
| 2081 |
equation_file,
|
| 2082 |
*,
|
|
@@ -2094,7 +2107,8 @@ def load(
|
|
| 2094 |
Parameters
|
| 2095 |
----------
|
| 2096 |
equation_file : str
|
| 2097 |
-
Path to a csv file containing equations
|
|
|
|
| 2098 |
|
| 2099 |
binary_operators : list[str], default=["+", "-", "*", "/"]
|
| 2100 |
The same binary operators used when creating the model.
|
|
@@ -2123,14 +2137,19 @@ def load(
|
|
| 2123 |
model : PySRRegressor
|
| 2124 |
The model with fitted equations.
|
| 2125 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2126 |
# Try to load model from <equation_file>.pkl
|
| 2127 |
-
print(f"Checking if {
|
| 2128 |
-
if os.path.exists(
|
| 2129 |
-
print(f"Loading model from {
|
| 2130 |
assert binary_operators is None
|
| 2131 |
assert unary_operators is None
|
| 2132 |
assert n_features_in is None
|
| 2133 |
-
with open(
|
| 2134 |
model = pkl.load(f)
|
| 2135 |
# Update any parameters if necessary, such as
|
| 2136 |
# extra_sympy_mappings:
|
|
@@ -2142,8 +2161,7 @@ def load(
|
|
| 2142 |
|
| 2143 |
# Else, we re-create it.
|
| 2144 |
print(
|
| 2145 |
-
f"{equation_file}
|
| 2146 |
-
"so we must create the model from scratch."
|
| 2147 |
)
|
| 2148 |
assert binary_operators is not None
|
| 2149 |
assert unary_operators is not None
|
|
|
|
| 930 |
This should only be used internally by PySRRegressor."""
|
| 931 |
# Save model state:
|
| 932 |
self.show_pickle_warnings_ = False
|
| 933 |
+
with open(_csv_filename_to_pkl_filename(self.equation_file_), "wb") as f:
|
| 934 |
pkl.dump(self, f)
|
| 935 |
self.show_pickle_warnings_ = True
|
| 936 |
|
|
|
|
| 1636 |
|
| 1637 |
# Initially, just save model parameters, so that
|
| 1638 |
# it can be loaded from an early exit:
|
| 1639 |
+
if not self.temp_equation_file:
|
| 1640 |
+
self._checkpoint()
|
| 1641 |
|
| 1642 |
# Perform the search:
|
| 1643 |
self._run(X, y, mutated_params, weights=weights, seed=seed)
|
| 1644 |
|
| 1645 |
# Then, after fit, we save again, so the pickle file contains
|
| 1646 |
# the equations:
|
| 1647 |
+
if not self.temp_equation_file:
|
| 1648 |
+
self._checkpoint()
|
| 1649 |
|
| 1650 |
return self
|
| 1651 |
|
|
|
|
| 2079 |
return selector.get_support(indices=True)
|
| 2080 |
|
| 2081 |
|
| 2082 |
+
def _csv_filename_to_pkl_filename(csv_filename) -> str:
|
| 2083 |
+
# Assume that the csv filename is of the form "foo.csv"
|
| 2084 |
+
dirname = str(os.path.dirname(csv_filename))
|
| 2085 |
+
basename = str(os.path.basename(csv_filename))
|
| 2086 |
+
base = str(os.path.splitext(basename)[0])
|
| 2087 |
+
|
| 2088 |
+
pkl_basename = base + ".pkl"
|
| 2089 |
+
|
| 2090 |
+
return os.path.join(dirname, pkl_basename)
|
| 2091 |
+
|
| 2092 |
+
|
| 2093 |
def load(
|
| 2094 |
equation_file,
|
| 2095 |
*,
|
|
|
|
| 2107 |
Parameters
|
| 2108 |
----------
|
| 2109 |
equation_file : str
|
| 2110 |
+
Path to a csv file containing equations, or a pickle file
|
| 2111 |
+
containing the model.
|
| 2112 |
|
| 2113 |
binary_operators : list[str], default=["+", "-", "*", "/"]
|
| 2114 |
The same binary operators used when creating the model.
|
|
|
|
| 2137 |
model : PySRRegressor
|
| 2138 |
The model with fitted equations.
|
| 2139 |
"""
|
| 2140 |
+
if os.path.splitext(equation_file)[1] != ".pkl":
|
| 2141 |
+
pkl_filename = _csv_filename_to_pkl_filename(equation_file)
|
| 2142 |
+
else:
|
| 2143 |
+
pkl_filename = equation_file
|
| 2144 |
+
|
| 2145 |
# Try to load model from <equation_file>.pkl
|
| 2146 |
+
print(f"Checking if {pkl_filename} exists...")
|
| 2147 |
+
if os.path.exists(pkl_filename):
|
| 2148 |
+
print(f"Loading model from {pkl_filename}")
|
| 2149 |
assert binary_operators is None
|
| 2150 |
assert unary_operators is None
|
| 2151 |
assert n_features_in is None
|
| 2152 |
+
with open(pkl_filename, "rb") as f:
|
| 2153 |
model = pkl.load(f)
|
| 2154 |
# Update any parameters if necessary, such as
|
| 2155 |
# extra_sympy_mappings:
|
|
|
|
| 2161 |
|
| 2162 |
# Else, we re-create it.
|
| 2163 |
print(
|
| 2164 |
+
f"{equation_file} does not exist, " "so we must create the model from scratch."
|
|
|
|
| 2165 |
)
|
| 2166 |
assert binary_operators is not None
|
| 2167 |
assert unary_operators is not None
|
test/test.py
CHANGED
|
@@ -5,7 +5,11 @@ import unittest
|
|
| 5 |
import numpy as np
|
| 6 |
from sklearn import model_selection
|
| 7 |
from pysr import PySRRegressor, load
|
| 8 |
-
from pysr.sr import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from sklearn.utils.estimator_checks import check_estimator
|
| 10 |
import sympy
|
| 11 |
import pandas as pd
|
|
@@ -341,6 +345,7 @@ class TestPipeline(unittest.TestCase):
|
|
| 341 |
if os.path.exists(file_to_delete):
|
| 342 |
os.remove(file_to_delete)
|
| 343 |
|
|
|
|
| 344 |
model3 = load(
|
| 345 |
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
| 346 |
)
|
|
@@ -430,6 +435,20 @@ class TestFeatureSelection(unittest.TestCase):
|
|
| 430 |
class TestMiscellaneous(unittest.TestCase):
|
| 431 |
"""Test miscellaneous functions."""
|
| 432 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
def test_deprecation(self):
|
| 434 |
"""Ensure that deprecation works as expected.
|
| 435 |
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
from sklearn import model_selection
|
| 7 |
from pysr import PySRRegressor, load
|
| 8 |
+
from pysr.sr import (
|
| 9 |
+
run_feature_selection,
|
| 10 |
+
_handle_feature_selection,
|
| 11 |
+
_csv_filename_to_pkl_filename,
|
| 12 |
+
)
|
| 13 |
from sklearn.utils.estimator_checks import check_estimator
|
| 14 |
import sympy
|
| 15 |
import pandas as pd
|
|
|
|
| 345 |
if os.path.exists(file_to_delete):
|
| 346 |
os.remove(file_to_delete)
|
| 347 |
|
| 348 |
+
pickle_file = rand_dir / "equations.pkl"
|
| 349 |
model3 = load(
|
| 350 |
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
| 351 |
)
|
|
|
|
| 435 |
class TestMiscellaneous(unittest.TestCase):
|
| 436 |
"""Test miscellaneous functions."""
|
| 437 |
|
| 438 |
+
def test_csv_to_pkl_conversion(self):
|
| 439 |
+
"""Test that csv filename to pkl filename works as expected."""
|
| 440 |
+
tmpdir = Path(tempfile.mkdtemp())
|
| 441 |
+
equation_file = tmpdir / "equations.389479384.28378374.csv"
|
| 442 |
+
expected_pkl_file = tmpdir / "equations.389479384.28378374.pkl"
|
| 443 |
+
|
| 444 |
+
# First, test inputting the paths:
|
| 445 |
+
test_pkl_file = _csv_filename_to_pkl_filename(equation_file)
|
| 446 |
+
self.assertEqual(test_pkl_file, str(expected_pkl_file))
|
| 447 |
+
|
| 448 |
+
# Next, test inputting the strings.
|
| 449 |
+
test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
|
| 450 |
+
self.assertEqual(test_pkl_file, str(expected_pkl_file))
|
| 451 |
+
|
| 452 |
def test_deprecation(self):
|
| 453 |
"""Ensure that deprecation works as expected.
|
| 454 |
|