Spaces:
Running
Running
Merge pull request #46 from MilesCranmer/multi-output
Browse files- Project.toml +1 -1
- pysr/sr.py +123 -61
- test/test.py +4 -3
Project.toml
CHANGED
|
@@ -2,5 +2,5 @@
|
|
| 2 |
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
|
| 3 |
|
| 4 |
[compat]
|
| 5 |
-
SymbolicRegression = "0.
|
| 6 |
julia = "1.5"
|
|
|
|
| 2 |
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
|
| 3 |
|
| 4 |
[compat]
|
| 5 |
+
SymbolicRegression = "0.6.0"
|
| 6 |
julia = "1.5"
|
pysr/sr.py
CHANGED
|
@@ -19,6 +19,8 @@ global_equation_file = 'hall_of_fame.csv'
|
|
| 19 |
global_n_features = None
|
| 20 |
global_variable_names = []
|
| 21 |
global_extra_sympy_mappings = {}
|
|
|
|
|
|
|
| 22 |
|
| 23 |
sympy_mappings = {
|
| 24 |
'div': lambda x, y : x/y,
|
|
@@ -276,6 +278,16 @@ def pysr(X=None, y=None, weights=None,
|
|
| 276 |
if X is None:
|
| 277 |
X, y = _using_test_input(X, test, y)
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
kwargs = dict(X=X, y=y, weights=weights,
|
| 280 |
alpha=alpha, annealing=annealing, batchSize=batchSize,
|
| 281 |
batching=batching, binary_operators=binary_operators,
|
|
@@ -309,7 +321,8 @@ def pysr(X=None, y=None, weights=None,
|
|
| 309 |
constraints=constraints,
|
| 310 |
extra_sympy_mappings=extra_sympy_mappings,
|
| 311 |
julia_project=julia_project, loss=loss,
|
| 312 |
-
output_jax_format=output_jax_format
|
|
|
|
| 313 |
|
| 314 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
| 315 |
|
|
@@ -358,15 +371,20 @@ def pysr(X=None, y=None, weights=None,
|
|
| 358 |
|
| 359 |
|
| 360 |
|
| 361 |
-
def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
|
|
|
|
| 362 |
global global_n_features
|
| 363 |
global global_equation_file
|
| 364 |
global global_variable_names
|
| 365 |
global global_extra_sympy_mappings
|
|
|
|
|
|
|
| 366 |
global_n_features = X.shape[1]
|
| 367 |
global_equation_file = equation_file
|
| 368 |
global_variable_names = variable_names
|
| 369 |
global_extra_sympy_mappings = extra_sympy_mappings
|
|
|
|
|
|
|
| 370 |
|
| 371 |
|
| 372 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
|
@@ -393,9 +411,7 @@ def _cmd_runner(command, **kwargs):
|
|
| 393 |
.replace('\\r', '\r')
|
| 394 |
.encode(sys.stdout.encoding, errors='replace'))
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
|
| 400 |
process.stdout.close()
|
| 401 |
process.wait()
|
|
@@ -438,17 +454,35 @@ def _create_julia_files(dataset_filename, def_datasets, hyperparam_filename, de
|
|
| 438 |
print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
|
| 439 |
|
| 440 |
|
| 441 |
-
def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
|
|
|
|
| 442 |
def_datasets = """using DelimitedFiles"""
|
| 443 |
np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
| 445 |
if weights is not None:
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
| 447 |
def_datasets += f"""
|
| 448 |
-
X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
|
|
|
|
| 450 |
if weights is not None:
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
|
| 453 |
return def_datasets
|
| 454 |
|
|
@@ -656,10 +690,10 @@ def _check_assertions(X, binary_operators, unary_operators, use_custom_variable_
|
|
| 656 |
# Check for potential errors before they happen
|
| 657 |
assert len(unary_operators) + len(binary_operators) > 0
|
| 658 |
assert len(X.shape) == 2
|
| 659 |
-
assert len(y.shape)
|
| 660 |
assert X.shape[0] == y.shape[0]
|
| 661 |
if weights is not None:
|
| 662 |
-
assert
|
| 663 |
assert X.shape[0] == weights.shape[0]
|
| 664 |
if use_custom_variable_names:
|
| 665 |
assert len(variable_names) == X.shape[1]
|
|
@@ -693,7 +727,8 @@ def run_feature_selection(X, y, select_k_features):
|
|
| 693 |
return selector.get_support(indices=True)
|
| 694 |
|
| 695 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 696 |
-
extra_sympy_mappings=None, output_jax_format=False,
|
|
|
|
| 697 |
"""Get the equations from a hall of fame file. If no arguments
|
| 698 |
entered, the ones used previously from a call to PySR will be used."""
|
| 699 |
|
|
@@ -701,99 +736,126 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 701 |
global global_equation_file
|
| 702 |
global global_variable_names
|
| 703 |
global global_extra_sympy_mappings
|
|
|
|
|
|
|
| 704 |
|
| 705 |
if equation_file is None: equation_file = global_equation_file
|
| 706 |
if n_features is None: n_features = global_n_features
|
| 707 |
if variable_names is None: variable_names = global_variable_names
|
| 708 |
if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
|
|
|
|
|
|
|
| 709 |
|
| 710 |
global_equation_file = equation_file
|
| 711 |
global_n_features = n_features
|
| 712 |
global_variable_names = variable_names
|
| 713 |
global_extra_sympy_mappings = extra_sympy_mappings
|
|
|
|
|
|
|
| 714 |
|
| 715 |
try:
|
| 716 |
-
|
|
|
|
|
|
|
|
|
|
| 717 |
except FileNotFoundError:
|
| 718 |
raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
|
| 719 |
|
| 720 |
-
|
| 721 |
-
lastMSE = None
|
| 722 |
-
lastComplexity = 0
|
| 723 |
-
sympy_format = []
|
| 724 |
-
lambda_format = []
|
| 725 |
-
if output_jax_format:
|
| 726 |
-
jax_format = []
|
| 727 |
-
use_custom_variable_names = (len(variable_names) != 0)
|
| 728 |
-
local_sympy_mappings = {
|
| 729 |
-
**extra_sympy_mappings,
|
| 730 |
-
**sympy_mappings
|
| 731 |
-
}
|
| 732 |
|
| 733 |
-
|
| 734 |
-
sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
|
| 735 |
-
else:
|
| 736 |
-
sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
|
| 737 |
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
|
|
|
|
|
|
| 741 |
if output_jax_format:
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
if lastMSE is None:
|
| 749 |
-
cur_score = 0.0
|
| 750 |
-
else:
|
| 751 |
-
cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
|
| 752 |
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 756 |
|
| 757 |
-
|
| 758 |
-
output['sympy_format'] = sympy_format
|
| 759 |
-
output['lambda_format'] = lambda_format
|
| 760 |
-
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
| 761 |
-
if output_jax_format:
|
| 762 |
-
output_cols += ['jax_format']
|
| 763 |
-
output['jax_format'] = jax_format
|
| 764 |
|
| 765 |
-
|
|
|
|
|
|
|
|
|
|
| 766 |
|
| 767 |
def best_row(equations=None):
|
| 768 |
"""Return the best row of a hall of fame file using the score column.
|
| 769 |
By default this uses the last equation file.
|
| 770 |
"""
|
| 771 |
if equations is None: equations = get_hof()
|
| 772 |
-
|
| 773 |
-
|
|
|
|
|
|
|
| 774 |
|
| 775 |
def best_tex(equations=None):
|
| 776 |
"""Return the equation with the best score, in latex format
|
| 777 |
By default this uses the last equation file.
|
| 778 |
"""
|
| 779 |
if equations is None: equations = get_hof()
|
| 780 |
-
|
| 781 |
-
|
|
|
|
|
|
|
| 782 |
|
| 783 |
def best(equations=None):
|
| 784 |
"""Return the equation with the best score, in sympy format.
|
| 785 |
By default this uses the last equation file.
|
| 786 |
"""
|
| 787 |
if equations is None: equations = get_hof()
|
| 788 |
-
|
| 789 |
-
|
|
|
|
|
|
|
| 790 |
|
| 791 |
def best_callable(equations=None):
|
| 792 |
"""Return the equation with the best score, in callable format.
|
| 793 |
By default this uses the last equation file.
|
| 794 |
"""
|
| 795 |
if equations is None: equations = get_hof()
|
| 796 |
-
|
|
|
|
|
|
|
|
|
|
| 797 |
|
| 798 |
def _escape_filename(filename):
|
| 799 |
"""Turns a file into a string representation with correctly escaped backslashes"""
|
|
|
|
| 19 |
global_n_features = None
|
| 20 |
global_variable_names = []
|
| 21 |
global_extra_sympy_mappings = {}
|
| 22 |
+
global_multioutput = False
|
| 23 |
+
global_nout = 1
|
| 24 |
|
| 25 |
sympy_mappings = {
|
| 26 |
'div': lambda x, y : x/y,
|
|
|
|
| 278 |
if X is None:
|
| 279 |
X, y = _using_test_input(X, test, y)
|
| 280 |
|
| 281 |
+
if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
|
| 282 |
+
multioutput = False
|
| 283 |
+
nout = 1
|
| 284 |
+
y = y.reshape(-1)
|
| 285 |
+
elif len(y.shape) == 2:
|
| 286 |
+
multioutput = True
|
| 287 |
+
nout = y.shape[1]
|
| 288 |
+
else:
|
| 289 |
+
raise NotImplementedError("y shape not supported!")
|
| 290 |
+
|
| 291 |
kwargs = dict(X=X, y=y, weights=weights,
|
| 292 |
alpha=alpha, annealing=annealing, batchSize=batchSize,
|
| 293 |
batching=batching, binary_operators=binary_operators,
|
|
|
|
| 321 |
constraints=constraints,
|
| 322 |
extra_sympy_mappings=extra_sympy_mappings,
|
| 323 |
julia_project=julia_project, loss=loss,
|
| 324 |
+
output_jax_format=output_jax_format,
|
| 325 |
+
multioutput=multioutput, nout=nout)
|
| 326 |
|
| 327 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
| 328 |
|
|
|
|
| 371 |
|
| 372 |
|
| 373 |
|
| 374 |
+
def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
|
| 375 |
+
multioutput, nout, **kwargs):
|
| 376 |
global global_n_features
|
| 377 |
global global_equation_file
|
| 378 |
global global_variable_names
|
| 379 |
global global_extra_sympy_mappings
|
| 380 |
+
global global_multioutput
|
| 381 |
+
global global_nout
|
| 382 |
global_n_features = X.shape[1]
|
| 383 |
global_equation_file = equation_file
|
| 384 |
global_variable_names = variable_names
|
| 385 |
global_extra_sympy_mappings = extra_sympy_mappings
|
| 386 |
+
global_multioutput = multioutput
|
| 387 |
+
global_nout = nout
|
| 388 |
|
| 389 |
|
| 390 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
|
|
|
| 411 |
.replace('\\r', '\r')
|
| 412 |
.encode(sys.stdout.encoding, errors='replace'))
|
| 413 |
|
| 414 |
+
sys.stdout.buffer.write(decoded_line)
|
|
|
|
|
|
|
| 415 |
|
| 416 |
process.stdout.close()
|
| 417 |
process.wait()
|
|
|
|
| 454 |
print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
|
| 455 |
|
| 456 |
|
| 457 |
+
def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
|
| 458 |
+
multioutput, **kwargs):
|
| 459 |
def_datasets = """using DelimitedFiles"""
|
| 460 |
np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
|
| 461 |
+
if multioutput:
|
| 462 |
+
np.savetxt(y_filename, y.astype(np.float32), delimiter=',')
|
| 463 |
+
else:
|
| 464 |
+
np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
|
| 465 |
if weights is not None:
|
| 466 |
+
if multioutput:
|
| 467 |
+
np.savetxt(weights_filename, weights.astype(np.float32), delimiter=',')
|
| 468 |
+
else:
|
| 469 |
+
np.savetxt(weights_filename, weights.reshape(-1, 1).astype(np.float32), delimiter=',')
|
| 470 |
def_datasets += f"""
|
| 471 |
+
X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))"""
|
| 472 |
+
|
| 473 |
+
if multioutput:
|
| 474 |
+
def_datasets+= f"""
|
| 475 |
+
y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')))"""
|
| 476 |
+
else:
|
| 477 |
+
def_datasets+= f"""
|
| 478 |
y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
|
| 479 |
+
|
| 480 |
if weights is not None:
|
| 481 |
+
if multioutput:
|
| 482 |
+
def_datasets += f"""
|
| 483 |
+
weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')))"""
|
| 484 |
+
else:
|
| 485 |
+
def_datasets += f"""
|
| 486 |
weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
|
| 487 |
return def_datasets
|
| 488 |
|
|
|
|
| 690 |
# Check for potential errors before they happen
|
| 691 |
assert len(unary_operators) + len(binary_operators) > 0
|
| 692 |
assert len(X.shape) == 2
|
| 693 |
+
assert len(y.shape) in [1, 2]
|
| 694 |
assert X.shape[0] == y.shape[0]
|
| 695 |
if weights is not None:
|
| 696 |
+
assert weights.shape == y.shape
|
| 697 |
assert X.shape[0] == weights.shape[0]
|
| 698 |
if use_custom_variable_names:
|
| 699 |
assert len(variable_names) == X.shape[1]
|
|
|
|
| 727 |
return selector.get_support(indices=True)
|
| 728 |
|
| 729 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 730 |
+
extra_sympy_mappings=None, output_jax_format=False,
|
| 731 |
+
multioutput=None, nout=None, **kwargs):
|
| 732 |
"""Get the equations from a hall of fame file. If no arguments
|
| 733 |
entered, the ones used previously from a call to PySR will be used."""
|
| 734 |
|
|
|
|
| 736 |
global global_equation_file
|
| 737 |
global global_variable_names
|
| 738 |
global global_extra_sympy_mappings
|
| 739 |
+
global global_multioutput
|
| 740 |
+
global global_nout
|
| 741 |
|
| 742 |
if equation_file is None: equation_file = global_equation_file
|
| 743 |
if n_features is None: n_features = global_n_features
|
| 744 |
if variable_names is None: variable_names = global_variable_names
|
| 745 |
if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
|
| 746 |
+
if multioutput is None: multioutput = global_multioutput
|
| 747 |
+
if nout is None: nout = global_nout
|
| 748 |
|
| 749 |
global_equation_file = equation_file
|
| 750 |
global_n_features = n_features
|
| 751 |
global_variable_names = variable_names
|
| 752 |
global_extra_sympy_mappings = extra_sympy_mappings
|
| 753 |
+
global_multioutput = multioutput
|
| 754 |
+
global_nout = nout
|
| 755 |
|
| 756 |
try:
|
| 757 |
+
if multioutput:
|
| 758 |
+
all_outputs = [pd.read_csv(f'out{i}_' + str(equation_file) + '.bkup', sep="|") for i in range(1, nout+1)]
|
| 759 |
+
else:
|
| 760 |
+
all_outputs = [pd.read_csv(str(equation_file) + '.bkup', sep="|")]
|
| 761 |
except FileNotFoundError:
|
| 762 |
raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
|
| 763 |
|
| 764 |
+
ret_outputs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
|
| 766 |
+
for output in all_outputs:
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
+
scores = []
|
| 769 |
+
lastMSE = None
|
| 770 |
+
lastComplexity = 0
|
| 771 |
+
sympy_format = []
|
| 772 |
+
lambda_format = []
|
| 773 |
if output_jax_format:
|
| 774 |
+
jax_format = []
|
| 775 |
+
use_custom_variable_names = (len(variable_names) != 0)
|
| 776 |
+
local_sympy_mappings = {
|
| 777 |
+
**extra_sympy_mappings,
|
| 778 |
+
**sympy_mappings
|
| 779 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
|
| 781 |
+
if use_custom_variable_names:
|
| 782 |
+
sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
|
| 783 |
+
else:
|
| 784 |
+
sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
|
| 785 |
+
|
| 786 |
+
for i in range(len(output)):
|
| 787 |
+
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
| 788 |
+
sympy_format.append(eqn)
|
| 789 |
+
if output_jax_format:
|
| 790 |
+
func, params = sympy2jax(eqn, sympy_symbols)
|
| 791 |
+
jax_format.append({'callable': func, 'parameters': params})
|
| 792 |
+
lambda_format.append(lambdify(sympy_symbols, eqn))
|
| 793 |
+
curMSE = output.loc[i, 'MSE']
|
| 794 |
+
curComplexity = output.loc[i, 'Complexity']
|
| 795 |
+
|
| 796 |
+
if lastMSE is None:
|
| 797 |
+
cur_score = 0.0
|
| 798 |
+
else:
|
| 799 |
+
cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
|
| 800 |
+
|
| 801 |
+
scores.append(cur_score)
|
| 802 |
+
lastMSE = curMSE
|
| 803 |
+
lastComplexity = curComplexity
|
| 804 |
+
|
| 805 |
+
output['score'] = np.array(scores)
|
| 806 |
+
output['sympy_format'] = sympy_format
|
| 807 |
+
output['lambda_format'] = lambda_format
|
| 808 |
+
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
| 809 |
+
if output_jax_format:
|
| 810 |
+
output_cols += ['jax_format']
|
| 811 |
+
output['jax_format'] = jax_format
|
| 812 |
|
| 813 |
+
ret_outputs.append(output[output_cols])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
|
| 815 |
+
if multioutput:
|
| 816 |
+
return ret_outputs
|
| 817 |
+
else:
|
| 818 |
+
return ret_outputs[0]
|
| 819 |
|
| 820 |
def best_row(equations=None):
|
| 821 |
"""Return the best row of a hall of fame file using the score column.
|
| 822 |
By default this uses the last equation file.
|
| 823 |
"""
|
| 824 |
if equations is None: equations = get_hof()
|
| 825 |
+
if isinstance(equations, list):
|
| 826 |
+
return [eq.iloc[np.argmax(eq['score'])] for eq in equations]
|
| 827 |
+
else:
|
| 828 |
+
return equations.iloc[np.argmax(equations['score'])]
|
| 829 |
|
| 830 |
def best_tex(equations=None):
|
| 831 |
"""Return the equation with the best score, in latex format
|
| 832 |
By default this uses the last equation file.
|
| 833 |
"""
|
| 834 |
if equations is None: equations = get_hof()
|
| 835 |
+
if isinstance(equations, list):
|
| 836 |
+
return [sympy.latex(best_row(eq)['sympy_format'].simplify()) for eq in equations]
|
| 837 |
+
else:
|
| 838 |
+
return sympy.latex(best_row(equations)['sympy_format'].simplify())
|
| 839 |
|
| 840 |
def best(equations=None):
|
| 841 |
"""Return the equation with the best score, in sympy format.
|
| 842 |
By default this uses the last equation file.
|
| 843 |
"""
|
| 844 |
if equations is None: equations = get_hof()
|
| 845 |
+
if isinstance(equations, list):
|
| 846 |
+
return [best_row(eq)['sympy_format'].simplify() for eq in equations]
|
| 847 |
+
else:
|
| 848 |
+
return best_row(equations)['sympy_format'].simplify()
|
| 849 |
|
| 850 |
def best_callable(equations=None):
|
| 851 |
"""Return the equation with the best score, in callable format.
|
| 852 |
By default this uses the last equation file.
|
| 853 |
"""
|
| 854 |
if equations is None: equations = get_hof()
|
| 855 |
+
if isinstance(equations, list):
|
| 856 |
+
return [best_row(eq)['lambda_format'] for eq in equations]
|
| 857 |
+
else:
|
| 858 |
+
return best_row(equations)['lambda_format']
|
| 859 |
|
| 860 |
def _escape_filename(filename):
|
| 861 |
"""Turns a file into a string representation with correctly escaped backslashes"""
|
test/test.py
CHANGED
|
@@ -17,14 +17,15 @@ equations = pysr(X, y, **default_test_kwargs)
|
|
| 17 |
print(equations)
|
| 18 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
| 19 |
|
| 20 |
-
print("Test 2 - test custom operator")
|
| 21 |
-
y = X[:, 0]**2
|
| 22 |
equations = pysr(X, y,
|
| 23 |
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
|
| 24 |
extra_sympy_mappings={'square': lambda x: x**2},
|
| 25 |
**default_test_kwargs)
|
| 26 |
print(equations)
|
| 27 |
-
assert equations.iloc[-1]['MSE'] < 1e-4
|
|
|
|
| 28 |
|
| 29 |
X = np.random.randn(100, 1)
|
| 30 |
y = X[:, 0] + 3.0
|
|
|
|
| 17 |
print(equations)
|
| 18 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
| 19 |
|
| 20 |
+
print("Test 2 - test custom operator, and multiple outputs")
|
| 21 |
+
y = X[:, [0, 1]]**2
|
| 22 |
equations = pysr(X, y,
|
| 23 |
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
|
| 24 |
extra_sympy_mappings={'square': lambda x: x**2},
|
| 25 |
**default_test_kwargs)
|
| 26 |
print(equations)
|
| 27 |
+
assert equations[0].iloc[-1]['MSE'] < 1e-4
|
| 28 |
+
assert equations[1].iloc[-1]['MSE'] < 1e-4
|
| 29 |
|
| 30 |
X = np.random.randn(100, 1)
|
| 31 |
y = X[:, 0] + 3.0
|