Spaces:
Running
Running
Commit
·
b4cb407
1
Parent(s):
ce5b119
Fix feature selection for JAX export
Browse files- pysr/export_jax.py +4 -1
- pysr/sr.py +1 -0
pysr/export_jax.py
CHANGED
|
@@ -109,7 +109,7 @@ def _initialize_jax():
|
|
| 109 |
jsp = _jsp
|
| 110 |
|
| 111 |
|
| 112 |
-
def sympy2jax(expression, symbols_in, extra_jax_mappings=None):
|
| 113 |
"""Returns a function f and its parameters;
|
| 114 |
the function takes an input matrix, and a list of arguments:
|
| 115 |
f(X, parameters)
|
|
@@ -192,6 +192,9 @@ def sympy2jax(expression, symbols_in, extra_jax_mappings=None):
|
|
| 192 |
)
|
| 193 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
| 194 |
text = f"def {hash_string}(X, parameters):\n"
|
|
|
|
|
|
|
|
|
|
| 195 |
text += " return "
|
| 196 |
text += functional_form_text
|
| 197 |
ldict = {}
|
|
|
|
| 109 |
jsp = _jsp
|
| 110 |
|
| 111 |
|
| 112 |
+
def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
|
| 113 |
"""Returns a function f and its parameters;
|
| 114 |
the function takes an input matrix, and a list of arguments:
|
| 115 |
f(X, parameters)
|
|
|
|
| 192 |
)
|
| 193 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
| 194 |
text = f"def {hash_string}(X, parameters):\n"
|
| 195 |
+
if selection is not None:
|
| 196 |
+
# Impose the feature selection:
|
| 197 |
+
text += f" X = X[:, {list(selection)}]\n"
|
| 198 |
text += " return "
|
| 199 |
text += functional_form_text
|
| 200 |
ldict = {}
|
pysr/sr.py
CHANGED
|
@@ -1740,6 +1740,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1740 |
func, params = sympy2jax(
|
| 1741 |
eqn,
|
| 1742 |
sympy_symbols,
|
|
|
|
| 1743 |
extra_jax_mappings=self.extra_jax_mappings,
|
| 1744 |
)
|
| 1745 |
jax_format.append({"callable": func, "parameters": params})
|
|
|
|
| 1740 |
func, params = sympy2jax(
|
| 1741 |
eqn,
|
| 1742 |
sympy_symbols,
|
| 1743 |
+
selection=self.selection_mask_,
|
| 1744 |
extra_jax_mappings=self.extra_jax_mappings,
|
| 1745 |
)
|
| 1746 |
jax_format.append({"callable": func, "parameters": params})
|