Spaces:
Running
Running
Try to fix nb sanitizer
Browse files- pysr/_cli/main.py +15 -6
- pysr/test/test.py +6 -5
- pysr/test/test_cli.py +6 -2
- pysr/test/test_dev.py +6 -2
- pysr/test/test_jax.py +6 -2
- pysr/test/test_startup.py +6 -2
- pysr/test/test_torch.py +6 -2
pysr/_cli/main.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import warnings
|
| 2 |
|
| 3 |
import click
|
|
@@ -55,19 +56,27 @@ def _tests(tests):
|
|
| 55 |
|
| 56 |
Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
|
| 57 |
"""
|
|
|
|
| 58 |
for test in tests.split(","):
|
| 59 |
if test == "main":
|
| 60 |
-
runtests()
|
| 61 |
elif test == "jax":
|
| 62 |
-
runtests_jax()
|
| 63 |
elif test == "torch":
|
| 64 |
-
runtests_torch()
|
| 65 |
elif test == "cli":
|
| 66 |
runtests_cli = get_runtests_cli()
|
| 67 |
-
runtests_cli()
|
| 68 |
elif test == "dev":
|
| 69 |
-
runtests_dev()
|
| 70 |
elif test == "startup":
|
| 71 |
-
runtests_startup()
|
| 72 |
else:
|
| 73 |
warnings.warn(f"Invalid test {test}. Skipping.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
import warnings
|
| 3 |
|
| 4 |
import click
|
|
|
|
| 56 |
|
| 57 |
Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
|
| 58 |
"""
|
| 59 |
+
test_cases = []
|
| 60 |
for test in tests.split(","):
|
| 61 |
if test == "main":
|
| 62 |
+
test_cases.extend(runtests(just_tests=True))
|
| 63 |
elif test == "jax":
|
| 64 |
+
test_cases.extend(runtests_jax(just_tests=True))
|
| 65 |
elif test == "torch":
|
| 66 |
+
test_cases.extend(runtests_torch(just_tests=True))
|
| 67 |
elif test == "cli":
|
| 68 |
runtests_cli = get_runtests_cli()
|
| 69 |
+
test_cases.extend(runtests_cli(just_tests=True))
|
| 70 |
elif test == "dev":
|
| 71 |
+
test_cases.extend(runtests_dev(just_tests=True))
|
| 72 |
elif test == "startup":
|
| 73 |
+
test_cases.extend(runtests_startup(just_tests=True))
|
| 74 |
else:
|
| 75 |
warnings.warn(f"Invalid test {test}. Skipping.")
|
| 76 |
+
|
| 77 |
+
loader = unittest.TestLoader()
|
| 78 |
+
suite = unittest.TestSuite()
|
| 79 |
+
for test_case in test_cases:
|
| 80 |
+
suite.addTests(loader.loadTestsFromTestCase(test_case))
|
| 81 |
+
runner = unittest.TextTestRunner()
|
| 82 |
+
return runner.run(suite)
|
pysr/test/test.py
CHANGED
|
@@ -1127,10 +1127,8 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
| 1127 |
# TODO: Determine desired behavior if second .fit() call does not have units
|
| 1128 |
|
| 1129 |
|
| 1130 |
-
def runtests():
|
| 1131 |
"""Run all tests in test.py."""
|
| 1132 |
-
suite = unittest.TestSuite()
|
| 1133 |
-
loader = unittest.TestLoader()
|
| 1134 |
test_cases = [
|
| 1135 |
TestPipeline,
|
| 1136 |
TestBest,
|
|
@@ -1139,8 +1137,11 @@ def runtests():
|
|
| 1139 |
TestLaTeXTable,
|
| 1140 |
TestDimensionalConstraints,
|
| 1141 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1142 |
for test_case in test_cases:
|
| 1143 |
-
|
| 1144 |
-
suite.addTests(tests)
|
| 1145 |
runner = unittest.TextTestRunner()
|
| 1146 |
return runner.run(suite)
|
|
|
|
| 1127 |
# TODO: Determine desired behavior if second .fit() call does not have units
|
| 1128 |
|
| 1129 |
|
| 1130 |
+
def runtests(just_tests=False):
|
| 1131 |
"""Run all tests in test.py."""
|
|
|
|
|
|
|
| 1132 |
test_cases = [
|
| 1133 |
TestPipeline,
|
| 1134 |
TestBest,
|
|
|
|
| 1137 |
TestLaTeXTable,
|
| 1138 |
TestDimensionalConstraints,
|
| 1139 |
]
|
| 1140 |
+
if just_tests:
|
| 1141 |
+
return test_cases
|
| 1142 |
+
suite = unittest.TestSuite()
|
| 1143 |
+
loader = unittest.TestLoader()
|
| 1144 |
for test_case in test_cases:
|
| 1145 |
+
suite.addTests(loader.loadTestsFromTestCase(test_case))
|
|
|
|
| 1146 |
runner = unittest.TextTestRunner()
|
| 1147 |
return runner.run(suite)
|
pysr/test/test_cli.py
CHANGED
|
@@ -68,11 +68,15 @@ def get_runtests():
|
|
| 68 |
self.assertEqual(result.output.strip(), expected.strip())
|
| 69 |
self.assertEqual(result.exit_code, 0)
|
| 70 |
|
| 71 |
-
def runtests():
|
| 72 |
"""Run all tests in cliTest.py."""
|
|
|
|
|
|
|
|
|
|
| 73 |
loader = unittest.TestLoader()
|
| 74 |
suite = unittest.TestSuite()
|
| 75 |
-
|
|
|
|
| 76 |
runner = unittest.TextTestRunner()
|
| 77 |
return runner.run(suite)
|
| 78 |
|
|
|
|
| 68 |
self.assertEqual(result.output.strip(), expected.strip())
|
| 69 |
self.assertEqual(result.exit_code, 0)
|
| 70 |
|
| 71 |
+
def runtests(just_tests=False):
|
| 72 |
"""Run all tests in cliTest.py."""
|
| 73 |
+
tests = [TestCli]
|
| 74 |
+
if just_tests:
|
| 75 |
+
return tests
|
| 76 |
loader = unittest.TestLoader()
|
| 77 |
suite = unittest.TestSuite()
|
| 78 |
+
for test in tests:
|
| 79 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
| 80 |
runner = unittest.TextTestRunner()
|
| 81 |
return runner.run(suite)
|
| 82 |
|
pysr/test/test_dev.py
CHANGED
|
@@ -47,9 +47,13 @@ class TestDev(unittest.TestCase):
|
|
| 47 |
self.assertEqual(test_result.stdout.decode("utf-8").strip(), "2.3")
|
| 48 |
|
| 49 |
|
| 50 |
-
def runtests():
|
|
|
|
|
|
|
|
|
|
| 51 |
suite = unittest.TestSuite()
|
| 52 |
loader = unittest.TestLoader()
|
| 53 |
-
|
|
|
|
| 54 |
runner = unittest.TextTestRunner()
|
| 55 |
return runner.run(suite)
|
|
|
|
| 47 |
self.assertEqual(test_result.stdout.decode("utf-8").strip(), "2.3")
|
| 48 |
|
| 49 |
|
| 50 |
+
def runtests(just_tests=False):
|
| 51 |
+
tests = [TestDev]
|
| 52 |
+
if just_tests:
|
| 53 |
+
return tests
|
| 54 |
suite = unittest.TestSuite()
|
| 55 |
loader = unittest.TestLoader()
|
| 56 |
+
for test in tests:
|
| 57 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
| 58 |
runner = unittest.TextTestRunner()
|
| 59 |
return runner.run(suite)
|
pysr/test/test_jax.py
CHANGED
|
@@ -121,10 +121,14 @@ class TestJAX(unittest.TestCase):
|
|
| 121 |
np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
|
| 122 |
|
| 123 |
|
| 124 |
-
def runtests():
|
| 125 |
"""Run all tests in test_jax.py."""
|
|
|
|
|
|
|
|
|
|
| 126 |
loader = unittest.TestLoader()
|
| 127 |
suite = unittest.TestSuite()
|
| 128 |
-
|
|
|
|
| 129 |
runner = unittest.TextTestRunner()
|
| 130 |
return runner.run(suite)
|
|
|
|
| 121 |
np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
|
| 122 |
|
| 123 |
|
| 124 |
+
def runtests(just_tests=False):
|
| 125 |
"""Run all tests in test_jax.py."""
|
| 126 |
+
tests = [TestJAX]
|
| 127 |
+
if just_tests:
|
| 128 |
+
return tests
|
| 129 |
loader = unittest.TestLoader()
|
| 130 |
suite = unittest.TestSuite()
|
| 131 |
+
for test in tests:
|
| 132 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
| 133 |
runner = unittest.TextTestRunner()
|
| 134 |
return runner.run(suite)
|
pysr/test/test_startup.py
CHANGED
|
@@ -143,9 +143,13 @@ class TestStartup(unittest.TestCase):
|
|
| 143 |
self.assertEqual(result.returncode, 0)
|
| 144 |
|
| 145 |
|
| 146 |
-
def runtests():
|
|
|
|
|
|
|
|
|
|
| 147 |
suite = unittest.TestSuite()
|
| 148 |
loader = unittest.TestLoader()
|
| 149 |
-
|
|
|
|
| 150 |
runner = unittest.TextTestRunner()
|
| 151 |
return runner.run(suite)
|
|
|
|
| 143 |
self.assertEqual(result.returncode, 0)
|
| 144 |
|
| 145 |
|
| 146 |
+
def runtests(just_tests=False):
|
| 147 |
+
tests = [TestStartup]
|
| 148 |
+
if just_tests:
|
| 149 |
+
return tests
|
| 150 |
suite = unittest.TestSuite()
|
| 151 |
loader = unittest.TestLoader()
|
| 152 |
+
for test in tests:
|
| 153 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
| 154 |
runner = unittest.TextTestRunner()
|
| 155 |
return runner.run(suite)
|
pysr/test/test_torch.py
CHANGED
|
@@ -184,10 +184,14 @@ class TestTorch(unittest.TestCase):
|
|
| 184 |
np.testing.assert_almost_equal(y.values, torch_output, decimal=3)
|
| 185 |
|
| 186 |
|
| 187 |
-
def runtests():
|
| 188 |
"""Run all tests in test_torch.py."""
|
|
|
|
|
|
|
|
|
|
| 189 |
loader = unittest.TestLoader()
|
| 190 |
suite = unittest.TestSuite()
|
| 191 |
-
|
|
|
|
| 192 |
runner = unittest.TextTestRunner()
|
| 193 |
return runner.run(suite)
|
|
|
|
| 184 |
np.testing.assert_almost_equal(y.values, torch_output, decimal=3)
|
| 185 |
|
| 186 |
|
| 187 |
+
def runtests(just_tests=False):
|
| 188 |
"""Run all tests in test_torch.py."""
|
| 189 |
+
tests = [TestTorch]
|
| 190 |
+
if just_tests:
|
| 191 |
+
return tests
|
| 192 |
loader = unittest.TestLoader()
|
| 193 |
suite = unittest.TestSuite()
|
| 194 |
+
for test in tests:
|
| 195 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
| 196 |
runner = unittest.TextTestRunner()
|
| 197 |
return runner.run(suite)
|