Spaces:
Running
Running
Commit
·
e2ae8ef
1
Parent(s):
d781e18
Meta-program branching table rather than array of operators
Browse files- julia/sr.jl +5 -4
- pysr/sr.py +38 -3
julia/sr.jl
CHANGED
|
@@ -281,9 +281,9 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
|
|
| 281 |
if cumulator == nothing
|
| 282 |
return nothing
|
| 283 |
end
|
| 284 |
-
|
| 285 |
@inbounds @simd for i=1:clen
|
| 286 |
-
cumulator[i] =
|
| 287 |
end
|
| 288 |
@inbounds for i=1:clen
|
| 289 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
|
@@ -292,7 +292,6 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
|
|
| 292 |
end
|
| 293 |
return cumulator
|
| 294 |
else
|
| 295 |
-
op = binops[tree.op]
|
| 296 |
cumulator = evalTreeArray(tree.l, cX)
|
| 297 |
if cumulator == nothing
|
| 298 |
return nothing
|
|
@@ -302,8 +301,10 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
|
|
| 302 |
return nothing
|
| 303 |
end
|
| 304 |
|
|
|
|
|
|
|
| 305 |
@inbounds @simd for i=1:clen
|
| 306 |
-
cumulator[i] =
|
| 307 |
end
|
| 308 |
@inbounds for i=1:clen
|
| 309 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
|
|
|
| 281 |
if cumulator == nothing
|
| 282 |
return nothing
|
| 283 |
end
|
| 284 |
+
op_idx = tree.op
|
| 285 |
@inbounds @simd for i=1:clen
|
| 286 |
+
cumulator[i] = UNAOP(op_idx, cumulator[i])
|
| 287 |
end
|
| 288 |
@inbounds for i=1:clen
|
| 289 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
|
|
|
| 292 |
end
|
| 293 |
return cumulator
|
| 294 |
else
|
|
|
|
| 295 |
cumulator = evalTreeArray(tree.l, cX)
|
| 296 |
if cumulator == nothing
|
| 297 |
return nothing
|
|
|
|
| 301 |
return nothing
|
| 302 |
end
|
| 303 |
|
| 304 |
+
op_idx = tree.op
|
| 305 |
+
|
| 306 |
@inbounds @simd for i=1:clen
|
| 307 |
+
cumulator[i] = BINOP(op_idx, cumulator[i], array2[i])
|
| 308 |
end
|
| 309 |
@inbounds for i=1:clen
|
| 310 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
pysr/sr.py
CHANGED
|
@@ -242,8 +242,8 @@ def pysr(X=None, y=None, weights=None,
|
|
| 242 |
op_list[i] = function_name
|
| 243 |
|
| 244 |
def_hyperparams += f"""include("{pkg_directory}/operators.jl")
|
| 245 |
-
const binops =
|
| 246 |
-
const unaops =
|
| 247 |
const ns=10;
|
| 248 |
const parsimony = {parsimony:f}f0
|
| 249 |
const alpha = {alpha:f}f0
|
|
@@ -275,7 +275,42 @@ const mutationWeights = [
|
|
| 275 |
{weightDoNothing:f}
|
| 276 |
]
|
| 277 |
const warmupMaxsize = {warmupMaxsize:d}
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
if X.shape[1] == 1:
|
| 281 |
X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
|
|
|
|
| 242 |
op_list[i] = function_name
|
| 243 |
|
| 244 |
def_hyperparams += f"""include("{pkg_directory}/operators.jl")
|
| 245 |
+
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
| 246 |
+
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
| 247 |
const ns=10;
|
| 248 |
const parsimony = {parsimony:f}f0
|
| 249 |
const alpha = {alpha:f}f0
|
|
|
|
| 275 |
{weightDoNothing:f}
|
| 276 |
]
|
| 277 |
const warmupMaxsize = {warmupMaxsize:d}
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
op_runner = ""
|
| 281 |
+
if len(binary_operators) > 0:
|
| 282 |
+
op_runner += f"""
|
| 283 |
+
@inline function BINOP(i::Int, x::Float32, y::Float32)::Float32
|
| 284 |
+
if i == 1
|
| 285 |
+
return @fastmath {binary_operators[0]}(x, y)
|
| 286 |
+
"""
|
| 287 |
+
for i in range(1, len(binary_operators)):
|
| 288 |
+
op_runner += f"""
|
| 289 |
+
elseif i == {i+1}
|
| 290 |
+
return @fastmath {binary_operators[i]}(x, y)
|
| 291 |
+
"""
|
| 292 |
+
op_runner += """
|
| 293 |
+
end
|
| 294 |
+
end
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
if len(unary_operators) > 0:
|
| 298 |
+
op_runner += f"""
|
| 299 |
+
@inline function UNAOP(i::Int, x::Float32)::Float32
|
| 300 |
+
if i == 1
|
| 301 |
+
return @fastmath {unary_operators[0]}(x)
|
| 302 |
+
"""
|
| 303 |
+
for i in range(1, len(unary_operators)):
|
| 304 |
+
op_runner += f"""
|
| 305 |
+
elseif i == {i+1}
|
| 306 |
+
return @fastmath {unary_operators[i]}(x)
|
| 307 |
+
"""
|
| 308 |
+
op_runner += """
|
| 309 |
+
end
|
| 310 |
+
end
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def_hyperparams += op_runner
|
| 314 |
|
| 315 |
if X.shape[1] == 1:
|
| 316 |
X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
|