Spaces:
Sleeping
Sleeping
Commit
·
e7ede78
1
Parent(s):
a88a169
Allow user to pass extra torch operators
Browse files- pysr/export_torch.py +2 -2
pysr/export_torch.py
CHANGED
|
@@ -160,7 +160,7 @@ def _initialize_torch():
|
|
| 160 |
return self._node(symbols)
|
| 161 |
|
| 162 |
|
| 163 |
-
def sympy2torch(expression, symbols_in):
|
| 164 |
"""Returns a module for a given sympy expression with trainable parameters;
|
| 165 |
|
| 166 |
This function will assume the input to the module is a matrix X, where
|
|
@@ -170,4 +170,4 @@ def sympy2torch(expression, symbols_in):
|
|
| 170 |
|
| 171 |
_initialize_torch()
|
| 172 |
|
| 173 |
-
return SingleSymPyModule(expression, symbols_in)
|
|
|
|
| 160 |
return self._node(symbols)
|
| 161 |
|
| 162 |
|
| 163 |
+
def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
|
| 164 |
"""Returns a module for a given sympy expression with trainable parameters;
|
| 165 |
|
| 166 |
This function will assume the input to the module is a matrix X, where
|
|
|
|
| 170 |
|
| 171 |
_initialize_torch()
|
| 172 |
|
| 173 |
+
return SingleSymPyModule(expression, symbols_in, extra_funcs=extra_torch_mappings)
|