Spaces:
Running
Running
Commit
·
8c55475
1
Parent(s):
45b290b
Allow custom selection of X matrix in torch/jax modules
Browse files- pysr/export_jax.py +4 -1
- pysr/export_torch.py +9 -3
pysr/export_jax.py
CHANGED
|
@@ -90,7 +90,7 @@ def _initialize_jax():
|
|
| 90 |
jsp = _jsp
|
| 91 |
|
| 92 |
|
| 93 |
-
def sympy2jax(expression, symbols_in):
|
| 94 |
"""Returns a function f and its parameters;
|
| 95 |
the function takes an input matrix, and a list of arguments:
|
| 96 |
f(X, parameters)
|
|
@@ -171,6 +171,9 @@ def sympy2jax(expression, symbols_in):
|
|
| 171 |
functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
|
| 172 |
hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
|
| 173 |
text = f"def {hash_string}(X, parameters):\n"
|
|
|
|
|
|
|
|
|
|
| 174 |
text += " return "
|
| 175 |
text += functional_form_text
|
| 176 |
ldict = {}
|
|
|
|
| 90 |
jsp = _jsp
|
| 91 |
|
| 92 |
|
| 93 |
+
def sympy2jax(expression, symbols_in, selection=None):
|
| 94 |
"""Returns a function f and its parameters;
|
| 95 |
the function takes an input matrix, and a list of arguments:
|
| 96 |
f(X, parameters)
|
|
|
|
| 171 |
functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
|
| 172 |
hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
|
| 173 |
text = f"def {hash_string}(X, parameters):\n"
|
| 174 |
+
if selection is not None:
|
| 175 |
+
# Impose the feature selection:
|
| 176 |
+
text += f" X = X[:, {list(selection)}]"
|
| 177 |
text += " return "
|
| 178 |
text += functional_form_text
|
| 179 |
ldict = {}
|
pysr/export_torch.py
CHANGED
|
@@ -137,7 +137,7 @@ def _initialize_torch():
|
|
| 137 |
class SingleSymPyModule(torch.nn.Module):
|
| 138 |
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
|
| 139 |
def __init__(self, expression, symbols_in,
|
| 140 |
-
|
| 141 |
super().__init__(**kwargs)
|
| 142 |
|
| 143 |
if extra_funcs is None:
|
|
@@ -147,18 +147,22 @@ def _initialize_torch():
|
|
| 147 |
_memodict = {}
|
| 148 |
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
|
| 149 |
self._expression_string = str(expression)
|
|
|
|
| 150 |
self.symbols_in = [str(symbol) for symbol in symbols_in]
|
| 151 |
|
| 152 |
def __repr__(self):
|
| 153 |
return f"{type(self).__name__}(expression={self._expression_string})"
|
| 154 |
|
| 155 |
def forward(self, X):
|
|
|
|
|
|
|
| 156 |
symbols = {symbol: X[:, i]
|
| 157 |
for i, symbol in enumerate(self.symbols_in)}
|
| 158 |
return self._node(symbols)
|
| 159 |
|
| 160 |
|
| 161 |
-
def sympy2torch(expression, symbols_in,
|
|
|
|
| 162 |
"""Returns a module for a given sympy expression with trainable parameters;
|
| 163 |
|
| 164 |
This function will assume the input to the module is a matrix X, where
|
|
@@ -168,4 +172,6 @@ def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
|
|
| 168 |
|
| 169 |
_initialize_torch()
|
| 170 |
|
| 171 |
-
return SingleSymPyModule(expression, symbols_in,
|
|
|
|
|
|
|
|
|
| 137 |
class SingleSymPyModule(torch.nn.Module):
|
| 138 |
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
|
| 139 |
def __init__(self, expression, symbols_in,
|
| 140 |
+
selection=None, extra_funcs=None, **kwargs):
|
| 141 |
super().__init__(**kwargs)
|
| 142 |
|
| 143 |
if extra_funcs is None:
|
|
|
|
| 147 |
_memodict = {}
|
| 148 |
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
|
| 149 |
self._expression_string = str(expression)
|
| 150 |
+
self._selection = selection
|
| 151 |
self.symbols_in = [str(symbol) for symbol in symbols_in]
|
| 152 |
|
| 153 |
def __repr__(self):
|
| 154 |
return f"{type(self).__name__}(expression={self._expression_string})"
|
| 155 |
|
| 156 |
def forward(self, X):
|
| 157 |
+
if self._selection is not None:
|
| 158 |
+
X = X[:, self._selection]
|
| 159 |
symbols = {symbol: X[:, i]
|
| 160 |
for i, symbol in enumerate(self.symbols_in)}
|
| 161 |
return self._node(symbols)
|
| 162 |
|
| 163 |
|
| 164 |
+
def sympy2torch(expression, symbols_in,
|
| 165 |
+
selection=None, extra_torch_mappings=None):
|
| 166 |
"""Returns a module for a given sympy expression with trainable parameters;
|
| 167 |
|
| 168 |
This function will assume the input to the module is a matrix X, where
|
|
|
|
| 172 |
|
| 173 |
_initialize_torch()
|
| 174 |
|
| 175 |
+
return SingleSymPyModule(expression, symbols_in,
|
| 176 |
+
selection=selection,
|
| 177 |
+
extra_funcs=extra_torch_mappings)
|