Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from test_functions.Ackley10D import * | |
| from test_functions.Ackley2D import * | |
| from test_functions.Ackley6D import * | |
| from test_functions.HeatExchanger import * | |
| from test_functions.CantileverBeam import * | |
| from test_functions.Car import * | |
| from test_functions.CompressionSpring import * | |
| from test_functions.GKXWC1 import * | |
| from test_functions.GKXWC2 import * | |
| from test_functions.HeatExchanger import * | |
| from test_functions.JLH1 import * | |
| from test_functions.JLH2 import * | |
| from test_functions.KeaneBump import * | |
| from test_functions.GKXWC1 import * | |
| from test_functions.GKXWC2 import * | |
| from test_functions.PressureVessel import * | |
| from test_functions.ReinforcedConcreteBeam import * | |
| from test_functions.SpeedReducer import * | |
| from test_functions.ThreeTruss import * | |
| from test_functions.WeldedBeam import * | |
| # Import other objective functions as needed | |
| import time | |
| from Rosen_PFN4BO import * | |
| from PIL import Image | |
| def s(input_string): | |
| return input_string | |
| def optimize(objective_function, iteration_input, progress=gr.Progress()): | |
| # print(objective_function) | |
| # Variable setup | |
| Current_BEST = torch.tensor( -1e10 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -1e10 ) | |
| if objective_function=="CantileverBeam.png": | |
| Current_BEST = torch.tensor( -82500 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -82500 ) | |
| elif objective_function=="CompressionSpring.png": | |
| Current_BEST = torch.tensor( -8 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -8 ) | |
| elif objective_function=="HeatExchanger.png": | |
| Current_BEST = torch.tensor( -30000 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -30000 ) | |
| elif objective_function=="ThreeTruss.png": | |
| Current_BEST = torch.tensor( -300 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -300 ) | |
| elif objective_function=="Reinforcement.png": | |
| Current_BEST = torch.tensor( -440 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -440 ) | |
| elif objective_function=="PressureVessel.png": | |
| Current_BEST = torch.tensor( -40000 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -40000 ) | |
| elif objective_function=="SpeedReducer.png": | |
| Current_BEST = torch.tensor( -3200 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -3200 ) | |
| elif objective_function=="WeldedBeam.png": | |
| Current_BEST = torch.tensor( -35 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -35 ) | |
| elif objective_function=="Car.png": | |
| Current_BEST = torch.tensor( -35 ) # Some arbitrary very small number | |
| Prev_BEST = torch.tensor( -35 ) | |
| # Initial random samples | |
| # print(objective_functions) | |
| trained_X = torch.rand(20, objective_functions[objective_function]['dim']) | |
| # Scale it to the domain of interest using the selected function | |
| # print(objective_function) | |
| X_Scaled = objective_functions[objective_function]['scaling'](trained_X) | |
| # Get the constraints and objective | |
| trained_gx, trained_Y = objective_functions[objective_function]['function'](X_Scaled) | |
| # Convergence list to store best values | |
| convergence = [] | |
| time_conv = [] | |
| START_TIME = time.time() | |
| # with gr.Progress(track_tqdm=True) as progress: | |
| # Optimization Loop | |
| for ii in progress.tqdm(range(iteration_input)): # Example with 100 iterations | |
| # (0) Get the updated data for this iteration | |
| X_scaled = objective_functions[objective_function]['scaling'](trained_X) | |
| trained_gx, trained_Y = objective_functions[objective_function]['function'](X_scaled) | |
| # (1) Randomly sample Xpen | |
| X_pen = torch.rand(1000,trained_X.shape[1]) | |
| # (2) PFN inference phase with EI | |
| default_model = 'final_models/model_hebo_morebudget_9_unused_features_3.pt' | |
| ei, p_feas = Rosen_PFN_Parallel(default_model, | |
| trained_X, | |
| trained_Y, | |
| trained_gx, | |
| X_pen, | |
| 'power', | |
| 'ei' | |
| ) | |
| # Calculating CEI | |
| CEI = ei | |
| for jj in range(p_feas.shape[1]): | |
| CEI = CEI*p_feas[:,jj] | |
| # (4) Get the next search value | |
| rec_idx = torch.argmax(CEI) | |
| best_candidate = X_pen[rec_idx,:].unsqueeze(0) | |
| # (5) Append the next search point | |
| trained_X = torch.cat([trained_X, best_candidate]) | |
| ################################################################################ | |
| # This is just for visualizing the best value. | |
| # This section can be remove for pure optimization purpose | |
| Current_X = objective_functions[objective_function]['scaling'](trained_X) | |
| Current_GX, Current_Y = objective_functions[objective_function]['function'](Current_X) | |
| if ((Current_GX<=0).all(dim=1)).any(): | |
| Current_BEST = torch.max(Current_Y[(Current_GX<=0).all(dim=1)]) | |
| else: | |
| Current_BEST = Prev_BEST | |
| ################################################################################ | |
| # (ii) Convergence tracking (assuming the best Y is to be maximized) | |
| # if Current_BEST != -1e10: | |
| # print(Current_BEST) | |
| # print(convergence) | |
| convergence.append(Current_BEST.abs()) | |
| time_conv.append(time.time() - START_TIME) | |
| # Timing | |
| END_TIME = time.time() | |
| TOTAL_TIME = END_TIME - START_TIME | |
| # Website visualization | |
| # (i) Radar chart for trained_X | |
| radar_chart = None | |
| # radar_chart = create_radar_chart(X_scaled) | |
| # (ii) Convergence tracking (assuming the best Y is to be maximized) | |
| convergence_plot = create_convergence_plot(objective_function, iteration_input, | |
| time_conv, | |
| convergence, TOTAL_TIME) | |
| return convergence_plot | |
| # return radar_chart, convergence_plot | |
| def create_radar_chart(X_scaled): | |
| fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) | |
| labels = [f'x{i+1}' for i in range(X_scaled.shape[1])] | |
| values = X_scaled.mean(dim=0).numpy() | |
| num_vars = len(labels) | |
| angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() | |
| values = np.concatenate((values, [values[0]])) | |
| angles += angles[:1] | |
| ax.fill(angles, values, color='green', alpha=0.25) | |
| ax.plot(angles, values, color='green', linewidth=2) | |
| ax.set_yticklabels([]) | |
| ax.set_xticks(angles[:-1]) | |
| # ax.set_xticklabels(labels) | |
| ax.set_xticklabels([f'{label}\n({value:.2f})' for label, value in zip(labels, values[:-1])]) # Show values | |
| ax.set_title("Selected Design", size=15, color='black', y=1.1) | |
| plt.close(fig) | |
| return fig | |
| def create_convergence_plot(objective_function, iteration_input, time_conv, convergence, TOTAL_TIME): | |
| fig, ax = plt.subplots() | |
| # Realtime optimization data | |
| ax.plot(time_conv, convergence, '^-', label='PFN-CBO (Realtime)' ) | |
| # Stored GP data | |
| if objective_function=="CantileverBeam.png": | |
| GP_TIME = torch.load('CantileverBeam_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('CantileverBeam_CEI_Avg_Obj.pt') | |
| elif objective_function=="CompressionSpring.png": | |
| GP_TIME = torch.load('CompressionSpring_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('CompressionSpring_CEI_Avg_Obj.pt') | |
| elif objective_function=="HeatExchanger.png": | |
| GP_TIME = torch.load('HeatExchanger_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('HeatExchanger_CEI_Avg_Obj.pt') | |
| elif objective_function=="ThreeTruss.png": | |
| GP_TIME = torch.load('ThreeTruss_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('ThreeTruss_CEI_Avg_Obj.pt') | |
| elif objective_function=="Reinforcement.png": | |
| GP_TIME = torch.load('ReinforcedConcreteBeam_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('ReinforcedConcreteBeam_CEI_Avg_Obj.pt') | |
| elif objective_function=="PressureVessel.png": | |
| GP_TIME = torch.load('PressureVessel_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('PressureVessel_CEI_Avg_Obj.pt') | |
| elif objective_function=="SpeedReducer.png": | |
| GP_TIME = torch.load('SpeedReducer_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('SpeedReducer_CEI_Avg_Obj.pt') | |
| elif objective_function=="WeldedBeam.png": | |
| GP_TIME = torch.load('WeldedBeam_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('WeldedBeam_CEI_Avg_Obj.pt') | |
| elif objective_function=="Car.png": | |
| GP_TIME = torch.load('Car_CEI_Avg_Time.pt') | |
| GP_OBJ = torch.load('Car_CEI_Avg_Obj.pt') | |
| # Plot GP data | |
| ax.plot(GP_TIME[:iteration_input], GP_OBJ[:iteration_input], '^-', label='GP-CBO (Data)' ) | |
| ax.set_xlabel('Time (seconds)') | |
| ax.set_ylabel('Objective Value (Minimization)') | |
| ax.set_title('Convergence Plot for {t} iterations'.format(t=iteration_input)) | |
| # ax.legend() | |
| if objective_function=="CantileverBeam.png": | |
| ax.axhline(y=50000, color='red', linestyle='--', label='Optimal Value') | |
| elif objective_function=="CompressionSpring.png": | |
| ax.axhline(y=0, color='red', linestyle='--', label='Optimal Value') | |
| elif objective_function=="HeatExchanger.png": | |
| ax.axhline(y=4700, color='red', linestyle='--', label='Optimal Value') | |
| elif objective_function=="ThreeTruss.png": | |
| ax.axhline(y=262, color='red', linestyle='--', label='Optimal Value') | |
| elif objective_function=="Reinforcement.png": | |
| ax.axhline(y=355, color='red', linestyle='--', label='Optimal Value') | |
| elif objective_function=="PressureVessel.png": | |
| ax.axhline(y=5000, color='red', linestyle='--', label='Optimal Value') | |
| elif objective_function=="SpeedReducer.png": | |
| ax.axhline(y=2650, color='red', linestyle='--', label='Optimal Value') | |
| elif objective_function=="WeldedBeam.png": | |
| ax.axhline(y=3.3, color='red', linestyle='--', label='Optimal Value') | |
| elif objective_function=="Car.png": | |
| ax.axhline(y=25, color='red', linestyle='--', label='Optimal Value') | |
| ax.legend(loc='best') | |
| # ax.legend(loc='lower left') | |
| # Add text to the top right corner of the plot | |
| if len(convergence) == 0: | |
| ax.text(0.5, 0.5, 'No Feasible Design Found', transform=ax.transAxes, fontsize=12, | |
| verticalalignment='top', horizontalalignment='right') | |
| plt.close(fig) | |
| return fig | |
| # Define available objective functions | |
| objective_functions = { | |
| # "ThreeTruss.png": {"image": "ThreeTruss.png", | |
| # "function": ThreeTruss, | |
| # "scaling": ThreeTruss_Scaling, | |
| # "dim": 2}, | |
| "CompressionSpring.png": {"image": "CompressionSpring.png", | |
| "function": CompressionSpring, | |
| "scaling": CompressionSpring_Scaling, | |
| "dim": 3}, | |
| "Reinforcement.png": {"image": "Reinforcement.png", "function": ReinforcedConcreteBeam, "scaling": ReinforcedConcreteBeam_Scaling, "dim": 3}, | |
| "PressureVessel.png": {"image": "PressureVessel.png", "function": PressureVessel, "scaling": PressureVessel_Scaling, "dim": 4}, | |
| "SpeedReducer.png": {"image": "SpeedReducer.png", "function": SpeedReducer, "scaling": SpeedReducer_Scaling, "dim": 7}, | |
| "WeldedBeam.png": {"image": "WeldedBeam.png", "function": WeldedBeam, "scaling": WeldedBeam_Scaling, "dim": 4}, | |
| "HeatExchanger.png": {"image": "HeatExchanger.png", "function": HeatExchanger, "scaling": HeatExchanger_Scaling, "dim": 8}, | |
| "CantileverBeam.png": {"image": "CantileverBeam.png", "function": CantileverBeam, "scaling": CantileverBeam_Scaling, "dim": 10}, | |
| "Car.png": {"image": "Car.png", "function": Car, "scaling": Car_Scaling, "dim": 11}, | |
| } | |
| # Extract just the image paths for the gallery | |
| image_paths = [key for key in objective_functions] | |
| def submit_action(objective_function_choices, iteration_input): | |
| # print(iteration_input) | |
| # print(len(objective_function_choices)) | |
| # print(objective_functions[objective_function_choices]['function']) | |
| if len(objective_function_choices)>0: | |
| selected_function = objective_functions[objective_function_choices]['function'] | |
| return optimize(objective_function_choices, iteration_input) | |
| return None | |
| # Function to clear the output | |
| def clear_output(): | |
| # print(gallery.selected_index) | |
| return gr.update(value=[], selected=None), None, 15, gr.Markdown(""), 'Formulation_default.png' | |
| def reset_gallery(): | |
| return gr.update(value=image_paths) | |
| with gr.Blocks() as demo: | |
| # Centered Title and Description using gr.HTML | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center;"> | |
| <p style="text-align: center; font-size:30px;"><b> | |
| Constrained Bayesian Optimization with Pre-trained Transformers | |
| </b></p> | |
| <p style="text-align: center; font-size:18px;"><b> | |
| Paper: <a href="https://arxiv.org/abs/2404.04495"> | |
| Fast and Accurate Bayesian Optimization with Pre-trained Transformers for Constrained Engineering Problems</a> | |
| </b></p> | |
| <p style="text-align: left;font-size:18px;"> | |
| Explore our interactive demo that uses PFN (Prior-Data Fitted Networks) for solving constrained Bayesian optimization problems! | |
| </p> | |
| <p style="text-align: left;font-size:24px;"><b> | |
| Get Started: | |
| </b> </p> | |
| <p style="text-align: left;font-size:18px;"> | |
| <ol style="text-align: left;font-size:18px;text-indent: 30px;"> | |
| <li> <b>Select a Problem:</b> Click on an image from the problem gallery to choose your objective function. </li> | |
| <li> <b>Set Iterations:</b> Adjust the slider to set the number of iterations for the optimization process. </li> | |
| <li> <b>Run Optimization:</b> Click "Submit" to start the optimization. Use "Clear" if you need to reselect your parameters. </li> | |
| </ol> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <p style="text-align: left;font-size:24px;"><b> | |
| Result Display: | |
| </b> </p> | |
| <p style="text-align: left;font-size:18px;"> | |
| <ol style="text-align: left;font-size:18px;text-indent: 30px;"> | |
| <li> <b>Panel Display:</b> Shows the problem formulation and the optimization results. </li> | |
| <li> <b>Convergence Plot:</b> Visualizes the best observed objective against the algorithm's runtime over the chosen iterations. </li> | |
| <ul> | |
| <li> <b>PFN-CBO:</b> Displays results from real-time optimization. </li> | |
| <li> <b>GP-CBO:</b> Provides pre-computed data from our past experiments, as GP real-time runs are impractical for a demo. </li> | |
| </ul> | |
| </ol> | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(variant='compact'): | |
| # gr.Markdown("# Inputs: ") | |
| with gr.Row(): | |
| gr.Markdown("## Select a problem (objective): ") | |
| img_key = gr.Markdown(value="", visible=False) | |
| gallery = gr.Gallery(value=image_paths, label="Objectives", | |
| # height = 450, | |
| object_fit='contain', | |
| columns=3, rows=3, elem_id="gallery") | |
| gr.Markdown("## Enter iteration Number: ") | |
| iteration_input = gr.Slider(label="Iterations:", minimum=15, maximum=50, step=1, value=15) | |
| # Row for the Clear and Submit buttons | |
| with gr.Row(): | |
| clear_button = gr.Button("Clear") | |
| submit_button = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| # gr.Markdown("# Outputs: ") | |
| gr.Markdown(""" | |
| ## Convergence Plot: | |
| """) | |
| convergence_plot = gr.Plot(label="Convergence Plot") | |
| gr.Markdown("") | |
| gr.Markdown("## Problem formulation: ") | |
| formulation = gr.Image(value='Formulation_default.png', label="Eq") | |
| def handle_select(evt: gr.SelectData): | |
| selected_image = evt.value | |
| key = evt.value['image']['orig_name'] | |
| if key=="CantileverBeam.png": | |
| formulation = 'Cantilever_formulation.png' | |
| elif key=="CompressionSpring.png": | |
| formulation = 'Compressed_Formulation.png' | |
| elif key=="HeatExchanger.png": | |
| formulation = 'Heat_Formulation.png' | |
| elif key=="Reinforcement.png": | |
| formulation = 'Reinforce_Formulation.png' | |
| elif key=="PressureVessel.png": | |
| formulation = 'Pressure_Formulation.png' | |
| elif key=="SpeedReducer.png": | |
| formulation = 'Speed_Formulation.png' | |
| elif key=="WeldedBeam.png": | |
| formulation = 'Welded_Formulation.png' | |
| elif key=="Car.png": | |
| formulation = 'Car_Formulation_2.png' | |
| # formulation = 'Test_formulation.png' | |
| # print('here') | |
| # print(key) | |
| return key, formulation | |
| gallery.select(fn=handle_select, inputs=None, outputs=[img_key, formulation]) | |
| submit_button.click( | |
| submit_action, | |
| inputs=[img_key, iteration_input], | |
| # outputs= [radar_plot, convergence_plot], | |
| outputs= convergence_plot, | |
| # progress=True # Enable progress tracking | |
| ) | |
| clear_button.click( | |
| clear_output, | |
| inputs=None, | |
| outputs=[gallery, convergence_plot, iteration_input, img_key, formulation] | |
| ).then( | |
| # Step 2: Reset the gallery to the original list | |
| reset_gallery, | |
| inputs=None, | |
| outputs=gallery | |
| ) | |
| demo.launch(share=True) |