Spaces:
Sleeping
Sleeping
Commit
·
62d539c
1
Parent(s):
db11d11
Clean up anti-patterns
Browse files- pysr/sr.py +6 -9
pysr/sr.py
CHANGED
|
@@ -289,14 +289,14 @@ def pysr(
|
|
| 289 |
variable_names = [f"x{i}" for i in range(X.shape[1])]
|
| 290 |
|
| 291 |
if extra_jax_mappings is not None:
|
| 292 |
-
for
|
| 293 |
if not isinstance(value, str):
|
| 294 |
raise NotImplementedError(
|
| 295 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 296 |
)
|
| 297 |
|
| 298 |
if extra_torch_mappings is not None:
|
| 299 |
-
for
|
| 300 |
if not callable(value):
|
| 301 |
raise NotImplementedError(
|
| 302 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
|
@@ -797,8 +797,7 @@ def _handle_constraints(binary_operators, constraints, unary_operators, **kwargs
|
|
| 797 |
def _create_inline_operators(binary_operators, unary_operators, **kwargs):
|
| 798 |
def_hyperparams = ""
|
| 799 |
for op_list in [binary_operators, unary_operators]:
|
| 800 |
-
for i in
|
| 801 |
-
op = op_list[i]
|
| 802 |
is_user_defined_operator = "(" in op
|
| 803 |
|
| 804 |
if is_user_defined_operator:
|
|
@@ -806,8 +805,8 @@ def _create_inline_operators(binary_operators, unary_operators, **kwargs):
|
|
| 806 |
# Cut off from the first non-alphanumeric char:
|
| 807 |
first_non_char = [
|
| 808 |
j
|
| 809 |
-
for j in
|
| 810 |
-
if not (
|
| 811 |
][0]
|
| 812 |
function_name = op[:first_non_char]
|
| 813 |
op_list[i] = function_name
|
|
@@ -823,9 +822,7 @@ def _handle_feature_selection(
|
|
| 823 |
X = X[:, selection]
|
| 824 |
|
| 825 |
if use_custom_variable_names:
|
| 826 |
-
variable_names = [
|
| 827 |
-
variable_names[selection[i]] for i in range(len(selection))
|
| 828 |
-
]
|
| 829 |
else:
|
| 830 |
selection = None
|
| 831 |
return X, variable_names, selection
|
|
|
|
| 289 |
variable_names = [f"x{i}" for i in range(X.shape[1])]
|
| 290 |
|
| 291 |
if extra_jax_mappings is not None:
|
| 292 |
+
for value in extra_jax_mappings.values():
|
| 293 |
if not isinstance(value, str):
|
| 294 |
raise NotImplementedError(
|
| 295 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 296 |
)
|
| 297 |
|
| 298 |
if extra_torch_mappings is not None:
|
| 299 |
+
for value in extra_jax_mappings.values():
|
| 300 |
if not callable(value):
|
| 301 |
raise NotImplementedError(
|
| 302 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
|
|
|
| 797 |
def _create_inline_operators(binary_operators, unary_operators, **kwargs):
|
| 798 |
def_hyperparams = ""
|
| 799 |
for op_list in [binary_operators, unary_operators]:
|
| 800 |
+
for i, op in enumerate(op_list):
|
|
|
|
| 801 |
is_user_defined_operator = "(" in op
|
| 802 |
|
| 803 |
if is_user_defined_operator:
|
|
|
|
| 805 |
# Cut off from the first non-alphanumeric char:
|
| 806 |
first_non_char = [
|
| 807 |
j
|
| 808 |
+
for j, char in enumerate(op)
|
| 809 |
+
if not (char.isalpha() or char.isdigit())
|
| 810 |
][0]
|
| 811 |
function_name = op[:first_non_char]
|
| 812 |
op_list[i] = function_name
|
|
|
|
| 822 |
X = X[:, selection]
|
| 823 |
|
| 824 |
if use_custom_variable_names:
|
| 825 |
+
variable_names = [variable_names[i] for i in selection]
|
|
|
|
|
|
|
| 826 |
else:
|
| 827 |
selection = None
|
| 828 |
return X, variable_names, selection
|