Spaces:
Sleeping
Sleeping
Add unittests for units checks
Browse files- pysr/test/test.py +39 -3
pysr/test/test.py
CHANGED
|
@@ -19,6 +19,7 @@ from ..sr import (
|
|
| 19 |
_handle_feature_selection,
|
| 20 |
_csv_filename_to_pkl_filename,
|
| 21 |
idx_model_selection,
|
|
|
|
| 22 |
)
|
| 23 |
from ..export_latex import to_latex
|
| 24 |
|
|
@@ -932,12 +933,47 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
| 932 |
self.assertLess(model.get_best()["loss"], 1e-6)
|
| 933 |
self.assertGreater(model.equations_.query("complexity <= 2").loss.min(), 1e-6)
|
| 934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
|
| 936 |
# TODO: add tests for:
|
| 937 |
# - custom operators + dimensions
|
| 938 |
-
# - invalid number of dimensions
|
| 939 |
-
# - X
|
| 940 |
-
# - y
|
| 941 |
# - no constants, so that it needs to find the right fraction
|
| 942 |
# - custom dimensional_constraint_penalty
|
| 943 |
|
|
|
|
| 19 |
_handle_feature_selection,
|
| 20 |
_csv_filename_to_pkl_filename,
|
| 21 |
idx_model_selection,
|
| 22 |
+
_check_assertions,
|
| 23 |
)
|
| 24 |
from ..export_latex import to_latex
|
| 25 |
|
|
|
|
| 933 |
self.assertLess(model.get_best()["loss"], 1e-6)
|
| 934 |
self.assertGreater(model.equations_.query("complexity <= 2").loss.min(), 1e-6)
|
| 935 |
|
| 936 |
+
def test_unit_checks(self):
|
| 937 |
+
"""This just checks the number of units passed"""
|
| 938 |
+
use_custom_variable_names = False
|
| 939 |
+
variable_names = None
|
| 940 |
+
weights = None
|
| 941 |
+
args = (use_custom_variable_names, variable_names, weights)
|
| 942 |
+
valid_units = [
|
| 943 |
+
(np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
|
| 944 |
+
(np.ones((10, 1)), np.ones(10), ["m/s"], None),
|
| 945 |
+
(np.ones((10, 1)), np.ones(10), None, "m/s"),
|
| 946 |
+
(np.ones((10, 1)), np.ones(10), None, ["m/s"]),
|
| 947 |
+
(np.ones((10, 1)), np.ones((10, 1)), None, ["m/s"]),
|
| 948 |
+
(np.ones((10, 1)), np.ones((10, 2)), None, ["m/s", "km"]),
|
| 949 |
+
]
|
| 950 |
+
for X, y, X_units, y_units in valid_units:
|
| 951 |
+
_check_assertions(
|
| 952 |
+
X,
|
| 953 |
+
*args,
|
| 954 |
+
y,
|
| 955 |
+
X_units,
|
| 956 |
+
y_units,
|
| 957 |
+
)
|
| 958 |
+
invalid_units = [
|
| 959 |
+
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], None),
|
| 960 |
+
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "m"),
|
| 961 |
+
(np.ones((10, 2)), np.ones((10, 2)), ["m/s", "s"], ["m"]),
|
| 962 |
+
(np.ones((10, 1)), np.ones((10, 1)), "m/s", ["m"]),
|
| 963 |
+
]
|
| 964 |
+
for X, y, X_units, y_units in invalid_units:
|
| 965 |
+
with self.assertRaises(ValueError):
|
| 966 |
+
_check_assertions(
|
| 967 |
+
X,
|
| 968 |
+
*args,
|
| 969 |
+
y,
|
| 970 |
+
X_units,
|
| 971 |
+
y_units,
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
|
| 975 |
# TODO: add tests for:
|
| 976 |
# - custom operators + dimensions
|
|
|
|
|
|
|
|
|
|
| 977 |
# - no constants, so that it needs to find the right fraction
|
| 978 |
# - custom dimensional_constraint_penalty
|
| 979 |
|