Spaces:
Sleeping
Sleeping
Commit
·
04f3f2f
1
Parent(s):
5db0d89
Fully document jax/torch export
Browse files- docs/options.md +39 -2
docs/options.md
CHANGED
|
@@ -15,7 +15,8 @@ may find useful include:
|
|
| 15 |
- `batching`, `batchSize`
|
| 16 |
- `variable_names` (or pandas input)
|
| 17 |
- Constraining operator complexity
|
| 18 |
-
- LaTeX, SymPy
|
|
|
|
| 19 |
- `loss`
|
| 20 |
|
| 21 |
These are described below
|
|
@@ -144,7 +145,7 @@ The other terms say that each multiplication can only have sub-expressions
|
|
| 144 |
of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
|
| 145 |
expressions of complexity 5 (e.g., 5.0 + x2 exp(x3)).
|
| 146 |
|
| 147 |
-
## LaTeX, SymPy
|
| 148 |
|
| 149 |
The `pysr` command will return a pandas dataframe. The `sympy_format`
|
| 150 |
column gives sympy equations, and the `lambda_format` gives callable
|
|
@@ -159,6 +160,42 @@ for the best equation, using the `score` column to sort equations.
|
|
| 159 |
`best_latex()` returns the LaTeX form of this, and `best_callable()`
|
| 160 |
returns a callable function.
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
## `loss`
|
| 163 |
|
| 164 |
The default loss is mean-square error, and weighted mean-square error.
|
|
|
|
| 15 |
- `batching`, `batchSize`
|
| 16 |
- `variable_names` (or pandas input)
|
| 17 |
- Constraining operator complexity
|
| 18 |
+
- LaTeX, SymPy
|
| 19 |
+
- Callable exports: numpy, pytorch, jax
|
| 20 |
- `loss`
|
| 21 |
|
| 22 |
These are described below
|
|
|
|
| 145 |
of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
|
| 146 |
expressions of complexity 5 (e.g., 5.0 + x2 exp(x3)).
|
| 147 |
|
| 148 |
+
## LaTeX, SymPy
|
| 149 |
|
| 150 |
The `pysr` command will return a pandas dataframe. The `sympy_format`
|
| 151 |
column gives sympy equations, and the `lambda_format` gives callable
|
|
|
|
| 160 |
`best_latex()` returns the LaTeX form of this, and `best_callable()`
|
| 161 |
returns a callable function.
|
| 162 |
|
| 163 |
+
|
| 164 |
+
## Callable exports: numpy, pytorch, jax
|
| 165 |
+
|
| 166 |
+
By default, the dataframe of equations will contain columns
|
| 167 |
+
with the identifier `lambda_format`. These are simple functions
|
| 168 |
+
which correspond to the equation, but executed
|
| 169 |
+
with numpy functions. You can pass your `X` matrix to these functions
|
| 170 |
+
just as you did to the `pysr` call. Thus, this allows
|
| 171 |
+
you to numerically evaluate the equations over different output.
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
One can do the same thing for PyTorch, which uses code
|
| 175 |
+
from [sympytorch](https://github.com/patrick-kidger/sympytorch),
|
| 176 |
+
and for JAX, which uses code from
|
| 177 |
+
[sympy2jax](https://github.com/MilesCranmer/sympy2jax).
|
| 178 |
+
|
| 179 |
+
For torch, set the argument `output_torch_format=True`, which
|
| 180 |
+
will generate a column `torch_format`. Each element of this column
|
| 181 |
+
is a PyTorch module which runs the equation, using PyTorch functions,
|
| 182 |
+
over `X` (as a PyTorch tensor). This is differentiable, and the
|
| 183 |
+
parameters of this PyTorch module correspond to the learned parameters
|
| 184 |
+
in the equation, and are trainable.
|
| 185 |
+
|
| 186 |
+
For jax, set the argument `output_jax_format=True`, which
|
| 187 |
+
will generate a column `jax_format`. Each element of this column
|
| 188 |
+
is a dictionary containing a `'callable'` (a JAX function),
|
| 189 |
+
and `'parameters'` (a list of parameters in the equation).
|
| 190 |
+
One can execute this function with: `element['callable'](X, element['parameters'])`.
|
| 191 |
+
Since the parameter list is a jax array, this therefore lets you also
|
| 192 |
+
train the parameters within JAX (and is differentiable).
|
| 193 |
+
|
| 194 |
+
If you forget to turn these on when calling the function initially,
|
| 195 |
+
you can re-run `get_hof(output_jax_format=True)`, and it will re-use
|
| 196 |
+
the equations and other state properties, assuming you haven't
|
| 197 |
+
re-run `pysr` in the meantime!
|
| 198 |
+
|
| 199 |
## `loss`
|
| 200 |
|
| 201 |
The default loss is mean-square error, and weighted mean-square error.
|