Spaces:
Running
Running
Commit
·
3dff82f
1
Parent(s):
e7b4ea9
Make __init__ not modify parameters again
Browse files- pysr/sr.py +120 -126
pysr/sr.py
CHANGED
|
@@ -759,8 +759,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 759 |
f"{k} is not a valid keyword argument for PySRRegressor"
|
| 760 |
)
|
| 761 |
|
| 762 |
-
self._process_params()
|
| 763 |
-
|
| 764 |
def __repr__(self):
|
| 765 |
"""
|
| 766 |
Prints all current equations fitted by the model.
|
|
@@ -865,105 +863,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 865 |
f"{self.model_selection} is not a valid model selection strategy."
|
| 866 |
)
|
| 867 |
|
| 868 |
-
def _process_params(self):
|
| 869 |
-
"""
|
| 870 |
-
Perform validation on the parameters defined in init for the
|
| 871 |
-
dataset specified in :term`fit`, and update them if necessary.
|
| 872 |
-
For example, this will change :param`binary_operators`
|
| 873 |
-
into `["+", "-", "*", "/"]` if `binary_operators` is `None`.
|
| 874 |
-
|
| 875 |
-
Raises
|
| 876 |
-
------
|
| 877 |
-
ValueError
|
| 878 |
-
Raised when on of the following occurs: `tournament_selection_n`
|
| 879 |
-
parameter is larger than `population_size`; `maxsize` is
|
| 880 |
-
less than 7; invalid `extra_jax_mappings` or
|
| 881 |
-
`extra_torch_mappings`; invalid optimizer algorithms.
|
| 882 |
-
|
| 883 |
-
"""
|
| 884 |
-
# Handle None values for instance parameters:
|
| 885 |
-
if self.binary_operators is None:
|
| 886 |
-
self.binary_operators = "+ * - /".split(" ")
|
| 887 |
-
if self.unary_operators is None:
|
| 888 |
-
self.unary_operators = []
|
| 889 |
-
if self.extra_sympy_mappings is None:
|
| 890 |
-
self.extra_sympy_mappings = {}
|
| 891 |
-
if self.constraints is None:
|
| 892 |
-
self.constraints = {}
|
| 893 |
-
if self.multithreading is None:
|
| 894 |
-
# Default is multithreading=True, unless explicitly set,
|
| 895 |
-
# or procs is set to 0 (serial mode).
|
| 896 |
-
self.multithreading = self.procs != 0 and self.cluster_manager is None
|
| 897 |
-
if self.update_verbosity is None:
|
| 898 |
-
self.update_verbosity = self.verbosity
|
| 899 |
-
if self.maxdepth is None:
|
| 900 |
-
self.maxdepth = self.maxsize
|
| 901 |
-
|
| 902 |
-
# Handle type conversion for instance parameters:
|
| 903 |
-
if isinstance(self.binary_operators, str):
|
| 904 |
-
self.binary_operators = [self.binary_operators]
|
| 905 |
-
if isinstance(self.unary_operators, str):
|
| 906 |
-
self.unary_operators = [self.unary_operators]
|
| 907 |
-
|
| 908 |
-
# Warn if instance parameters are not sensible values:
|
| 909 |
-
if self.batch_size < 1:
|
| 910 |
-
warnings.warn(
|
| 911 |
-
"Given :param`batch_size` must be greater than or equal to one. "
|
| 912 |
-
":param`batch_size` has been increased to equal one."
|
| 913 |
-
)
|
| 914 |
-
self.batch_size = 1
|
| 915 |
-
|
| 916 |
-
# Ensure instance parameters are allowable values:
|
| 917 |
-
# ValueError - Incompatible values
|
| 918 |
-
if self.tournament_selection_n > self.population_size:
|
| 919 |
-
raise ValueError(
|
| 920 |
-
"tournament_selection_n parameter must be smaller than population_size."
|
| 921 |
-
)
|
| 922 |
-
|
| 923 |
-
if self.maxsize > 40:
|
| 924 |
-
warnings.warn(
|
| 925 |
-
"Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
| 926 |
-
)
|
| 927 |
-
elif self.maxsize < 7:
|
| 928 |
-
raise ValueError("PySR requires a maxsize of at least 7")
|
| 929 |
-
|
| 930 |
-
if self.extra_jax_mappings is not None:
|
| 931 |
-
for value in self.extra_jax_mappings.values():
|
| 932 |
-
if not isinstance(value, str):
|
| 933 |
-
raise ValueError(
|
| 934 |
-
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 935 |
-
)
|
| 936 |
-
else:
|
| 937 |
-
self.extra_jax_mappings = {}
|
| 938 |
-
|
| 939 |
-
if self.extra_torch_mappings is not None:
|
| 940 |
-
for value in self.extra_jax_mappings.values():
|
| 941 |
-
if not callable(value):
|
| 942 |
-
raise ValueError(
|
| 943 |
-
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
| 944 |
-
)
|
| 945 |
-
else:
|
| 946 |
-
self.extra_torch_mappings = {}
|
| 947 |
-
|
| 948 |
-
# NotImplementedError - Values that could be supported at a later time
|
| 949 |
-
if self.optimizer_algorithm not in self.VALID_OPTIMIZER_ALGORITHMS:
|
| 950 |
-
raise NotImplementedError(
|
| 951 |
-
f"PySR currently only supports the following optimizer algorithms: {self.VALID_OPTIMIZER_ALGORITHMS}"
|
| 952 |
-
)
|
| 953 |
-
|
| 954 |
-
# Handle presentation of the progress bar:
|
| 955 |
-
buffer_available = "buffer" in sys.stdout.__dir__()
|
| 956 |
-
if self.progress is not None:
|
| 957 |
-
if self.progress and not buffer_available:
|
| 958 |
-
warnings.warn(
|
| 959 |
-
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
| 960 |
-
)
|
| 961 |
-
self.progress = False
|
| 962 |
-
else:
|
| 963 |
-
self.progress = buffer_available
|
| 964 |
-
|
| 965 |
-
return self
|
| 966 |
-
|
| 967 |
def _setup_equation_file(self):
|
| 968 |
"""
|
| 969 |
Sets the full pathname of the equation file, using :param`tempdir` and
|
|
@@ -1016,6 +915,39 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1016 |
|
| 1017 |
"""
|
| 1018 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1019 |
if isinstance(X, pd.DataFrame):
|
| 1020 |
if variable_names:
|
| 1021 |
variable_names = None
|
|
@@ -1165,23 +1097,82 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1165 |
global already_ran
|
| 1166 |
global Main
|
| 1167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1168 |
# Start julia backend processes
|
| 1169 |
if Main is None:
|
| 1170 |
-
if
|
| 1171 |
os.environ["JULIA_NUM_THREADS"] = str(self.procs)
|
| 1172 |
|
| 1173 |
Main = init_julia()
|
| 1174 |
|
| 1175 |
-
if
|
| 1176 |
-
Main.eval(f"import ClusterManagers: addprocs_{
|
| 1177 |
-
cluster_manager = Main.eval(f"addprocs_{
|
| 1178 |
-
else:
|
| 1179 |
-
cluster_manager = None
|
| 1180 |
|
| 1181 |
if not already_ran:
|
| 1182 |
julia_project, is_shared = _get_julia_project(self.julia_project)
|
| 1183 |
Main.eval("using Pkg")
|
| 1184 |
-
io = "devnull" if
|
| 1185 |
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
|
| 1186 |
|
| 1187 |
Main.eval(
|
|
@@ -1211,39 +1202,35 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1211 |
|
| 1212 |
# TODO(mcranmer): These functions should be part of this class.
|
| 1213 |
binary_operators, unary_operators = _maybe_create_inline_operators(
|
| 1214 |
-
binary_operators=
|
| 1215 |
)
|
| 1216 |
constraints = _process_constraints(
|
| 1217 |
binary_operators=binary_operators,
|
| 1218 |
unary_operators=unary_operators,
|
| 1219 |
-
constraints=
|
| 1220 |
)
|
| 1221 |
|
| 1222 |
una_constraints = [constraints[op] for op in unary_operators]
|
| 1223 |
bin_constraints = [constraints[op] for op in binary_operators]
|
| 1224 |
|
| 1225 |
# Parse dict into Julia Dict for nested constraints::
|
| 1226 |
-
if
|
| 1227 |
nested_constraints_str = "Dict("
|
| 1228 |
-
for outer_k, outer_v in
|
| 1229 |
nested_constraints_str += f"({outer_k}) => Dict("
|
| 1230 |
for inner_k, inner_v in outer_v.items():
|
| 1231 |
nested_constraints_str += f"({inner_k}) => {inner_v}, "
|
| 1232 |
nested_constraints_str += "), "
|
| 1233 |
nested_constraints_str += ")"
|
| 1234 |
nested_constraints = Main.eval(nested_constraints_str)
|
| 1235 |
-
else:
|
| 1236 |
-
nested_constraints = None
|
| 1237 |
|
| 1238 |
# Parse dict into Julia Dict for complexities:
|
| 1239 |
-
if
|
| 1240 |
complexity_of_operators_str = "Dict("
|
| 1241 |
-
for k, v in
|
| 1242 |
complexity_of_operators_str += f"({k}) => {v}, "
|
| 1243 |
complexity_of_operators_str += ")"
|
| 1244 |
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
| 1245 |
-
else:
|
| 1246 |
-
complexity_of_operators = None
|
| 1247 |
|
| 1248 |
Main.custom_loss = Main.eval(self.loss)
|
| 1249 |
|
|
@@ -1274,14 +1261,14 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1274 |
hofFile=_escape_filename(self.equation_file_),
|
| 1275 |
npopulations=int(self.populations),
|
| 1276 |
batching=self.batching,
|
| 1277 |
-
batchSize=int(min([
|
| 1278 |
mutationWeights=mutationWeights,
|
| 1279 |
probPickFirst=self.tournament_selection_p,
|
| 1280 |
ns=self.tournament_selection_n,
|
| 1281 |
# These have the same name:
|
| 1282 |
parsimony=self.parsimony,
|
| 1283 |
alpha=self.alpha,
|
| 1284 |
-
maxdepth=
|
| 1285 |
fast_cycle=self.fast_cycle,
|
| 1286 |
migration=self.migration,
|
| 1287 |
hofMigration=self.hof_migration,
|
|
@@ -1302,7 +1289,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1302 |
perturbationFactor=self.perturbation_factor,
|
| 1303 |
annealing=self.annealing,
|
| 1304 |
stateReturn=True, # Required for state saving.
|
| 1305 |
-
progress=
|
| 1306 |
timeout_in_seconds=self.timeout_in_seconds,
|
| 1307 |
crossoverProbability=self.crossover_probability,
|
| 1308 |
skip_mutation_failures=self.skip_mutation_failures,
|
|
@@ -1313,6 +1300,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1313 |
# Convert data to desired precision
|
| 1314 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
|
| 1315 |
|
|
|
|
| 1316 |
Main.X = np.array(X, dtype=np_dtype).T
|
| 1317 |
if len(y.shape) == 1:
|
| 1318 |
Main.y = np.array(y, dtype=np_dtype)
|
|
@@ -1326,7 +1314,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1326 |
else:
|
| 1327 |
Main.weights = None
|
| 1328 |
|
| 1329 |
-
cprocs = 0 if
|
| 1330 |
|
| 1331 |
# Call to Julia backend.
|
| 1332 |
# See https://github.com/search?q=%22function+EquationSearch%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3ASymbolicRegression.jl+language%3AJulia&type=Code
|
|
@@ -1338,7 +1326,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1338 |
varMap=self.feature_names_in_.tolist(),
|
| 1339 |
options=options,
|
| 1340 |
numprocs=int(cprocs),
|
| 1341 |
-
multithreading=bool(
|
| 1342 |
saved_state=self.raw_julia_state_,
|
| 1343 |
addprocs_function=cluster_manager,
|
| 1344 |
)
|
|
@@ -1714,7 +1702,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1714 |
if self.output_torch_format:
|
| 1715 |
torch_format = []
|
| 1716 |
local_sympy_mappings = {
|
| 1717 |
-
**self.extra_sympy_mappings,
|
| 1718 |
**sympy_mappings,
|
| 1719 |
}
|
| 1720 |
|
|
@@ -1741,7 +1729,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1741 |
eqn,
|
| 1742 |
sympy_symbols,
|
| 1743 |
selection=self.selection_mask_,
|
| 1744 |
-
extra_jax_mappings=
|
|
|
|
|
|
|
| 1745 |
)
|
| 1746 |
jax_format.append({"callable": func, "parameters": params})
|
| 1747 |
|
|
@@ -1753,7 +1743,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1753 |
eqn,
|
| 1754 |
sympy_symbols,
|
| 1755 |
selection=self.selection_mask_,
|
| 1756 |
-
extra_torch_mappings=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1757 |
)
|
| 1758 |
torch_format.append(module)
|
| 1759 |
|
|
|
|
| 759 |
f"{k} is not a valid keyword argument for PySRRegressor"
|
| 760 |
)
|
| 761 |
|
|
|
|
|
|
|
| 762 |
def __repr__(self):
|
| 763 |
"""
|
| 764 |
Prints all current equations fitted by the model.
|
|
|
|
| 863 |
f"{self.model_selection} is not a valid model selection strategy."
|
| 864 |
)
|
| 865 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
def _setup_equation_file(self):
|
| 867 |
"""
|
| 868 |
Sets the full pathname of the equation file, using :param`tempdir` and
|
|
|
|
| 915 |
|
| 916 |
"""
|
| 917 |
|
| 918 |
+
# Ensure instance parameters are allowable values:
|
| 919 |
+
if self.tournament_selection_n > self.population_size:
|
| 920 |
+
raise ValueError(
|
| 921 |
+
"tournament_selection_n parameter must be smaller than population_size."
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
if self.maxsize > 40:
|
| 925 |
+
warnings.warn(
|
| 926 |
+
"Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
| 927 |
+
)
|
| 928 |
+
elif self.maxsize < 7:
|
| 929 |
+
raise ValueError("PySR requires a maxsize of at least 7")
|
| 930 |
+
|
| 931 |
+
if self.extra_jax_mappings is not None:
|
| 932 |
+
for value in self.extra_jax_mappings.values():
|
| 933 |
+
if not isinstance(value, str):
|
| 934 |
+
raise ValueError(
|
| 935 |
+
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
if self.extra_torch_mappings is not None:
|
| 939 |
+
for value in self.extra_jax_mappings.values():
|
| 940 |
+
if not callable(value):
|
| 941 |
+
raise ValueError(
|
| 942 |
+
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
# NotImplementedError - Values that could be supported at a later time
|
| 946 |
+
if self.optimizer_algorithm not in self.VALID_OPTIMIZER_ALGORITHMS:
|
| 947 |
+
raise NotImplementedError(
|
| 948 |
+
f"PySR currently only supports the following optimizer algorithms: {self.VALID_OPTIMIZER_ALGORITHMS}"
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
if isinstance(X, pd.DataFrame):
|
| 952 |
if variable_names:
|
| 953 |
variable_names = None
|
|
|
|
| 1097 |
global already_ran
|
| 1098 |
global Main
|
| 1099 |
|
| 1100 |
+
# These are the parameters which may be modified from the ones
|
| 1101 |
+
# specified in init, so we define them here locally:
|
| 1102 |
+
binary_operators = self.binary_operators
|
| 1103 |
+
unary_operators = self.unary_operators
|
| 1104 |
+
constraints = self.constraints
|
| 1105 |
+
nested_constraints = self.nested_constraints
|
| 1106 |
+
complexity_of_operators = self.complexity_of_operators
|
| 1107 |
+
multithreading = self.multithreading
|
| 1108 |
+
update_verbosity = self.update_verbosity
|
| 1109 |
+
maxdepth = self.maxdepth
|
| 1110 |
+
batch_size = self.batch_size
|
| 1111 |
+
progress = self.progress
|
| 1112 |
+
cluster_manager = self.cluster_manager
|
| 1113 |
+
|
| 1114 |
+
# TODO: Clean this up into a readable format, such that
|
| 1115 |
+
# a function call automatically configures each default.
|
| 1116 |
+
|
| 1117 |
+
# Deal with default values, and type conversions:
|
| 1118 |
+
if binary_operators is None:
|
| 1119 |
+
binary_operators = "+ * - /".split(" ")
|
| 1120 |
+
elif isinstance(binary_operators, str):
|
| 1121 |
+
binary_operators = [binary_operators]
|
| 1122 |
+
|
| 1123 |
+
if unary_operators is None:
|
| 1124 |
+
unary_operators = []
|
| 1125 |
+
elif isinstance(unary_operators, str):
|
| 1126 |
+
unary_operators = [unary_operators]
|
| 1127 |
+
|
| 1128 |
+
if constraints is None:
|
| 1129 |
+
constraints = {}
|
| 1130 |
+
|
| 1131 |
+
if multithreading is None:
|
| 1132 |
+
# Default is multithreading=True, unless explicitly set,
|
| 1133 |
+
# or procs is set to 0 (serial mode).
|
| 1134 |
+
multithreading = self.procs != 0 and cluster_manager is None
|
| 1135 |
+
|
| 1136 |
+
if update_verbosity is None:
|
| 1137 |
+
update_verbosity = self.verbosity
|
| 1138 |
+
|
| 1139 |
+
if maxdepth is None:
|
| 1140 |
+
maxdepth = self.maxsize
|
| 1141 |
+
|
| 1142 |
+
# Warn if instance parameters are not sensible values:
|
| 1143 |
+
if batch_size < 1:
|
| 1144 |
+
warnings.warn(
|
| 1145 |
+
"Given :param`batch_size` must be greater than or equal to one. "
|
| 1146 |
+
":param`batch_size` has been increased to equal one."
|
| 1147 |
+
)
|
| 1148 |
+
batch_size = 1
|
| 1149 |
+
|
| 1150 |
+
# Handle presentation of the progress bar:
|
| 1151 |
+
buffer_available = "buffer" in sys.stdout.__dir__()
|
| 1152 |
+
if progress is not None:
|
| 1153 |
+
if progress and not buffer_available:
|
| 1154 |
+
warnings.warn(
|
| 1155 |
+
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
| 1156 |
+
)
|
| 1157 |
+
progress = False
|
| 1158 |
+
else:
|
| 1159 |
+
progress = buffer_available
|
| 1160 |
+
|
| 1161 |
# Start julia backend processes
|
| 1162 |
if Main is None:
|
| 1163 |
+
if multithreading:
|
| 1164 |
os.environ["JULIA_NUM_THREADS"] = str(self.procs)
|
| 1165 |
|
| 1166 |
Main = init_julia()
|
| 1167 |
|
| 1168 |
+
if cluster_manager is not None:
|
| 1169 |
+
Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
|
| 1170 |
+
cluster_manager = Main.eval(f"addprocs_{cluster_manager}")
|
|
|
|
|
|
|
| 1171 |
|
| 1172 |
if not already_ran:
|
| 1173 |
julia_project, is_shared = _get_julia_project(self.julia_project)
|
| 1174 |
Main.eval("using Pkg")
|
| 1175 |
+
io = "devnull" if update_verbosity == 0 else "stderr"
|
| 1176 |
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
|
| 1177 |
|
| 1178 |
Main.eval(
|
|
|
|
| 1202 |
|
| 1203 |
# TODO(mcranmer): These functions should be part of this class.
|
| 1204 |
binary_operators, unary_operators = _maybe_create_inline_operators(
|
| 1205 |
+
binary_operators=binary_operators, unary_operators=unary_operators
|
| 1206 |
)
|
| 1207 |
constraints = _process_constraints(
|
| 1208 |
binary_operators=binary_operators,
|
| 1209 |
unary_operators=unary_operators,
|
| 1210 |
+
constraints=constraints,
|
| 1211 |
)
|
| 1212 |
|
| 1213 |
una_constraints = [constraints[op] for op in unary_operators]
|
| 1214 |
bin_constraints = [constraints[op] for op in binary_operators]
|
| 1215 |
|
| 1216 |
# Parse dict into Julia Dict for nested constraints::
|
| 1217 |
+
if nested_constraints is not None:
|
| 1218 |
nested_constraints_str = "Dict("
|
| 1219 |
+
for outer_k, outer_v in nested_constraints.items():
|
| 1220 |
nested_constraints_str += f"({outer_k}) => Dict("
|
| 1221 |
for inner_k, inner_v in outer_v.items():
|
| 1222 |
nested_constraints_str += f"({inner_k}) => {inner_v}, "
|
| 1223 |
nested_constraints_str += "), "
|
| 1224 |
nested_constraints_str += ")"
|
| 1225 |
nested_constraints = Main.eval(nested_constraints_str)
|
|
|
|
|
|
|
| 1226 |
|
| 1227 |
# Parse dict into Julia Dict for complexities:
|
| 1228 |
+
if complexity_of_operators is not None:
|
| 1229 |
complexity_of_operators_str = "Dict("
|
| 1230 |
+
for k, v in complexity_of_operators.items():
|
| 1231 |
complexity_of_operators_str += f"({k}) => {v}, "
|
| 1232 |
complexity_of_operators_str += ")"
|
| 1233 |
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
|
|
|
|
|
|
| 1234 |
|
| 1235 |
Main.custom_loss = Main.eval(self.loss)
|
| 1236 |
|
|
|
|
| 1261 |
hofFile=_escape_filename(self.equation_file_),
|
| 1262 |
npopulations=int(self.populations),
|
| 1263 |
batching=self.batching,
|
| 1264 |
+
batchSize=int(min([batch_size, len(X)]) if self.batching else len(X)),
|
| 1265 |
mutationWeights=mutationWeights,
|
| 1266 |
probPickFirst=self.tournament_selection_p,
|
| 1267 |
ns=self.tournament_selection_n,
|
| 1268 |
# These have the same name:
|
| 1269 |
parsimony=self.parsimony,
|
| 1270 |
alpha=self.alpha,
|
| 1271 |
+
maxdepth=maxdepth,
|
| 1272 |
fast_cycle=self.fast_cycle,
|
| 1273 |
migration=self.migration,
|
| 1274 |
hofMigration=self.hof_migration,
|
|
|
|
| 1289 |
perturbationFactor=self.perturbation_factor,
|
| 1290 |
annealing=self.annealing,
|
| 1291 |
stateReturn=True, # Required for state saving.
|
| 1292 |
+
progress=progress,
|
| 1293 |
timeout_in_seconds=self.timeout_in_seconds,
|
| 1294 |
crossoverProbability=self.crossover_probability,
|
| 1295 |
skip_mutation_failures=self.skip_mutation_failures,
|
|
|
|
| 1300 |
# Convert data to desired precision
|
| 1301 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
|
| 1302 |
|
| 1303 |
+
# This converts the data into a Julia array:
|
| 1304 |
Main.X = np.array(X, dtype=np_dtype).T
|
| 1305 |
if len(y.shape) == 1:
|
| 1306 |
Main.y = np.array(y, dtype=np_dtype)
|
|
|
|
| 1314 |
else:
|
| 1315 |
Main.weights = None
|
| 1316 |
|
| 1317 |
+
cprocs = 0 if multithreading else self.procs
|
| 1318 |
|
| 1319 |
# Call to Julia backend.
|
| 1320 |
# See https://github.com/search?q=%22function+EquationSearch%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3ASymbolicRegression.jl+language%3AJulia&type=Code
|
|
|
|
| 1326 |
varMap=self.feature_names_in_.tolist(),
|
| 1327 |
options=options,
|
| 1328 |
numprocs=int(cprocs),
|
| 1329 |
+
multithreading=bool(multithreading),
|
| 1330 |
saved_state=self.raw_julia_state_,
|
| 1331 |
addprocs_function=cluster_manager,
|
| 1332 |
)
|
|
|
|
| 1702 |
if self.output_torch_format:
|
| 1703 |
torch_format = []
|
| 1704 |
local_sympy_mappings = {
|
| 1705 |
+
**(self.extra_sympy_mappings if self.extra_sympy_mappings else {}),
|
| 1706 |
**sympy_mappings,
|
| 1707 |
}
|
| 1708 |
|
|
|
|
| 1729 |
eqn,
|
| 1730 |
sympy_symbols,
|
| 1731 |
selection=self.selection_mask_,
|
| 1732 |
+
extra_jax_mappings=(
|
| 1733 |
+
self.extra_jax_mappings if self.extra_jax_mappings else {}
|
| 1734 |
+
),
|
| 1735 |
)
|
| 1736 |
jax_format.append({"callable": func, "parameters": params})
|
| 1737 |
|
|
|
|
| 1743 |
eqn,
|
| 1744 |
sympy_symbols,
|
| 1745 |
selection=self.selection_mask_,
|
| 1746 |
+
extra_torch_mappings=(
|
| 1747 |
+
self.extra_torch_mappings
|
| 1748 |
+
if self.extra_torch_mappings
|
| 1749 |
+
else {}
|
| 1750 |
+
),
|
| 1751 |
)
|
| 1752 |
torch_format.append(module)
|
| 1753 |
|