Spaces:
Running
Running
Commit
·
a15823e
1
Parent(s):
c41cf33
Reduce precision of JAX tests
Browse files- test/test_jax.py +3 -3
test/test_jax.py
CHANGED
|
@@ -49,7 +49,7 @@ class TestJAX(unittest.TestCase):
|
|
| 49 |
np.testing.assert_almost_equal(
|
| 50 |
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
| 51 |
np.square(np.cos(X.values[:, 1])), # Select feature 1
|
| 52 |
-
decimal=
|
| 53 |
)
|
| 54 |
|
| 55 |
def test_pipeline(self):
|
|
@@ -110,5 +110,5 @@ class TestJAX(unittest.TestCase):
|
|
| 110 |
np_output = np_prediction(X.values)
|
| 111 |
jax_output = jax_prediction(X.values)
|
| 112 |
|
| 113 |
-
np.testing.assert_almost_equal(y.values, np_output, decimal=
|
| 114 |
-
np.testing.assert_almost_equal(y.values, jax_output, decimal=
|
|
|
|
| 49 |
np.testing.assert_almost_equal(
|
| 50 |
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
| 51 |
np.square(np.cos(X.values[:, 1])), # Select feature 1
|
| 52 |
+
decimal=3,
|
| 53 |
)
|
| 54 |
|
| 55 |
def test_pipeline(self):
|
|
|
|
| 110 |
np_output = np_prediction(X.values)
|
| 111 |
jax_output = jax_prediction(X.values)
|
| 112 |
|
| 113 |
+
np.testing.assert_almost_equal(y.values, np_output, decimal=3)
|
| 114 |
+
np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
|