Spaces:
Running
Running
Commit
·
d0788ef
1
Parent(s):
44216ab
Fix syntax error in JAX converter
Browse files- pysr/export.py +1 -1
- pysr/sr.py +2 -2
pysr/export.py
CHANGED
|
@@ -62,7 +62,7 @@ def sympy2jaxtext(expr, parameters, symbols_in):
|
|
| 62 |
parameters.append(float(expr))
|
| 63 |
return f"parameters[{len(parameters) - 1}]"
|
| 64 |
elif issubclass(expr.func, sympy.Integer):
|
| 65 |
-
return "{int(expr)}"
|
| 66 |
elif issubclass(expr.func, sympy.Symbol):
|
| 67 |
return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 68 |
else:
|
|
|
|
| 62 |
parameters.append(float(expr))
|
| 63 |
return f"parameters[{len(parameters) - 1}]"
|
| 64 |
elif issubclass(expr.func, sympy.Integer):
|
| 65 |
+
return f"{int(expr)}"
|
| 66 |
elif issubclass(expr.func, sympy.Symbol):
|
| 67 |
return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 68 |
else:
|
pysr/sr.py
CHANGED
|
@@ -686,7 +686,7 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 686 |
sympy_format.append(eqn)
|
| 687 |
if output_jax_format:
|
| 688 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 689 |
-
jax_format.append({'callable': func, 'parameters':
|
| 690 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
| 691 |
curMSE = output.loc[i, 'MSE']
|
| 692 |
curComplexity = output.loc[i, 'Complexity']
|
|
@@ -705,7 +705,7 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 705 |
output['lambda_format'] = lambda_format
|
| 706 |
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
| 707 |
if output_jax_format:
|
| 708 |
-
output_cols += 'jax_format'
|
| 709 |
output['jax_format'] = jax_format
|
| 710 |
|
| 711 |
return output[output_cols]
|
|
|
|
| 686 |
sympy_format.append(eqn)
|
| 687 |
if output_jax_format:
|
| 688 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 689 |
+
jax_format.append({'callable': func, 'parameters': params})
|
| 690 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
| 691 |
curMSE = output.loc[i, 'MSE']
|
| 692 |
curComplexity = output.loc[i, 'Complexity']
|
|
|
|
| 705 |
output['lambda_format'] = lambda_format
|
| 706 |
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
| 707 |
if output_jax_format:
|
| 708 |
+
output_cols += ['jax_format']
|
| 709 |
output['jax_format'] = jax_format
|
| 710 |
|
| 711 |
return output[output_cols]
|