Spaces:
Running
Running
| from collections import OrderedDict | |
| import gradio as gr | |
| import numpy as np | |
| from data import TEST_EQUATIONS | |
| from gradio.components.base import Component | |
| from plots import plot_example_data, plot_pareto_curve | |
| from processing import processing, stop | |
| class ExampleData: | |
| def __init__(self, demo: gr.Blocks) -> None: | |
| with gr.Column(scale=1): | |
| self.example_plot = gr.Plot() | |
| with gr.Column(scale=1): | |
| self.test_equation = gr.Radio( | |
| TEST_EQUATIONS, value=TEST_EQUATIONS[0], label="Test Equation" | |
| ) | |
| self.num_points = gr.Slider( | |
| minimum=10, | |
| maximum=1000, | |
| value=200, | |
| label="Number of Data Points", | |
| step=1, | |
| ) | |
| self.noise_level = gr.Slider( | |
| minimum=0, maximum=1, value=0.05, label="Noise Level" | |
| ) | |
| self.data_seed = gr.Number(value=0, label="Random Seed") | |
| # Set up plotting: | |
| eqn_components = [ | |
| self.test_equation, | |
| self.num_points, | |
| self.noise_level, | |
| self.data_seed, | |
| ] | |
| for eqn_component in eqn_components: | |
| eqn_component.change( | |
| plot_example_data, | |
| eqn_components, | |
| self.example_plot, | |
| show_progress=False, | |
| ) | |
| demo.load(plot_example_data, eqn_components, self.example_plot) | |
| class UploadData: | |
| def __init__(self) -> None: | |
| self.file_input = gr.File(label="Upload a CSV File") | |
| self.label = gr.Markdown( | |
| "The rightmost column of your CSV file will be used as the target variable." | |
| ) | |
| class Data: | |
| def __init__(self, demo: gr.Blocks) -> None: | |
| with gr.Tab("Example Data"): | |
| self.example_data = ExampleData(demo) | |
| with gr.Tab("Upload Data"): | |
| self.upload_data = UploadData() | |
| class BasicSettings: | |
| def __init__(self) -> None: | |
| self.binary_operators = gr.CheckboxGroup( | |
| choices=["+", "-", "*", "/", "^", "max", "min", "mod", "cond"], | |
| label="Binary Operators", | |
| value=["+", "-", "*", "/"], | |
| ) | |
| self.unary_operators = gr.CheckboxGroup( | |
| choices=[ | |
| "sin", | |
| "cos", | |
| "tan", | |
| "exp", | |
| "log", | |
| "square", | |
| "cube", | |
| "sqrt", | |
| "abs", | |
| "erf", | |
| "relu", | |
| "round", | |
| "sign", | |
| ], | |
| label="Unary Operators", | |
| value=["sin"], | |
| ) | |
| self.niterations = gr.Slider( | |
| minimum=1, | |
| maximum=1000, | |
| value=40, | |
| label="Number of Iterations", | |
| step=1, | |
| ) | |
| self.maxsize = gr.Slider( | |
| minimum=7, | |
| maximum=100, | |
| value=20, | |
| label="Maximum Complexity", | |
| step=1, | |
| ) | |
| self.parsimony = gr.Number( | |
| value=0.0032, | |
| label="Parsimony Coefficient", | |
| ) | |
| class AdvancedSettings: | |
| def __init__(self) -> None: | |
| self.populations = gr.Slider( | |
| minimum=2, | |
| maximum=100, | |
| value=15, | |
| label="Number of Populations", | |
| step=1, | |
| ) | |
| self.population_size = gr.Slider( | |
| minimum=2, | |
| maximum=1000, | |
| value=33, | |
| label="Population Size", | |
| step=1, | |
| ) | |
| self.ncycles_per_iteration = gr.Number( | |
| value=550, | |
| label="Cycles per Iteration", | |
| ) | |
| self.elementwise_loss = gr.Radio( | |
| ["L2DistLoss()", "L1DistLoss()", "LogitDistLoss()", "HuberLoss()"], | |
| value="L2DistLoss()", | |
| label="Loss Function", | |
| ) | |
| self.adaptive_parsimony_scaling = gr.Number( | |
| value=20.0, | |
| label="Adaptive Parsimony Scaling", | |
| ) | |
| self.optimizer_algorithm = gr.Radio( | |
| ["BFGS", "NelderMead"], | |
| value="BFGS", | |
| label="Optimizer Algorithm", | |
| ) | |
| self.optimizer_iterations = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=8, | |
| label="Optimizer Iterations", | |
| step=1, | |
| ) | |
| self.batching = gr.Checkbox( | |
| value=False, | |
| label="Batching", | |
| ) | |
| self.batch_size = gr.Slider( | |
| minimum=2, | |
| maximum=1000, | |
| value=50, | |
| label="Batch Size", | |
| step=1, | |
| ) | |
| class GradioSettings: | |
| def __init__(self) -> None: | |
| self.plot_update_delay = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=3, | |
| label="Plot Update Delay", | |
| ) | |
| self.force_run = gr.Checkbox( | |
| value=False, | |
| label="Ignore Warnings", | |
| ) | |
| class Settings: | |
| def __init__(self): | |
| with gr.Tab("Basic Settings"): | |
| self.basic_settings = BasicSettings() | |
| with gr.Tab("Advanced Settings"): | |
| self.advanced_settings = AdvancedSettings() | |
| with gr.Tab("Gradio Settings"): | |
| self.gradio_settings = GradioSettings() | |
| class Results: | |
| def __init__(self): | |
| with gr.Tab("Pareto Front"): | |
| self.pareto = gr.Plot() | |
| with gr.Tab("Predictions"): | |
| self.predictions_plot = gr.Plot() | |
| self.df = gr.Dataframe( | |
| headers=["complexity", "loss", "equation"], | |
| datatype=["number", "number", "str"], | |
| wrap=True, | |
| column_widths=[75, 75, 200], | |
| interactive=False, | |
| ) | |
| self.messages = gr.Textbox(label="Messages", value="", interactive=False) | |
| def flatten_attributes( | |
| component_group, absolute_name: str, d: OrderedDict | |
| ) -> OrderedDict: | |
| if not hasattr(component_group, "__dict__"): | |
| return d | |
| for name, elem in component_group.__dict__.items(): | |
| new_absolute_name = absolute_name + "." + name | |
| if name.startswith("_"): | |
| # Private attribute | |
| continue | |
| elif elem in d.values(): | |
| # Don't duplicate any tiems | |
| continue | |
| elif isinstance(elem, Component): | |
| # Only add components to dict | |
| d[new_absolute_name] = elem | |
| else: | |
| flatten_attributes(elem, new_absolute_name, d) | |
| return d | |
| class AppInterface: | |
| def __init__(self, demo: gr.Blocks) -> None: | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| self.data = Data(demo) | |
| with gr.Row(): | |
| self.settings = Settings() | |
| with gr.Column(scale=2): | |
| self.results = Results() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| self.stop = gr.Button(value="Stop") | |
| with gr.Column(scale=1, min_width=200): | |
| self.run = gr.Button() | |
| # Update plot when dataframe is updated: | |
| self.results.df.change( | |
| plot_pareto_curve, | |
| inputs=[self.results.df, self.settings.basic_settings.maxsize], | |
| outputs=[self.results.pareto], | |
| show_progress=False, | |
| ) | |
| ignore = ["df", "predictions_plot", "pareto", "messages"] | |
| self.run.click( | |
| create_processing_function(self, ignore=ignore), | |
| inputs=[ | |
| v | |
| for k, v in flatten_attributes(self, "interface", OrderedDict()).items() | |
| if last_part(k) not in ignore | |
| ], | |
| outputs=[ | |
| self.results.df, | |
| self.results.predictions_plot, | |
| self.results.messages, | |
| ], | |
| show_progress=True, | |
| ) | |
| self.stop.click(stop) | |
| def last_part(k: str) -> str: | |
| return k.split(".")[-1] | |
| def create_processing_function(interface: AppInterface, ignore=[]): | |
| d = flatten_attributes(interface, "interface", OrderedDict()) | |
| keys = [k for k in map(last_part, d.keys()) if k not in ignore] | |
| _, idx, counts = np.unique(keys, return_index=True, return_counts=True) | |
| if np.any(counts > 1): | |
| raise AssertionError("Bad keys: " + ",".join(np.array(keys)[idx[counts > 1]])) | |
| def f(*components): | |
| n = len(components) | |
| assert n == len(keys) | |
| for output in processing(**{keys[i]: components[i] for i in range(n)}): | |
| yield output | |
| return f | |
| def main(): | |
| with gr.Blocks(theme="default") as demo: | |
| _ = AppInterface(demo) | |
| demo.launch(debug=True) | |
| if __name__ == "__main__": | |
| main() | |