Spaces:
Sleeping
Sleeping
Commit
·
5617815
1
Parent(s):
b5b74c3
Make y flat if only one output feature
Browse files- pysr/sr.py +5 -4
pysr/sr.py
CHANGED
|
@@ -278,12 +278,13 @@ def pysr(X=None, y=None, weights=None,
|
|
| 278 |
if X is None:
|
| 279 |
X, y = _using_test_input(X, test, y)
|
| 280 |
|
| 281 |
-
if len(y.shape) == 2:
|
| 282 |
-
multioutput = True
|
| 283 |
-
nout = y.shape[1]
|
| 284 |
-
elif len(y.shape) == 1:
|
| 285 |
multioutput = False
|
| 286 |
nout = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
else:
|
| 288 |
raise NotImplementedError("y shape not supported!")
|
| 289 |
|
|
|
|
| 278 |
if X is None:
|
| 279 |
X, y = _using_test_input(X, test, y)
|
| 280 |
|
| 281 |
+
if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
|
|
|
|
|
|
|
|
|
|
| 282 |
multioutput = False
|
| 283 |
nout = 1
|
| 284 |
+
y = y.reshape(-1)
|
| 285 |
+
elif len(y.shape) == 2:
|
| 286 |
+
multioutput = True
|
| 287 |
+
nout = y.shape[1]
|
| 288 |
else:
|
| 289 |
raise NotImplementedError("y shape not supported!")
|
| 290 |
|