Spaces:
Running
Running
Commit
·
aef1f27
1
Parent(s):
e63cf2d
Add warning if training on pandas dataframe then torch
Browse files- pysr/sr.py +15 -0
pysr/sr.py
CHANGED
|
@@ -796,6 +796,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 796 |
return sympy.latex(sympy_representation)
|
| 797 |
|
| 798 |
def jax(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
self.set_params(output_jax_format=True)
|
| 800 |
self.refresh()
|
| 801 |
best = self.get_best()
|
|
@@ -804,6 +810,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 804 |
return best["jax_format"]
|
| 805 |
|
| 806 |
def pytorch(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
self.set_params(output_torch_format=True)
|
| 808 |
self.refresh()
|
| 809 |
best = self.get_best()
|
|
@@ -854,6 +866,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 854 |
|
| 855 |
variable_names = list(X.columns)
|
| 856 |
X = np.array(X)
|
|
|
|
|
|
|
|
|
|
| 857 |
|
| 858 |
if len(X.shape) == 1:
|
| 859 |
X = X[:, None]
|
|
|
|
| 796 |
return sympy.latex(sympy_representation)
|
| 797 |
|
| 798 |
def jax(self):
|
| 799 |
+
if self.using_pandas:
|
| 800 |
+
warnings.warn(
|
| 801 |
+
"PySR's JAX modules are not set up to work with a "
|
| 802 |
+
"model that was trained on pandas dataframes. "
|
| 803 |
+
"Train on an array instead to ensure everything works as planned."
|
| 804 |
+
)
|
| 805 |
self.set_params(output_jax_format=True)
|
| 806 |
self.refresh()
|
| 807 |
best = self.get_best()
|
|
|
|
| 810 |
return best["jax_format"]
|
| 811 |
|
| 812 |
def pytorch(self):
|
| 813 |
+
if self.using_pandas:
|
| 814 |
+
warnings.warn(
|
| 815 |
+
"PySR's PyTorch modules are not set up to work with a "
|
| 816 |
+
"model that was trained on pandas dataframes. "
|
| 817 |
+
"Train on an array instead to ensure everything works as planned."
|
| 818 |
+
)
|
| 819 |
self.set_params(output_torch_format=True)
|
| 820 |
self.refresh()
|
| 821 |
best = self.get_best()
|
|
|
|
| 866 |
|
| 867 |
variable_names = list(X.columns)
|
| 868 |
X = np.array(X)
|
| 869 |
+
self.using_pandas = True
|
| 870 |
+
else:
|
| 871 |
+
self.using_pandas = False
|
| 872 |
|
| 873 |
if len(X.shape) == 1:
|
| 874 |
X = X[:, None]
|