Spaces:
Running
Running
Commit
·
e0c68fc
1
Parent(s):
a29e818
Propagate and check torch/jax mappings
Browse files- pysr/export_jax.py +14 -5
- pysr/sr.py +26 -2
pysr/export_jax.py
CHANGED
|
@@ -51,7 +51,7 @@ _jnp_func_lookup = {
|
|
| 51 |
}
|
| 52 |
|
| 53 |
|
| 54 |
-
def sympy2jaxtext(expr, parameters, symbols_in):
|
| 55 |
if issubclass(expr.func, sympy.Float):
|
| 56 |
parameters.append(float(expr))
|
| 57 |
return f"parameters[{len(parameters) - 1}]"
|
|
@@ -61,8 +61,15 @@ def sympy2jaxtext(expr, parameters, symbols_in):
|
|
| 61 |
return (
|
| 62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 63 |
)
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
if _func == MUL:
|
| 67 |
return " * ".join(["(" + arg + ")" for arg in args])
|
| 68 |
if _func == ADD:
|
|
@@ -92,7 +99,7 @@ def _initialize_jax():
|
|
| 92 |
jsp = _jsp
|
| 93 |
|
| 94 |
|
| 95 |
-
def sympy2jax(expression, symbols_in, selection=None):
|
| 96 |
"""Returns a function f and its parameters;
|
| 97 |
the function takes an input matrix, and a list of arguments:
|
| 98 |
f(X, parameters)
|
|
@@ -170,7 +177,9 @@ def sympy2jax(expression, symbols_in, selection=None):
|
|
| 170 |
global jsp
|
| 171 |
|
| 172 |
parameters = []
|
| 173 |
-
functional_form_text = sympy2jaxtext(
|
|
|
|
|
|
|
| 174 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
| 175 |
text = f"def {hash_string}(X, parameters):\n"
|
| 176 |
if selection is not None:
|
|
|
|
| 51 |
}
|
| 52 |
|
| 53 |
|
| 54 |
+
def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
|
| 55 |
if issubclass(expr.func, sympy.Float):
|
| 56 |
parameters.append(float(expr))
|
| 57 |
return f"parameters[{len(parameters) - 1}]"
|
|
|
|
| 61 |
return (
|
| 62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 63 |
)
|
| 64 |
+
if extra_jax_mappings is None:
|
| 65 |
+
extra_jax_mappings = {}
|
| 66 |
+
_func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
|
| 67 |
+
args = [
|
| 68 |
+
sympy2jaxtext(
|
| 69 |
+
arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
|
| 70 |
+
)
|
| 71 |
+
for arg in expr.args
|
| 72 |
+
]
|
| 73 |
if _func == MUL:
|
| 74 |
return " * ".join(["(" + arg + ")" for arg in args])
|
| 75 |
if _func == ADD:
|
|
|
|
| 99 |
jsp = _jsp
|
| 100 |
|
| 101 |
|
| 102 |
+
def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
|
| 103 |
"""Returns a function f and its parameters;
|
| 104 |
the function takes an input matrix, and a list of arguments:
|
| 105 |
f(X, parameters)
|
|
|
|
| 177 |
global jsp
|
| 178 |
|
| 179 |
parameters = []
|
| 180 |
+
functional_form_text = sympy2jaxtext(
|
| 181 |
+
expression, parameters, symbols_in, extra_jax_mappings
|
| 182 |
+
)
|
| 183 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
| 184 |
text = f"def {hash_string}(X, parameters):\n"
|
| 185 |
if selection is not None:
|
pysr/sr.py
CHANGED
|
@@ -289,6 +289,20 @@ def pysr(
|
|
| 289 |
if len(variable_names) == 0:
|
| 290 |
variable_names = [f"x{i}" for i in range(X.shape[1])]
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
use_custom_variable_names = len(variable_names) != 0
|
| 293 |
|
| 294 |
_check_assertions(
|
|
@@ -996,14 +1010,24 @@ def get_hof(
|
|
| 996 |
if output_jax_format:
|
| 997 |
from .export_jax import sympy2jax
|
| 998 |
|
| 999 |
-
func, params = sympy2jax(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
jax_format.append({"callable": func, "parameters": params})
|
| 1001 |
|
| 1002 |
# Torch:
|
| 1003 |
if output_torch_format:
|
| 1004 |
from .export_torch import sympy2torch
|
| 1005 |
|
| 1006 |
-
module = sympy2torch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1007 |
torch_format.append(module)
|
| 1008 |
|
| 1009 |
curMSE = output.loc[i, "MSE"]
|
|
|
|
| 289 |
if len(variable_names) == 0:
|
| 290 |
variable_names = [f"x{i}" for i in range(X.shape[1])]
|
| 291 |
|
| 292 |
+
if extra_jax_mappings is not None:
|
| 293 |
+
for key, value in extra_jax_mappings:
|
| 294 |
+
if not isinstance(value, str):
|
| 295 |
+
raise NotImplementedError(
|
| 296 |
+
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if extra_torch_mappings is not None:
|
| 300 |
+
for key, value in extra_jax_mappings:
|
| 301 |
+
if not callable(value):
|
| 302 |
+
raise NotImplementedError(
|
| 303 |
+
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
use_custom_variable_names = len(variable_names) != 0
|
| 307 |
|
| 308 |
_check_assertions(
|
|
|
|
| 1010 |
if output_jax_format:
|
| 1011 |
from .export_jax import sympy2jax
|
| 1012 |
|
| 1013 |
+
func, params = sympy2jax(
|
| 1014 |
+
eqn,
|
| 1015 |
+
sympy_symbols,
|
| 1016 |
+
selection=selection,
|
| 1017 |
+
extra_jax_mappings=extra_jax_mappings,
|
| 1018 |
+
)
|
| 1019 |
jax_format.append({"callable": func, "parameters": params})
|
| 1020 |
|
| 1021 |
# Torch:
|
| 1022 |
if output_torch_format:
|
| 1023 |
from .export_torch import sympy2torch
|
| 1024 |
|
| 1025 |
+
module = sympy2torch(
|
| 1026 |
+
eqn,
|
| 1027 |
+
sympy_symbols,
|
| 1028 |
+
selection=selection,
|
| 1029 |
+
extra_torch_mappings=extra_torch_mappings,
|
| 1030 |
+
)
|
| 1031 |
torch_format.append(module)
|
| 1032 |
|
| 1033 |
curMSE = output.loc[i, "MSE"]
|