Spaces:
Sleeping
Sleeping
Commit
·
6d58816
1
Parent(s):
49b163d
Refactor backend loading
Browse files- pysr/julia_helpers.py +22 -0
- pysr/sr.py +7 -14
pysr/julia_helpers.py
CHANGED
|
@@ -4,6 +4,7 @@ import subprocess
|
|
| 4 |
import warnings
|
| 5 |
from pathlib import Path
|
| 6 |
import os
|
|
|
|
| 7 |
|
| 8 |
from .version import __version__, __symbolic_regression_jl_version__
|
| 9 |
|
|
@@ -230,3 +231,24 @@ def _version_assertion():
|
|
| 230 |
"PySR requires Julia 1.6.0 or greater. "
|
| 231 |
"Please update your Julia installation."
|
| 232 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import warnings
|
| 5 |
from pathlib import Path
|
| 6 |
import os
|
| 7 |
+
from julia.api import JuliaError
|
| 8 |
|
| 9 |
from .version import __version__, __symbolic_regression_jl_version__
|
| 10 |
|
|
|
|
| 231 |
"PySR requires Julia 1.6.0 or greater. "
|
| 232 |
"Please update your Julia installation."
|
| 233 |
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _load_cluster_manager(Main, cluster_manager):
|
| 237 |
+
Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
|
| 238 |
+
return Main.eval(f"addprocs_{cluster_manager}")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _update_julia_project(Main, julia_project, is_shared, io_arg):
|
| 242 |
+
try:
|
| 243 |
+
if is_shared:
|
| 244 |
+
_add_sr_to_julia_project(Main, io_arg)
|
| 245 |
+
Main.eval(f"Pkg.resolve({io_arg})")
|
| 246 |
+
except (JuliaError, RuntimeError) as e:
|
| 247 |
+
raise ImportError(_import_error_string(julia_project)) from e
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _load_backend(Main, julia_project):
|
| 251 |
+
try:
|
| 252 |
+
Main.eval("using SymbolicRegression")
|
| 253 |
+
except (JuliaError, RuntimeError) as e:
|
| 254 |
+
raise ImportError(_import_error_string(julia_project)) from e
|
pysr/sr.py
CHANGED
|
@@ -26,8 +26,9 @@ from .julia_helpers import (
|
|
| 26 |
_process_julia_project,
|
| 27 |
is_julia_version_greater_eq,
|
| 28 |
_escape_filename,
|
| 29 |
-
|
| 30 |
-
|
|
|
|
| 31 |
)
|
| 32 |
from .export_numpy import CallableEquation
|
| 33 |
from .export_latex import generate_single_table, generate_multiple_tables, to_latex
|
|
@@ -1453,8 +1454,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1453 |
Main = init_julia(self.julia_project)
|
| 1454 |
|
| 1455 |
if cluster_manager is not None:
|
| 1456 |
-
|
| 1457 |
-
cluster_manager = Main.eval(f"addprocs_{cluster_manager}")
|
| 1458 |
|
| 1459 |
if not already_ran:
|
| 1460 |
julia_project, is_shared = _process_julia_project(self.julia_project)
|
|
@@ -1470,16 +1470,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1470 |
from julia.api import JuliaError
|
| 1471 |
|
| 1472 |
if self.update:
|
| 1473 |
-
|
| 1474 |
-
|
| 1475 |
-
|
| 1476 |
-
Main.eval(f"Pkg.resolve({io_arg})")
|
| 1477 |
-
except (JuliaError, RuntimeError) as e:
|
| 1478 |
-
raise ImportError(_import_error_string(julia_project)) from e
|
| 1479 |
-
try:
|
| 1480 |
-
Main.eval("using SymbolicRegression")
|
| 1481 |
-
except (JuliaError, RuntimeError) as e:
|
| 1482 |
-
raise ImportError(_import_error_string(julia_project)) from e
|
| 1483 |
|
| 1484 |
Main.plus = Main.eval("(+)")
|
| 1485 |
Main.sub = Main.eval("(-)")
|
|
|
|
| 26 |
_process_julia_project,
|
| 27 |
is_julia_version_greater_eq,
|
| 28 |
_escape_filename,
|
| 29 |
+
_load_cluster_manager,
|
| 30 |
+
_update_julia_project,
|
| 31 |
+
_load_backend,
|
| 32 |
)
|
| 33 |
from .export_numpy import CallableEquation
|
| 34 |
from .export_latex import generate_single_table, generate_multiple_tables, to_latex
|
|
|
|
| 1454 |
Main = init_julia(self.julia_project)
|
| 1455 |
|
| 1456 |
if cluster_manager is not None:
|
| 1457 |
+
cluster_manager = _load_cluster_manager(cluster_manager)
|
|
|
|
| 1458 |
|
| 1459 |
if not already_ran:
|
| 1460 |
julia_project, is_shared = _process_julia_project(self.julia_project)
|
|
|
|
| 1470 |
from julia.api import JuliaError
|
| 1471 |
|
| 1472 |
if self.update:
|
| 1473 |
+
_update_julia_project(Main, julia_project, is_shared, io_arg)
|
| 1474 |
+
|
| 1475 |
+
_load_backend(Main, julia_project)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1476 |
|
| 1477 |
Main.plus = Main.eval("(+)")
|
| 1478 |
Main.sub = Main.eval("(-)")
|