Spaces:
Running
Running
deepsource-autofix[bot]
commited on
Refactor unnecessary `else` / `elif` when `if` block has a `return` statement
Browse files- pysr/export_jax.py +9 -11
- pysr/sr.py +8 -15
pysr/export_jax.py
CHANGED
|
@@ -55,21 +55,19 @@ def sympy2jaxtext(expr, parameters, symbols_in):
|
|
| 55 |
if issubclass(expr.func, sympy.Float):
|
| 56 |
parameters.append(float(expr))
|
| 57 |
return f"parameters[{len(parameters) - 1}]"
|
| 58 |
-
|
| 59 |
return f"{int(expr)}"
|
| 60 |
-
|
| 61 |
return (
|
| 62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 63 |
)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
else:
|
| 72 |
-
return f'{_func}({", ".join(args)})'
|
| 73 |
|
| 74 |
|
| 75 |
jax_initialized = False
|
|
|
|
| 55 |
if issubclass(expr.func, sympy.Float):
|
| 56 |
parameters.append(float(expr))
|
| 57 |
return f"parameters[{len(parameters) - 1}]"
|
| 58 |
+
if issubclass(expr.func, sympy.Integer):
|
| 59 |
return f"{int(expr)}"
|
| 60 |
+
if issubclass(expr.func, sympy.Symbol):
|
| 61 |
return (
|
| 62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 63 |
)
|
| 64 |
+
_func = _jnp_func_lookup[expr.func]
|
| 65 |
+
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
|
| 66 |
+
if _func == MUL:
|
| 67 |
+
return " * ".join(["(" + arg + ")" for arg in args])
|
| 68 |
+
if _func == ADD:
|
| 69 |
+
return " + ".join(["(" + arg + ")" for arg in args])
|
| 70 |
+
return f'{_func}({", ".join(args)})'
|
|
|
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
jax_initialized = False
|
pysr/sr.py
CHANGED
|
@@ -643,10 +643,9 @@ def _make_hyperparams_julia_str(
|
|
| 643 |
def tuple_fix(ops):
|
| 644 |
if len(ops) > 1:
|
| 645 |
return ", ".join(ops)
|
| 646 |
-
|
| 647 |
return ""
|
| 648 |
-
|
| 649 |
-
return ops[0] + ","
|
| 650 |
|
| 651 |
def_hyperparams += f"""\n
|
| 652 |
plus=(+)
|
|
@@ -1025,8 +1024,7 @@ def get_hof(
|
|
| 1025 |
|
| 1026 |
if multioutput:
|
| 1027 |
return ret_outputs
|
| 1028 |
-
|
| 1029 |
-
return ret_outputs[0]
|
| 1030 |
|
| 1031 |
|
| 1032 |
def best_row(equations=None):
|
|
@@ -1037,8 +1035,7 @@ def best_row(equations=None):
|
|
| 1037 |
equations = get_hof()
|
| 1038 |
if isinstance(equations, list):
|
| 1039 |
return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
|
| 1040 |
-
|
| 1041 |
-
return equations.iloc[np.argmax(equations["score"])]
|
| 1042 |
|
| 1043 |
|
| 1044 |
def best_tex(equations=None):
|
|
@@ -1051,8 +1048,7 @@ def best_tex(equations=None):
|
|
| 1051 |
return [
|
| 1052 |
sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
|
| 1053 |
]
|
| 1054 |
-
|
| 1055 |
-
return sympy.latex(best_row(equations)["sympy_format"].simplify())
|
| 1056 |
|
| 1057 |
|
| 1058 |
def best(equations=None):
|
|
@@ -1063,8 +1059,7 @@ def best(equations=None):
|
|
| 1063 |
equations = get_hof()
|
| 1064 |
if isinstance(equations, list):
|
| 1065 |
return [best_row(eq)["sympy_format"].simplify() for eq in equations]
|
| 1066 |
-
|
| 1067 |
-
return best_row(equations)["sympy_format"].simplify()
|
| 1068 |
|
| 1069 |
|
| 1070 |
def best_callable(equations=None):
|
|
@@ -1075,8 +1070,7 @@ def best_callable(equations=None):
|
|
| 1075 |
equations = get_hof()
|
| 1076 |
if isinstance(equations, list):
|
| 1077 |
return [best_row(eq)["lambda_format"] for eq in equations]
|
| 1078 |
-
|
| 1079 |
-
return best_row(equations)["lambda_format"]
|
| 1080 |
|
| 1081 |
|
| 1082 |
def _escape_filename(filename):
|
|
@@ -1114,5 +1108,4 @@ class CallableEquation(object):
|
|
| 1114 |
def __call__(self, X):
|
| 1115 |
if self._selection is not None:
|
| 1116 |
return self._lambda(*X[:, self._selection].T)
|
| 1117 |
-
|
| 1118 |
-
return self._lambda(*X.T)
|
|
|
|
| 643 |
def tuple_fix(ops):
|
| 644 |
if len(ops) > 1:
|
| 645 |
return ", ".join(ops)
|
| 646 |
+
if len(ops) == 0:
|
| 647 |
return ""
|
| 648 |
+
return ops[0] + ","
|
|
|
|
| 649 |
|
| 650 |
def_hyperparams += f"""\n
|
| 651 |
plus=(+)
|
|
|
|
| 1024 |
|
| 1025 |
if multioutput:
|
| 1026 |
return ret_outputs
|
| 1027 |
+
return ret_outputs[0]
|
|
|
|
| 1028 |
|
| 1029 |
|
| 1030 |
def best_row(equations=None):
|
|
|
|
| 1035 |
equations = get_hof()
|
| 1036 |
if isinstance(equations, list):
|
| 1037 |
return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
|
| 1038 |
+
return equations.iloc[np.argmax(equations["score"])]
|
|
|
|
| 1039 |
|
| 1040 |
|
| 1041 |
def best_tex(equations=None):
|
|
|
|
| 1048 |
return [
|
| 1049 |
sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
|
| 1050 |
]
|
| 1051 |
+
return sympy.latex(best_row(equations)["sympy_format"].simplify())
|
|
|
|
| 1052 |
|
| 1053 |
|
| 1054 |
def best(equations=None):
|
|
|
|
| 1059 |
equations = get_hof()
|
| 1060 |
if isinstance(equations, list):
|
| 1061 |
return [best_row(eq)["sympy_format"].simplify() for eq in equations]
|
| 1062 |
+
return best_row(equations)["sympy_format"].simplify()
|
|
|
|
| 1063 |
|
| 1064 |
|
| 1065 |
def best_callable(equations=None):
|
|
|
|
| 1070 |
equations = get_hof()
|
| 1071 |
if isinstance(equations, list):
|
| 1072 |
return [best_row(eq)["lambda_format"] for eq in equations]
|
| 1073 |
+
return best_row(equations)["lambda_format"]
|
|
|
|
| 1074 |
|
| 1075 |
|
| 1076 |
def _escape_filename(filename):
|
|
|
|
| 1108 |
def __call__(self, X):
|
| 1109 |
if self._selection is not None:
|
| 1110 |
return self._lambda(*X[:, self._selection].T)
|
| 1111 |
+
return self._lambda(*X.T)
|
|
|