Spaces:
Sleeping
Sleeping
Commit
·
86dd9ce
1
Parent(s):
9068541
Fix output_torch_format option for pysr
Browse files- pysr/sr.py +2 -2
pysr/sr.py
CHANGED
|
@@ -800,8 +800,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 800 |
=======
|
| 801 |
if output_torch_format:
|
| 802 |
from .export_torch import sympy2torch
|
| 803 |
-
|
| 804 |
-
torch_format.append(
|
| 805 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
| 806 |
>>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
|
| 807 |
curMSE = output.loc[i, 'MSE']
|
|
|
|
| 800 |
=======
|
| 801 |
if output_torch_format:
|
| 802 |
from .export_torch import sympy2torch
|
| 803 |
+
module = sympy2torch(eqn, sympy_symbols)
|
| 804 |
+
torch_format.append(module)
|
| 805 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
| 806 |
>>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
|
| 807 |
curMSE = output.loc[i, 'MSE']
|