Spaces:
Running
Running
Commit
·
e68c63f
1
Parent(s):
a06de5e
Add feature for operator-level size constraints
Browse files- docs/options.md +14 -8
- julia/sr.jl +110 -45
- pysr/sr.py +43 -4
docs/options.md
CHANGED
|
@@ -14,7 +14,7 @@ may find useful include:
|
|
| 14 |
- `maxsize`, `maxdepth`
|
| 15 |
- `batching`, `batchSize`
|
| 16 |
- `variable_names` (or pandas input)
|
| 17 |
-
-
|
| 18 |
- LaTeX, SymPy, and callable equation output
|
| 19 |
|
| 20 |
These are described below
|
|
@@ -129,13 +129,19 @@ alphabetical characters and `_` are used in these names.
|
|
| 129 |
|
| 130 |
## Limiting pow complexity
|
| 131 |
|
| 132 |
-
One can limit the complexity of
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
## LaTeX, SymPy, callables
|
| 141 |
|
|
|
|
| 14 |
- `maxsize`, `maxdepth`
|
| 15 |
- `batching`, `batchSize`
|
| 16 |
- `variable_names` (or pandas input)
|
| 17 |
+
- Constraining operator complexity
|
| 18 |
- LaTeX, SymPy, and callable equation output
|
| 19 |
|
| 20 |
These are described below
|
|
|
|
| 129 |
|
| 130 |
## Limiting pow complexity
|
| 131 |
|
| 132 |
+
One can limit the complexity of specific operators with the `constraints` parameter.
|
| 133 |
+
There is a "maxsize" parameter to PySR, but there is also an operator-level
|
| 134 |
+
"constraints" parameter. One supplies a dict, like so:
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
constraints={'pow': (-1, 1), 'mult': (3, 3), 'cos': 5}
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
What this says is that: a power law x^y can have an expression of arbitrary (-1) complexity in the x, but only complexity 1 (e.g., a constant or variable) in the y. So (x0 + 3)^5.5 is allowed, but 5.5^(x0 + 3) is not.
|
| 141 |
+
I find this helps a lot for getting more interpretable equations.
|
| 142 |
+
The other terms say that each multiplication can only have sub-expressions
|
| 143 |
+
of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
|
| 144 |
+
expressions of complexity 5 (e.g., 5.0 + x2*exp(x3)).
|
| 145 |
|
| 146 |
## LaTeX, SymPy, callables
|
| 147 |
|
julia/sr.jl
CHANGED
|
@@ -646,24 +646,46 @@ mutable struct PopMember
|
|
| 646 |
|
| 647 |
end
|
| 648 |
|
| 649 |
-
# Check if any
|
| 650 |
-
function
|
| 651 |
if tree.degree == 0
|
| 652 |
-
return
|
| 653 |
elseif tree.degree == 1
|
| 654 |
-
return
|
| 655 |
else
|
| 656 |
-
if
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
end
|
| 664 |
-
else
|
| 665 |
-
return 0 + deepPow(tree.l) + deepPow(tree.r)
|
| 666 |
end
|
|
|
|
|
|
|
|
|
|
| 667 |
end
|
| 668 |
end
|
| 669 |
|
|
@@ -671,61 +693,104 @@ end
|
|
| 671 |
# exp(-delta/T) defines probability of accepting a change
|
| 672 |
function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
|
| 673 |
prev = member.tree
|
| 674 |
-
tree =
|
| 675 |
#TODO - reconsider this
|
| 676 |
if batching
|
| 677 |
-
beforeLoss = scoreFuncBatch(
|
| 678 |
else
|
| 679 |
beforeLoss = member.score
|
| 680 |
end
|
| 681 |
|
| 682 |
mutationChoice = rand()
|
| 683 |
-
weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
|
| 684 |
-
cur_weights = copy(mutationWeights) .* 1.0
|
| 685 |
#More constants => more likely to do constant mutation
|
|
|
|
|
|
|
| 686 |
cur_weights[1] *= weightAdjustmentMutateConstant
|
| 687 |
-
n = countNodes(
|
| 688 |
-
depth = countDepth(
|
| 689 |
|
| 690 |
# If equation too big, don't add new operators
|
| 691 |
if n >= curmaxsize || depth >= maxdepth
|
| 692 |
cur_weights[3] = 0.0
|
| 693 |
cur_weights[4] = 0.0
|
| 694 |
end
|
| 695 |
-
|
| 696 |
cur_weights /= sum(cur_weights)
|
| 697 |
cweights = cumsum(cur_weights)
|
| 698 |
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
end
|
| 709 |
-
elseif mutationChoice < cweights[4]
|
| 710 |
-
tree = insertRandomOp(tree)
|
| 711 |
-
elseif mutationChoice < cweights[5]
|
| 712 |
-
tree = deleteRandomOp(tree)
|
| 713 |
-
elseif mutationChoice < cweights[6]
|
| 714 |
-
tree = simplifyTree(tree) # Sometimes we simplify tree
|
| 715 |
-
tree = combineOperators(tree) # See if repeated constants at outer levels
|
| 716 |
-
return PopMember(tree, beforeLoss)
|
| 717 |
-
elseif mutationChoice < cweights[7]
|
| 718 |
-
tree = genRandomTree(5) # Sometimes we generate a new tree completely tree
|
| 719 |
-
else
|
| 720 |
-
return PopMember(tree, beforeLoss)
|
| 721 |
-
end
|
| 722 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
|
| 724 |
-
|
| 725 |
-
if limitPowComplexity && (deepPow(tree) > 0)
|
| 726 |
-
return PopMember(copyNode(prev), beforeLoss)
|
| 727 |
end
|
|
|
|
| 728 |
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
if batching
|
| 731 |
afterLoss = scoreFuncBatch(tree)
|
|
|
|
| 646 |
|
| 647 |
end
|
| 648 |
|
| 649 |
+
# Check if any binary operator are overly complex
|
| 650 |
+
function flagBinOperatorComplexity(tree::Node, op::Int)::Bool
|
| 651 |
if tree.degree == 0
|
| 652 |
+
return false
|
| 653 |
elseif tree.degree == 1
|
| 654 |
+
return flagBinOperatorComplexity(tree.l, op)
|
| 655 |
else
|
| 656 |
+
if tree.op == op
|
| 657 |
+
overly_complex = (
|
| 658 |
+
((bin_constraints[op][1] > -1) &&
|
| 659 |
+
(countNodes(tree.l) > bin_constraints[op][1]))
|
| 660 |
+
||
|
| 661 |
+
((bin_constraints[op][2] > -1) &&
|
| 662 |
+
(countNodes(tree.r) > bin_constraints[op][2]))
|
| 663 |
+
)
|
| 664 |
+
if overly_complex
|
| 665 |
+
return true
|
| 666 |
+
end
|
| 667 |
+
end
|
| 668 |
+
return (flagBinOperatorComplexity(tree.l, op) || flagBinOperatorComplexity(tree.r, op))
|
| 669 |
+
end
|
| 670 |
+
end
|
| 671 |
+
|
| 672 |
+
# Check if any unary operators are overly complex
|
| 673 |
+
function flagUnaOperatorComplexity(tree::Node, op::Int)::Bool
|
| 674 |
+
if tree.degree == 0
|
| 675 |
+
return false
|
| 676 |
+
elseif tree.degree == 1
|
| 677 |
+
if tree.op == op
|
| 678 |
+
overly_complex = (
|
| 679 |
+
(una_constraints[op] > -1) &&
|
| 680 |
+
(countNodes(tree.l) > una_constraints[op])
|
| 681 |
+
)
|
| 682 |
+
if overly_complex
|
| 683 |
+
return true
|
| 684 |
end
|
|
|
|
|
|
|
| 685 |
end
|
| 686 |
+
return flagUnaOperatorComplexity(tree.l, op)
|
| 687 |
+
else
|
| 688 |
+
return (flagUnaOperatorComplexity(tree.l, op) || flagUnaOperatorComplexity(tree.r, op))
|
| 689 |
end
|
| 690 |
end
|
| 691 |
|
|
|
|
| 693 |
# exp(-delta/T) defines probability of accepting a change
|
| 694 |
function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
|
| 695 |
prev = member.tree
|
| 696 |
+
tree = prev
|
| 697 |
#TODO - reconsider this
|
| 698 |
if batching
|
| 699 |
+
beforeLoss = scoreFuncBatch(prev)
|
| 700 |
else
|
| 701 |
beforeLoss = member.score
|
| 702 |
end
|
| 703 |
|
| 704 |
mutationChoice = rand()
|
|
|
|
|
|
|
| 705 |
#More constants => more likely to do constant mutation
|
| 706 |
+
weightAdjustmentMutateConstant = min(8, countConstants(prev))/8.0
|
| 707 |
+
cur_weights = copy(mutationWeights) .* 1.0
|
| 708 |
cur_weights[1] *= weightAdjustmentMutateConstant
|
| 709 |
+
n = countNodes(prev)
|
| 710 |
+
depth = countDepth(prev)
|
| 711 |
|
| 712 |
# If equation too big, don't add new operators
|
| 713 |
if n >= curmaxsize || depth >= maxdepth
|
| 714 |
cur_weights[3] = 0.0
|
| 715 |
cur_weights[4] = 0.0
|
| 716 |
end
|
|
|
|
| 717 |
cur_weights /= sum(cur_weights)
|
| 718 |
cweights = cumsum(cur_weights)
|
| 719 |
|
| 720 |
+
successful_mutation = false
|
| 721 |
+
#TODO: Currently we dont take this \/ into account
|
| 722 |
+
is_success_always_possible = true
|
| 723 |
+
attempts = 0
|
| 724 |
+
max_attempts = 10
|
| 725 |
+
|
| 726 |
+
#############################################
|
| 727 |
+
# Mutations
|
| 728 |
+
#############################################
|
| 729 |
+
while (!successful_mutation) && attempts < max_attempts
|
| 730 |
+
tree = copyNode(prev)
|
| 731 |
+
successful_mutation = true
|
| 732 |
+
if mutationChoice < cweights[1]
|
| 733 |
+
tree = mutateConstant(tree, T)
|
| 734 |
+
|
| 735 |
+
is_success_always_possible = true
|
| 736 |
+
# Mutating a constant shouldn't invalidate an already-valid function
|
| 737 |
+
|
| 738 |
+
elseif mutationChoice < cweights[2]
|
| 739 |
+
tree = mutateOperator(tree)
|
| 740 |
+
|
| 741 |
+
is_success_always_possible = true
|
| 742 |
+
# Can always mutate to the same operator
|
| 743 |
+
|
| 744 |
+
elseif mutationChoice < cweights[3]
|
| 745 |
+
if rand() < 0.5
|
| 746 |
+
tree = appendRandomOp(tree)
|
| 747 |
+
else
|
| 748 |
+
tree = prependRandomOp(tree)
|
| 749 |
+
end
|
| 750 |
+
is_success_always_possible = false
|
| 751 |
+
# Can potentially have a situation without success
|
| 752 |
+
elseif mutationChoice < cweights[4]
|
| 753 |
+
tree = insertRandomOp(tree)
|
| 754 |
+
is_success_always_possible = false
|
| 755 |
+
elseif mutationChoice < cweights[5]
|
| 756 |
+
tree = deleteRandomOp(tree)
|
| 757 |
+
is_success_always_possible = true
|
| 758 |
+
elseif mutationChoice < cweights[6]
|
| 759 |
+
tree = simplifyTree(tree) # Sometimes we simplify tree
|
| 760 |
+
tree = combineOperators(tree) # See if repeated constants at outer levels
|
| 761 |
+
return PopMember(tree, beforeLoss)
|
| 762 |
+
|
| 763 |
+
is_success_always_possible = true
|
| 764 |
+
# Simplification shouldn't hurt complexity; unless some non-symmetric constraint
|
| 765 |
+
# to commutative operator...
|
| 766 |
+
|
| 767 |
+
elseif mutationChoice < cweights[7]
|
| 768 |
+
tree = genRandomTree(5) # Sometimes we generate a new tree completely tree
|
| 769 |
+
|
| 770 |
+
is_success_always_possible = true
|
| 771 |
+
else # no mutation applied
|
| 772 |
+
return PopMember(tree, beforeLoss)
|
| 773 |
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
|
| 775 |
+
# Check for illegal equations
|
| 776 |
+
for i=1:nbin
|
| 777 |
+
if successful_mutation && flagBinOperatorComplexity(tree, i)
|
| 778 |
+
successful_mutation = false
|
| 779 |
+
end
|
| 780 |
+
end
|
| 781 |
+
for i=1:nuna
|
| 782 |
+
if successful_mutation && flagUnaOperatorComplexity(tree, i)
|
| 783 |
+
successful_mutation = false
|
| 784 |
+
end
|
| 785 |
+
end
|
| 786 |
|
| 787 |
+
attempts += 1
|
|
|
|
|
|
|
| 788 |
end
|
| 789 |
+
#############################################
|
| 790 |
|
| 791 |
+
if !successful_mutation
|
| 792 |
+
return PopMember(copyNode(prev), beforeLoss)
|
| 793 |
+
end
|
| 794 |
|
| 795 |
if batching
|
| 796 |
afterLoss = scoreFuncBatch(tree)
|
pysr/sr.py
CHANGED
|
@@ -89,7 +89,8 @@ def pysr(X=None, y=None, weights=None,
|
|
| 89 |
batchSize=50,
|
| 90 |
select_k_features=None,
|
| 91 |
warmupMaxsize=0,
|
| 92 |
-
|
|
|
|
| 93 |
threads=None, #deprecated
|
| 94 |
julia_optimization=3,
|
| 95 |
):
|
|
@@ -166,9 +167,11 @@ def pysr(X=None, y=None, weights=None,
|
|
| 166 |
a small number up to the maxsize (if greater than 0).
|
| 167 |
If greater than 0, says how many cycles before the maxsize
|
| 168 |
is increased.
|
| 169 |
-
:param
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
| 173 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
| 174 |
(as strings).
|
|
@@ -176,6 +179,8 @@ def pysr(X=None, y=None, weights=None,
|
|
| 176 |
"""
|
| 177 |
if threads is not None:
|
| 178 |
raise ValueError("The threads kwarg is deprecated. Use procs.")
|
|
|
|
|
|
|
| 179 |
if maxdepth is None:
|
| 180 |
maxdepth = maxsize
|
| 181 |
|
|
@@ -207,6 +212,17 @@ def pysr(X=None, y=None, weights=None,
|
|
| 207 |
if populations is None:
|
| 208 |
populations = procs
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
|
| 211 |
|
| 212 |
if isinstance(binary_operators, str): binary_operators = [binary_operators]
|
|
@@ -247,7 +263,30 @@ def pysr(X=None, y=None, weights=None,
|
|
| 247 |
function_name = op[:first_non_char]
|
| 248 |
op_list[i] = function_name
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
def_hyperparams += f"""include("{pkg_directory}/operators.jl")
|
|
|
|
| 251 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
| 252 |
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
| 253 |
const ns=10;
|
|
|
|
| 89 |
batchSize=50,
|
| 90 |
select_k_features=None,
|
| 91 |
warmupMaxsize=0,
|
| 92 |
+
constraints={},
|
| 93 |
+
limitPowComplexity=False, #deprecated
|
| 94 |
threads=None, #deprecated
|
| 95 |
julia_optimization=3,
|
| 96 |
):
|
|
|
|
| 167 |
a small number up to the maxsize (if greater than 0).
|
| 168 |
If greater than 0, says how many cycles before the maxsize
|
| 169 |
is increased.
|
| 170 |
+
:param constraints: dict of int (unary) or 2-tuples (binary),
|
| 171 |
+
this enforces maxsize constraints on the individual
|
| 172 |
+
arguments of operators. E.g., `'pow': (-1, 1)`
|
| 173 |
+
says that power laws can have any complexity left argument, but only
|
| 174 |
+
1 complexity exponent. Use this to force more interpretable solutions.
|
| 175 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
| 176 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
| 177 |
(as strings).
|
|
|
|
| 179 |
"""
|
| 180 |
if threads is not None:
|
| 181 |
raise ValueError("The threads kwarg is deprecated. Use procs.")
|
| 182 |
+
if limitPowComplexity:
|
| 183 |
+
raise ValueError("The limitPowComplexity kwarg is deprecated. Use constraints.")
|
| 184 |
if maxdepth is None:
|
| 185 |
maxdepth = maxsize
|
| 186 |
|
|
|
|
| 212 |
if populations is None:
|
| 213 |
populations = procs
|
| 214 |
|
| 215 |
+
#arbitrary complexity by default
|
| 216 |
+
for op in unary_operators:
|
| 217 |
+
if op not in constraints:
|
| 218 |
+
constraints[op] = -1
|
| 219 |
+
for op in binary_operators:
|
| 220 |
+
if op not in constraints:
|
| 221 |
+
constraints[op] = (-1, -1)
|
| 222 |
+
if op in ['mult', 'plus', 'sub']:
|
| 223 |
+
if constraints[op][0] != constraints[op][1]:
|
| 224 |
+
raise NotImplementedError("You need equal constraints on both sides for +, -, and *, due to simplification strategies.")
|
| 225 |
+
|
| 226 |
rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
|
| 227 |
|
| 228 |
if isinstance(binary_operators, str): binary_operators = [binary_operators]
|
|
|
|
| 263 |
function_name = op[:first_non_char]
|
| 264 |
op_list[i] = function_name
|
| 265 |
|
| 266 |
+
constraints_str = "const una_constraints = ["
|
| 267 |
+
first = True
|
| 268 |
+
for op in unary_operators:
|
| 269 |
+
val = constraints[op]
|
| 270 |
+
if not first:
|
| 271 |
+
constraints_str += ", "
|
| 272 |
+
constraints_str += f"{val:d}"
|
| 273 |
+
first = False
|
| 274 |
+
|
| 275 |
+
constraints_str += """]
|
| 276 |
+
const bin_constraints = ["""
|
| 277 |
+
|
| 278 |
+
first = True
|
| 279 |
+
for op in binary_operators:
|
| 280 |
+
tup = constraints[op]
|
| 281 |
+
if not first:
|
| 282 |
+
constraints_str += ", "
|
| 283 |
+
constraints_str += f"({tup[0]:d}, {tup[1]:d})"
|
| 284 |
+
first = False
|
| 285 |
+
constraints_str += "]"
|
| 286 |
+
|
| 287 |
+
|
| 288 |
def_hyperparams += f"""include("{pkg_directory}/operators.jl")
|
| 289 |
+
{constraints_str}
|
| 290 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
| 291 |
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
| 292 |
const ns=10;
|