Spaces:
Sleeping
Sleeping
Commit
·
a0c6429
1
Parent(s):
90d24f5
Fix JAX test
Browse files- test/test_jax.py +6 -7
test/test_jax.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import unittest
|
| 2 |
import numpy as np
|
| 3 |
-
from pysr import sympy2jax,
|
| 4 |
import pandas as pd
|
| 5 |
from jax import numpy as jnp
|
| 6 |
from jax import random
|
|
@@ -35,18 +35,17 @@ class TestJAX(unittest.TestCase):
|
|
| 35 |
"equation_file.csv.bkup", sep="|"
|
| 36 |
)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
"equation_file.csv",
|
| 40 |
-
n_features=2,
|
| 41 |
-
variables_names="x1 x2 x3".split(" "),
|
| 42 |
-
extra_sympy_mappings={},
|
| 43 |
output_jax_format=True,
|
|
|
|
| 44 |
multioutput=False,
|
| 45 |
nout=1,
|
| 46 |
selection=[1, 2, 3],
|
| 47 |
)
|
| 48 |
|
| 49 |
-
model =
|
|
|
|
| 50 |
jformat = model.jax()
|
| 51 |
|
| 52 |
np.testing.assert_almost_equal(
|
|
|
|
| 1 |
import unittest
|
| 2 |
import numpy as np
|
| 3 |
+
from pysr import sympy2jax, PySRRegressor
|
| 4 |
import pandas as pd
|
| 5 |
from jax import numpy as jnp
|
| 6 |
from jax import random
|
|
|
|
| 35 |
"equation_file.csv.bkup", sep="|"
|
| 36 |
)
|
| 37 |
|
| 38 |
+
model = PySRRegressor(
|
| 39 |
+
equation_file="equation_file.csv",
|
|
|
|
|
|
|
|
|
|
| 40 |
output_jax_format=True,
|
| 41 |
+
variables_names="x1 x2 x3".split(" "),
|
| 42 |
multioutput=False,
|
| 43 |
nout=1,
|
| 44 |
selection=[1, 2, 3],
|
| 45 |
)
|
| 46 |
|
| 47 |
+
model.n_features = 2
|
| 48 |
+
model.refresh()
|
| 49 |
jformat = model.jax()
|
| 50 |
|
| 51 |
np.testing.assert_almost_equal(
|