Spaces:
Running
Running
Commit
·
0dfd8e3
1
Parent(s):
a9a1691
Refactored out paths and others
Browse files- pysr/sr.py +31 -22
pysr/sr.py
CHANGED
|
@@ -207,16 +207,7 @@ def pysr(X=None, y=None, weights=None,
|
|
| 207 |
if len(X.shape) == 1:
|
| 208 |
X = X[:, None]
|
| 209 |
|
| 210 |
-
|
| 211 |
-
assert len(unary_operators) + len(binary_operators) > 0
|
| 212 |
-
assert len(X.shape) == 2
|
| 213 |
-
assert len(y.shape) == 1
|
| 214 |
-
assert X.shape[0] == y.shape[0]
|
| 215 |
-
if weights is not None:
|
| 216 |
-
assert len(weights.shape) == 1
|
| 217 |
-
assert X.shape[0] == weights.shape[0]
|
| 218 |
-
if use_custom_variable_names:
|
| 219 |
-
assert len(variable_names) == X.shape[1]
|
| 220 |
|
| 221 |
if select_k_features is not None:
|
| 222 |
selection = run_feature_selection(X, y, select_k_features)
|
|
@@ -248,18 +239,8 @@ def pysr(X=None, y=None, weights=None,
|
|
| 248 |
y = eval(eval_str)
|
| 249 |
print("Running on", eval_str)
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
pkg_filename = pkg_directory / "sr.jl"
|
| 254 |
-
operator_filename = pkg_directory / "operators.jl"
|
| 255 |
-
|
| 256 |
-
tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
|
| 257 |
-
hyperparam_filename = tmpdir / f'hyperparams.jl'
|
| 258 |
-
dataset_filename = tmpdir / f'dataset.jl'
|
| 259 |
-
runfile_filename = tmpdir / f'runfile.jl'
|
| 260 |
-
X_filename = tmpdir / "X.csv"
|
| 261 |
-
y_filename = tmpdir / "y.csv"
|
| 262 |
-
weights_filename = tmpdir / "weights.csv"
|
| 263 |
|
| 264 |
def_hyperparams = ""
|
| 265 |
|
|
@@ -463,6 +444,34 @@ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
|
|
| 463 |
return get_hof()
|
| 464 |
|
| 465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
def raise_depreciation_errors(limitPowComplexity, threads):
|
| 467 |
if threads is not None:
|
| 468 |
raise ValueError("The threads kwarg is deprecated. Use procs.")
|
|
|
|
| 207 |
if len(X.shape) == 1:
|
| 208 |
X = X[:, None]
|
| 209 |
|
| 210 |
+
check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
if select_k_features is not None:
|
| 213 |
selection = run_feature_selection(X, y, select_k_features)
|
|
|
|
| 239 |
y = eval(eval_str)
|
| 240 |
print("Running on", eval_str)
|
| 241 |
|
| 242 |
+
X_filename, dataset_filename, hyperparam_filename, operator_filename, pkg_filename, runfile_filename, tmpdir, weights_filename, y_filename = set_paths(
|
| 243 |
+
tempdir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
def_hyperparams = ""
|
| 246 |
|
|
|
|
| 444 |
return get_hof()
|
| 445 |
|
| 446 |
|
| 447 |
+
def set_paths(tempdir):
|
| 448 |
+
# System-independent paths
|
| 449 |
+
pkg_directory = Path(__file__).parents[1] / 'julia'
|
| 450 |
+
pkg_filename = pkg_directory / "sr.jl"
|
| 451 |
+
operator_filename = pkg_directory / "operators.jl"
|
| 452 |
+
tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
|
| 453 |
+
hyperparam_filename = tmpdir / f'hyperparams.jl'
|
| 454 |
+
dataset_filename = tmpdir / f'dataset.jl'
|
| 455 |
+
runfile_filename = tmpdir / f'runfile.jl'
|
| 456 |
+
X_filename = tmpdir / "X.csv"
|
| 457 |
+
y_filename = tmpdir / "y.csv"
|
| 458 |
+
weights_filename = tmpdir / "weights.csv"
|
| 459 |
+
return X_filename, dataset_filename, hyperparam_filename, operator_filename, pkg_filename, runfile_filename, tmpdir, weights_filename, y_filename
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y):
|
| 463 |
+
# Check for potential errors before they happen
|
| 464 |
+
assert len(unary_operators) + len(binary_operators) > 0
|
| 465 |
+
assert len(X.shape) == 2
|
| 466 |
+
assert len(y.shape) == 1
|
| 467 |
+
assert X.shape[0] == y.shape[0]
|
| 468 |
+
if weights is not None:
|
| 469 |
+
assert len(weights.shape) == 1
|
| 470 |
+
assert X.shape[0] == weights.shape[0]
|
| 471 |
+
if use_custom_variable_names:
|
| 472 |
+
assert len(variable_names) == X.shape[1]
|
| 473 |
+
|
| 474 |
+
|
| 475 |
def raise_depreciation_errors(limitPowComplexity, threads):
|
| 476 |
if threads is not None:
|
| 477 |
raise ValueError("The threads kwarg is deprecated. Use procs.")
|