Spaces:
Sleeping
Sleeping
Commit
·
ea010a7
1
Parent(s):
34fadcf
Catch domain errors during classical optimization
Browse files- julia/sr.jl +22 -13
- pysr/sr.py +1 -1
julia/sr.jl
CHANGED
|
@@ -689,21 +689,30 @@ function optimizeConstants(member::PopMember)::PopMember
|
|
| 689 |
algorithm = Optim.NelderMead
|
| 690 |
end
|
| 691 |
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
|
|
|
|
|
|
| 698 |
end
|
| 699 |
-
end
|
| 700 |
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
end
|
| 708 |
return member
|
| 709 |
end
|
|
|
|
| 689 |
algorithm = Optim.NelderMead
|
| 690 |
end
|
| 691 |
|
| 692 |
+
try
|
| 693 |
+
result = Optim.optimize(f, x0, algorithm(), Optim.Options(iterations=100))
|
| 694 |
+
# Try other initial conditions:
|
| 695 |
+
for i=1:nrestarts
|
| 696 |
+
tmpresult = Optim.optimize(f, x0 .* (1f0 .+ 5f-1*randn(Float32, size(x0)[1])), algorithm(), Optim.Options(iterations=100))
|
| 697 |
+
if tmpresult.minimum < result.minimum
|
| 698 |
+
result = tmpresult
|
| 699 |
+
end
|
| 700 |
end
|
|
|
|
| 701 |
|
| 702 |
+
if Optim.converged(result)
|
| 703 |
+
setConstants(member.tree, result.minimizer)
|
| 704 |
+
member.score = convert(Float32, result.minimum)
|
| 705 |
+
member.birth = getTime()
|
| 706 |
+
else
|
| 707 |
+
setConstants(member.tree, x0)
|
| 708 |
+
end
|
| 709 |
+
catch error
|
| 710 |
+
# Fine if optimization encountered domain error, just return x0
|
| 711 |
+
if isa(error, AssertionError)
|
| 712 |
+
setConstants(member.tree, x0)
|
| 713 |
+
else
|
| 714 |
+
throw(error)
|
| 715 |
+
end
|
| 716 |
end
|
| 717 |
return member
|
| 718 |
end
|
pysr/sr.py
CHANGED
|
@@ -113,7 +113,7 @@ def pysr(X=None, y=None, threads=4,
|
|
| 113 |
y = eval(eval_str)
|
| 114 |
print("Running on", eval_str)
|
| 115 |
|
| 116 |
-
pkg_directory = '/'.join(__file__.split('/')[:-2] + ['
|
| 117 |
|
| 118 |
def_hyperparams = f"""include("{pkg_directory}/operators.jl")
|
| 119 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
|
|
|
| 113 |
y = eval(eval_str)
|
| 114 |
print("Running on", eval_str)
|
| 115 |
|
| 116 |
+
pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
|
| 117 |
|
| 118 |
def_hyperparams = f"""include("{pkg_directory}/operators.jl")
|
| 119 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|