Spaces:
Running
Running
test: more typing info
Browse files- pysr/julia_helpers.py +5 -4
- pysr/sr.py +7 -7
pysr/julia_helpers.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
"""Functions for initializing the Julia environment and installing deps."""
|
| 2 |
|
| 3 |
-
from typing import Any, Callable, cast
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
from juliacall import convert as jl_convert # type: ignore
|
|
|
|
| 7 |
|
| 8 |
from .deprecated import init_julia, install
|
| 9 |
from .julia_import import jl
|
|
@@ -26,7 +27,7 @@ def _escape_filename(filename):
|
|
| 26 |
return str_repr
|
| 27 |
|
| 28 |
|
| 29 |
-
def _load_cluster_manager(cluster_manager):
|
| 30 |
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
|
| 31 |
return jl.seval(f"addprocs_{cluster_manager}")
|
| 32 |
|
|
@@ -37,13 +38,13 @@ def jl_array(x):
|
|
| 37 |
return jl_convert(jl.Array, x)
|
| 38 |
|
| 39 |
|
| 40 |
-
def jl_serialize(obj):
|
| 41 |
buf = jl.IOBuffer()
|
| 42 |
Serialization.serialize(buf, obj)
|
| 43 |
return np.array(jl.take_b(buf))
|
| 44 |
|
| 45 |
|
| 46 |
-
def jl_deserialize(s):
|
| 47 |
if s is None:
|
| 48 |
return s
|
| 49 |
buf = jl.IOBuffer()
|
|
|
|
| 1 |
"""Functions for initializing the Julia environment and installing deps."""
|
| 2 |
|
| 3 |
+
from typing import Any, Callable, Union, cast
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
from juliacall import convert as jl_convert # type: ignore
|
| 7 |
+
from numpy.typing import NDArray
|
| 8 |
|
| 9 |
from .deprecated import init_julia, install
|
| 10 |
from .julia_import import jl
|
|
|
|
| 27 |
return str_repr
|
| 28 |
|
| 29 |
|
| 30 |
+
def _load_cluster_manager(cluster_manager: str):
|
| 31 |
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
|
| 32 |
return jl.seval(f"addprocs_{cluster_manager}")
|
| 33 |
|
|
|
|
| 38 |
return jl_convert(jl.Array, x)
|
| 39 |
|
| 40 |
|
| 41 |
+
def jl_serialize(obj: Any) -> NDArray[np.uint8]:
|
| 42 |
buf = jl.IOBuffer()
|
| 43 |
Serialization.serialize(buf, obj)
|
| 44 |
return np.array(jl.take_b(buf))
|
| 45 |
|
| 46 |
|
| 47 |
+
def jl_deserialize(s: Union[NDArray[np.uint8], None]):
|
| 48 |
if s is None:
|
| 49 |
return s
|
| 50 |
buf = jl.IOBuffer()
|
pysr/sr.py
CHANGED
|
@@ -667,19 +667,19 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 667 |
```
|
| 668 |
"""
|
| 669 |
|
| 670 |
-
equations_:
|
| 671 |
n_features_in_: int
|
| 672 |
feature_names_in_: ArrayLike[str]
|
| 673 |
display_feature_names_in_: ArrayLike[str]
|
| 674 |
-
X_units_:
|
| 675 |
-
y_units_:
|
| 676 |
nout_: int
|
| 677 |
-
selection_mask_:
|
| 678 |
tempdir_: Path
|
| 679 |
equation_file_: Union[str, Path]
|
| 680 |
-
julia_state_stream_:
|
| 681 |
-
julia_options_stream_:
|
| 682 |
-
equation_file_contents_:
|
| 683 |
show_pickle_warnings_: bool
|
| 684 |
|
| 685 |
def __init__(
|
|
|
|
| 667 |
```
|
| 668 |
"""
|
| 669 |
|
| 670 |
+
equations_: Union[pd.DataFrame, List[pd.DataFrame], None]
|
| 671 |
n_features_in_: int
|
| 672 |
feature_names_in_: ArrayLike[str]
|
| 673 |
display_feature_names_in_: ArrayLike[str]
|
| 674 |
+
X_units_: Union[ArrayLike[str], None]
|
| 675 |
+
y_units_: Union[str, ArrayLike[str], None]
|
| 676 |
nout_: int
|
| 677 |
+
selection_mask_: Union[NDArray[np.bool_], None]
|
| 678 |
tempdir_: Path
|
| 679 |
equation_file_: Union[str, Path]
|
| 680 |
+
julia_state_stream_: Union[NDArray[np.uint8], None]
|
| 681 |
+
julia_options_stream_: Union[NDArray[np.uint8], None]
|
| 682 |
+
equation_file_contents_: Union[List[pd.DataFrame], None]
|
| 683 |
show_pickle_warnings_: bool
|
| 684 |
|
| 685 |
def __init__(
|