Spaces:
Running
Running
Commit
·
4d915b2
1
Parent(s):
09b1cf7
Change lambda_format to same format as torch/jax
Browse files- pysr/sr.py +2 -1
pysr/sr.py
CHANGED
|
@@ -795,7 +795,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 795 |
if output_jax_format:
|
| 796 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 797 |
jax_format.append({'callable': func, 'parameters': params})
|
| 798 |
-
|
|
|
|
| 799 |
curMSE = output.loc[i, 'MSE']
|
| 800 |
curComplexity = output.loc[i, 'Complexity']
|
| 801 |
|
|
|
|
| 795 |
if output_jax_format:
|
| 796 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 797 |
jax_format.append({'callable': func, 'parameters': params})
|
| 798 |
+
tmp_lambda = lambdify(sympy_symbols, eqn)
|
| 799 |
+
lambda_format.append(lambda X: tmp_lambda(*X.T))
|
| 800 |
curMSE = output.loc[i, 'MSE']
|
| 801 |
curComplexity = output.loc[i, 'Complexity']
|
| 802 |
|