Spaces:
Running
Running
test: list-like variable complexity
Browse files- pysr/test/test.py +27 -1
pysr/test/test.py
CHANGED
|
@@ -172,6 +172,26 @@ class TestPipeline(unittest.TestCase):
|
|
| 172 |
self.assertLessEqual(mse1, 1e-4)
|
| 173 |
self.assertLessEqual(mse2, 1e-4)
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
def test_multioutput_weighted_with_callable_temp_equation(self):
|
| 176 |
X = self.X.copy()
|
| 177 |
y = X[:, [0, 1]] ** 2
|
|
@@ -1053,8 +1073,14 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
| 1053 |
"""This just checks the number of units passed"""
|
| 1054 |
use_custom_variable_names = False
|
| 1055 |
variable_names = None
|
|
|
|
| 1056 |
weights = None
|
| 1057 |
-
args = (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1058 |
valid_units = [
|
| 1059 |
(np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
|
| 1060 |
(np.ones((10, 1)), np.ones(10), ["m/s"], None),
|
|
|
|
| 172 |
self.assertLessEqual(mse1, 1e-4)
|
| 173 |
self.assertLessEqual(mse2, 1e-4)
|
| 174 |
|
| 175 |
+
def test_custom_variable_complexity(self):
|
| 176 |
+
y = self.X[:, [0, 1]] ** 2
|
| 177 |
+
model = PySRRegressor(
|
| 178 |
+
binary_operators=["*", "+"],
|
| 179 |
+
verbosity=0,
|
| 180 |
+
**self.default_test_kwargs,
|
| 181 |
+
early_stop_condition="stop_if(l, c) = l < 1e-4 && c <= 7",
|
| 182 |
+
)
|
| 183 |
+
model.fit(
|
| 184 |
+
self.X,
|
| 185 |
+
y,
|
| 186 |
+
complexity_of_variables=[2, 3] + [100 for _ in range(self.X.shape[1] - 2)],
|
| 187 |
+
)
|
| 188 |
+
equations = model.equations_
|
| 189 |
+
self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
|
| 190 |
+
self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
|
| 191 |
+
|
| 192 |
+
self.assertEqual(model.get_best()[0]["complexity"], 5)
|
| 193 |
+
self.assertEqual(model.get_best()[1]["complexity"], 7)
|
| 194 |
+
|
| 195 |
def test_multioutput_weighted_with_callable_temp_equation(self):
|
| 196 |
X = self.X.copy()
|
| 197 |
y = X[:, [0, 1]] ** 2
|
|
|
|
| 1073 |
"""This just checks the number of units passed"""
|
| 1074 |
use_custom_variable_names = False
|
| 1075 |
variable_names = None
|
| 1076 |
+
complexity_of_variables = 1
|
| 1077 |
weights = None
|
| 1078 |
+
args = (
|
| 1079 |
+
use_custom_variable_names,
|
| 1080 |
+
variable_names,
|
| 1081 |
+
complexity_of_variables,
|
| 1082 |
+
weights,
|
| 1083 |
+
)
|
| 1084 |
valid_units = [
|
| 1085 |
(np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
|
| 1086 |
(np.ones((10, 1)), np.ones(10), ["m/s"], None),
|