Spaces:
Running
Running
Add setting for plot update rate
Browse files- gui/app.py +91 -77
gui/app.py
CHANGED
|
@@ -5,6 +5,7 @@ import pandas as pd
|
|
| 5 |
import time
|
| 6 |
import multiprocessing as mp
|
| 7 |
from matplotlib import pyplot as plt
|
|
|
|
| 8 |
plt.ioff()
|
| 9 |
import tempfile
|
| 10 |
from typing import Optional, Union
|
|
@@ -18,9 +19,7 @@ empty_df = pd.DataFrame(
|
|
| 18 |
}
|
| 19 |
)
|
| 20 |
|
| 21 |
-
test_equations = [
|
| 22 |
-
"sin(2*x)/x + 0.1*x"
|
| 23 |
-
]
|
| 24 |
|
| 25 |
|
| 26 |
def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
|
|
@@ -52,7 +51,7 @@ def _greet_dispatch(
|
|
| 52 |
maxsize,
|
| 53 |
binary_operators,
|
| 54 |
unary_operators,
|
| 55 |
-
|
| 56 |
):
|
| 57 |
"""Load data, then spawn a process to run the greet function."""
|
| 58 |
if file_input is not None:
|
|
@@ -96,7 +95,6 @@ def _greet_dispatch(
|
|
| 96 |
maxsize=maxsize,
|
| 97 |
binary_operators=binary_operators,
|
| 98 |
unary_operators=unary_operators,
|
| 99 |
-
seed=seed,
|
| 100 |
equation_file=equation_file,
|
| 101 |
),
|
| 102 |
)
|
|
@@ -123,7 +121,10 @@ def _greet_dispatch(
|
|
| 123 |
bad_idx.append(i)
|
| 124 |
equations.drop(index=bad_idx, inplace=True)
|
| 125 |
|
| 126 |
-
while
|
|
|
|
|
|
|
|
|
|
| 127 |
time.sleep(0.1)
|
| 128 |
|
| 129 |
yield equations[["Complexity", "Loss", "Equation"]]
|
|
@@ -132,7 +133,6 @@ def _greet_dispatch(
|
|
| 132 |
except pd.errors.EmptyDataError:
|
| 133 |
pass
|
| 134 |
|
| 135 |
-
|
| 136 |
process.join()
|
| 137 |
|
| 138 |
|
|
@@ -144,7 +144,6 @@ def greet(
|
|
| 144 |
maxsize: int,
|
| 145 |
binary_operators: list,
|
| 146 |
unary_operators: list,
|
| 147 |
-
seed: int,
|
| 148 |
equation_file: Union[str, Path],
|
| 149 |
):
|
| 150 |
import pysr
|
|
@@ -180,7 +179,9 @@ def _data_layout():
|
|
| 180 |
label="Number of Data Points",
|
| 181 |
step=1,
|
| 182 |
)
|
| 183 |
-
noise_level = gr.Slider(
|
|
|
|
|
|
|
| 184 |
data_seed = gr.Number(value=0, label="Random Seed")
|
| 185 |
with gr.Tab("Upload Data"):
|
| 186 |
file_input = gr.File(label="Upload a CSV File")
|
|
@@ -199,55 +200,59 @@ def _data_layout():
|
|
| 199 |
|
| 200 |
|
| 201 |
def _settings_layout():
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
return dict(
|
| 245 |
binary_operators=binary_operators,
|
| 246 |
unary_operators=unary_operators,
|
| 247 |
niterations=niterations,
|
| 248 |
maxsize=maxsize,
|
| 249 |
force_run=force_run,
|
| 250 |
-
|
| 251 |
)
|
| 252 |
|
| 253 |
|
|
@@ -286,7 +291,7 @@ def main():
|
|
| 286 |
"maxsize",
|
| 287 |
"binary_operators",
|
| 288 |
"unary_operators",
|
| 289 |
-
"
|
| 290 |
]
|
| 291 |
],
|
| 292 |
outputs=blocks["df"],
|
|
@@ -302,7 +307,6 @@ def main():
|
|
| 302 |
for eqn_component in eqn_components:
|
| 303 |
eqn_component.change(replot, eqn_components, blocks["example_plot"])
|
| 304 |
|
| 305 |
-
|
| 306 |
# Update plot when dataframe is updated:
|
| 307 |
blocks["df"].change(
|
| 308 |
replot_pareto,
|
|
@@ -313,60 +317,70 @@ def main():
|
|
| 313 |
|
| 314 |
demo.launch(debug=True)
|
| 315 |
|
|
|
|
| 316 |
def replot_pareto(df, maxsize):
|
| 317 |
-
plt.rcParams[
|
| 318 |
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
| 319 |
|
| 320 |
-
if len(df) == 0 or
|
| 321 |
return fig
|
| 322 |
|
| 323 |
# Plotting the data
|
| 324 |
-
ax.loglog(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
# Set the axis limits
|
| 327 |
ax.set_xlim(0.5, maxsize + 1)
|
| 328 |
-
ytop = 2 ** (np.ceil(np.log2(df[
|
| 329 |
-
ybottom = 2 ** (np.floor(np.log2(df[
|
| 330 |
ax.set_ylim(ybottom, ytop)
|
| 331 |
|
| 332 |
-
ax.grid(True, which="both", ls="--", linewidth=0.5, color=
|
| 333 |
-
ax.spines[
|
| 334 |
-
ax.spines[
|
| 335 |
|
| 336 |
# Range-frame the plot
|
| 337 |
-
for direction in [
|
| 338 |
-
ax.spines[direction].set_position((
|
| 339 |
|
| 340 |
# Delete far ticks
|
| 341 |
-
ax.tick_params(axis=
|
| 342 |
-
ax.tick_params(axis=
|
| 343 |
|
| 344 |
-
ax.set_xlabel(
|
| 345 |
-
ax.set_ylabel(
|
| 346 |
fig.tight_layout(pad=2)
|
| 347 |
|
| 348 |
return fig
|
| 349 |
|
|
|
|
| 350 |
def replot(test_equation, num_points, noise_level, data_seed):
|
| 351 |
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
|
| 352 |
x = X["x"]
|
| 353 |
|
| 354 |
-
plt.rcParams[
|
| 355 |
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
| 356 |
|
| 357 |
-
ax.scatter(x, y, alpha=0.7, edgecolors=
|
| 358 |
|
| 359 |
-
ax.grid(True, which="both", ls="--", linewidth=0.5, color=
|
| 360 |
-
ax.spines[
|
| 361 |
-
ax.spines[
|
| 362 |
|
| 363 |
# Range-frame the plot
|
| 364 |
-
for direction in [
|
| 365 |
-
ax.spines[direction].set_position((
|
| 366 |
|
| 367 |
# Delete far ticks
|
| 368 |
-
ax.tick_params(axis=
|
| 369 |
-
ax.tick_params(axis=
|
| 370 |
|
| 371 |
ax.set_xlabel("x")
|
| 372 |
ax.set_ylabel("y")
|
|
|
|
| 5 |
import time
|
| 6 |
import multiprocessing as mp
|
| 7 |
from matplotlib import pyplot as plt
|
| 8 |
+
|
| 9 |
plt.ioff()
|
| 10 |
import tempfile
|
| 11 |
from typing import Optional, Union
|
|
|
|
| 19 |
}
|
| 20 |
)
|
| 21 |
|
| 22 |
+
test_equations = ["sin(2*x)/x + 0.1*x"]
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
|
|
|
|
| 51 |
maxsize,
|
| 52 |
binary_operators,
|
| 53 |
unary_operators,
|
| 54 |
+
plot_update_delay,
|
| 55 |
):
|
| 56 |
"""Load data, then spawn a process to run the greet function."""
|
| 57 |
if file_input is not None:
|
|
|
|
| 95 |
maxsize=maxsize,
|
| 96 |
binary_operators=binary_operators,
|
| 97 |
unary_operators=unary_operators,
|
|
|
|
| 98 |
equation_file=equation_file,
|
| 99 |
),
|
| 100 |
)
|
|
|
|
| 121 |
bad_idx.append(i)
|
| 122 |
equations.drop(index=bad_idx, inplace=True)
|
| 123 |
|
| 124 |
+
while (
|
| 125 |
+
last_yield_time is not None
|
| 126 |
+
and time.time() - last_yield_time < plot_update_delay
|
| 127 |
+
):
|
| 128 |
time.sleep(0.1)
|
| 129 |
|
| 130 |
yield equations[["Complexity", "Loss", "Equation"]]
|
|
|
|
| 133 |
except pd.errors.EmptyDataError:
|
| 134 |
pass
|
| 135 |
|
|
|
|
| 136 |
process.join()
|
| 137 |
|
| 138 |
|
|
|
|
| 144 |
maxsize: int,
|
| 145 |
binary_operators: list,
|
| 146 |
unary_operators: list,
|
|
|
|
| 147 |
equation_file: Union[str, Path],
|
| 148 |
):
|
| 149 |
import pysr
|
|
|
|
| 179 |
label="Number of Data Points",
|
| 180 |
step=1,
|
| 181 |
)
|
| 182 |
+
noise_level = gr.Slider(
|
| 183 |
+
minimum=0, maximum=1, value=0.05, label="Noise Level"
|
| 184 |
+
)
|
| 185 |
data_seed = gr.Number(value=0, label="Random Seed")
|
| 186 |
with gr.Tab("Upload Data"):
|
| 187 |
file_input = gr.File(label="Upload a CSV File")
|
|
|
|
| 200 |
|
| 201 |
|
| 202 |
def _settings_layout():
|
| 203 |
+
with gr.Tab("Basic Settings"):
|
| 204 |
+
binary_operators = gr.CheckboxGroup(
|
| 205 |
+
choices=["+", "-", "*", "/", "^"],
|
| 206 |
+
label="Binary Operators",
|
| 207 |
+
value=["+", "-", "*", "/"],
|
| 208 |
+
)
|
| 209 |
+
unary_operators = gr.CheckboxGroup(
|
| 210 |
+
choices=[
|
| 211 |
+
"sin",
|
| 212 |
+
"cos",
|
| 213 |
+
"exp",
|
| 214 |
+
"log",
|
| 215 |
+
"square",
|
| 216 |
+
"cube",
|
| 217 |
+
"sqrt",
|
| 218 |
+
"abs",
|
| 219 |
+
"tan",
|
| 220 |
+
],
|
| 221 |
+
label="Unary Operators",
|
| 222 |
+
value=["sin"],
|
| 223 |
+
)
|
| 224 |
+
niterations = gr.Slider(
|
| 225 |
+
minimum=1,
|
| 226 |
+
maximum=1000,
|
| 227 |
+
value=40,
|
| 228 |
+
label="Number of Iterations",
|
| 229 |
+
step=1,
|
| 230 |
+
)
|
| 231 |
+
maxsize = gr.Slider(
|
| 232 |
+
minimum=7,
|
| 233 |
+
maximum=35,
|
| 234 |
+
value=20,
|
| 235 |
+
label="Maximum Complexity",
|
| 236 |
+
step=1,
|
| 237 |
+
)
|
| 238 |
+
force_run = gr.Checkbox(
|
| 239 |
+
value=False,
|
| 240 |
+
label="Ignore Warnings",
|
| 241 |
+
)
|
| 242 |
+
with gr.Tab("Gradio Settings"):
|
| 243 |
+
plot_update_delay = gr.Slider(
|
| 244 |
+
minimum=1,
|
| 245 |
+
maximum=100,
|
| 246 |
+
value=3,
|
| 247 |
+
label="Plot Update Delay",
|
| 248 |
+
)
|
| 249 |
return dict(
|
| 250 |
binary_operators=binary_operators,
|
| 251 |
unary_operators=unary_operators,
|
| 252 |
niterations=niterations,
|
| 253 |
maxsize=maxsize,
|
| 254 |
force_run=force_run,
|
| 255 |
+
plot_update_delay=plot_update_delay,
|
| 256 |
)
|
| 257 |
|
| 258 |
|
|
|
|
| 291 |
"maxsize",
|
| 292 |
"binary_operators",
|
| 293 |
"unary_operators",
|
| 294 |
+
"plot_update_delay",
|
| 295 |
]
|
| 296 |
],
|
| 297 |
outputs=blocks["df"],
|
|
|
|
| 307 |
for eqn_component in eqn_components:
|
| 308 |
eqn_component.change(replot, eqn_components, blocks["example_plot"])
|
| 309 |
|
|
|
|
| 310 |
# Update plot when dataframe is updated:
|
| 311 |
blocks["df"].change(
|
| 312 |
replot_pareto,
|
|
|
|
| 317 |
|
| 318 |
demo.launch(debug=True)
|
| 319 |
|
| 320 |
+
|
| 321 |
def replot_pareto(df, maxsize):
|
| 322 |
+
plt.rcParams["font.family"] = "IBM Plex Mono"
|
| 323 |
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
| 324 |
|
| 325 |
+
if len(df) == 0 or "Equation" not in df.columns:
|
| 326 |
return fig
|
| 327 |
|
| 328 |
# Plotting the data
|
| 329 |
+
ax.loglog(
|
| 330 |
+
df["Complexity"],
|
| 331 |
+
df["Loss"],
|
| 332 |
+
marker="o",
|
| 333 |
+
linestyle="-",
|
| 334 |
+
color="#333f48",
|
| 335 |
+
linewidth=1.5,
|
| 336 |
+
markersize=6,
|
| 337 |
+
)
|
| 338 |
|
| 339 |
# Set the axis limits
|
| 340 |
ax.set_xlim(0.5, maxsize + 1)
|
| 341 |
+
ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
|
| 342 |
+
ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
|
| 343 |
ax.set_ylim(ybottom, ytop)
|
| 344 |
|
| 345 |
+
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
| 346 |
+
ax.spines["top"].set_visible(False)
|
| 347 |
+
ax.spines["right"].set_visible(False)
|
| 348 |
|
| 349 |
# Range-frame the plot
|
| 350 |
+
for direction in ["bottom", "left"]:
|
| 351 |
+
ax.spines[direction].set_position(("outward", 10))
|
| 352 |
|
| 353 |
# Delete far ticks
|
| 354 |
+
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
|
| 355 |
+
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|
| 356 |
|
| 357 |
+
ax.set_xlabel("Complexity")
|
| 358 |
+
ax.set_ylabel("Loss")
|
| 359 |
fig.tight_layout(pad=2)
|
| 360 |
|
| 361 |
return fig
|
| 362 |
|
| 363 |
+
|
| 364 |
def replot(test_equation, num_points, noise_level, data_seed):
|
| 365 |
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
|
| 366 |
x = X["x"]
|
| 367 |
|
| 368 |
+
plt.rcParams["font.family"] = "IBM Plex Mono"
|
| 369 |
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
| 370 |
|
| 371 |
+
ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
|
| 372 |
|
| 373 |
+
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
| 374 |
+
ax.spines["top"].set_visible(False)
|
| 375 |
+
ax.spines["right"].set_visible(False)
|
| 376 |
|
| 377 |
# Range-frame the plot
|
| 378 |
+
for direction in ["bottom", "left"]:
|
| 379 |
+
ax.spines[direction].set_position(("outward", 10))
|
| 380 |
|
| 381 |
# Delete far ticks
|
| 382 |
+
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
|
| 383 |
+
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|
| 384 |
|
| 385 |
ax.set_xlabel("x")
|
| 386 |
ax.set_ylabel("y")
|