Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # author: Martin Fajčík | |
| # modified by: Jan Doležal | |
| import csv | |
| import random | |
| import numpy as np | |
| from bokeh.plotting import figure | |
| from bokeh.models import LabelSet, LogScale, ColumnDataSource, tickers | |
| from bokeh.models import LinearColorMapper, HoverTool | |
| from bokeh.models import CustomJS | |
| from bokeh.palettes import Turbo256 # A color palette with enough colors | |
| def bokeh2html(obj): | |
| from bokeh.embed import components | |
| from bokeh.resources import CDN | |
| script, div = components(obj, CDN) | |
| bokeh_html = f"{CDN.render()}\n{div}\n{script}" | |
| return bokeh_html | |
| def bokeh2fullhtml(obj): | |
| from bokeh.embed import components | |
| from bokeh.resources import CDN | |
| script, div = components(obj, CDN) | |
| bokeh_html = f"""<!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| {CDN.render()} | |
| <style> | |
| .spinner {{ | |
| padding-top: 50px; | |
| padding-left: 50px; | |
| position: absolute; | |
| font-size: 20px; | |
| }} | |
| @keyframes blink {{ | |
| 0%,100% {{opacity:1;}} 50% {{opacity:0.3;}} | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div id="spinner" class="spinner">⌛ Loading...</div> | |
| {div} | |
| {script} | |
| </body> | |
| </html>""" | |
| return bokeh_html | |
| def bokeh2iframe(obj, height=820): | |
| import html | |
| srcdoc = bokeh2fullhtml(obj) | |
| srcdoc = html.escape(srcdoc) | |
| return f''' | |
| <div | |
| style=" | |
| width: 100%; | |
| height: {height}px; | |
| resize: vertical; | |
| overflow: hidden; | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: var(--block-radius); | |
| " | |
| > | |
| <iframe | |
| srcdoc="{srcdoc}" | |
| style=" | |
| width: 100%; | |
| height: 100%; | |
| " | |
| ></iframe> | |
| </div> | |
| ''' | |
| def bokeh2json(obj): | |
| from bokeh.document import Document | |
| doc = Document() | |
| doc.add_root(obj) | |
| json_str = doc.to_json() | |
| return json_str | |
| def json2bokeh(json_str): | |
| from bokeh.document import Document | |
| doc = Document.from_json(json_str) | |
| obj = doc.roots[0] | |
| return obj | |
| def bokeh_copy(obj): | |
| json_str = bokeh2json(obj) | |
| obj_copy = json2bokeh(json_str) | |
| return obj_copy | |
| # Function to fit a polynomial curve and return the x and y values of the fitted curve | |
| def fit_curve(x, y, degree=1): | |
| # Fit a polynomial of given degree | |
| coeffs = np.polyfit(x, y, degree) | |
| poly = np.poly1d(coeffs) | |
| x_fit = np.linspace(min(x), max(x), 100) | |
| y_fit = poly(x_fit) | |
| return x_fit, y_fit | |
| # Function to detect and remove outliers using the IQR method | |
| def remove_outliers(x, y): | |
| x = np.array(x) | |
| y = np.array(y) | |
| # Calculate Q1 (25th percentile) and Q3 (75th percentile) | |
| Q1_x, Q3_x = np.percentile(x, [25, 75]) | |
| Q1_y, Q3_y = np.percentile(y, [25, 75]) | |
| IQR_x = Q3_x - Q1_x | |
| IQR_y = Q3_y - Q1_y | |
| # Define bounds for outliers | |
| lower_bound_x = Q1_x - 1.5 * IQR_x | |
| upper_bound_x = Q3_x + 1.5 * IQR_x | |
| lower_bound_y = Q1_y - 1.5 * IQR_y | |
| upper_bound_y = Q3_y + 1.5 * IQR_y | |
| # Filter out outliers | |
| mask_x = (x >= lower_bound_x) & (x <= upper_bound_x) | |
| mask_y = (y >= lower_bound_y) & (y <= upper_bound_y) | |
| mask = mask_x & mask_y | |
| return x[mask], y[mask], x[~mask], y[~mask] | |
| def get_ldb_records(name_map, csv_file_path): | |
| model_mapping = {model_title: model_title for model_title in name_map.values()} | |
| ldb_records={} | |
| with open(csv_file_path, mode='r') as file: | |
| reader = csv.DictReader(file) | |
| for row in reader: | |
| sanitized_name = model_mapping[row['Model']] | |
| ldb_records[sanitized_name] = row | |
| return ldb_records | |
| def create_scatter_plot_with_curve_with_variances_named(category, variance_across_categories, x, y, sizes, model_names, ldb_records): | |
| FONTSIZE = 12 | |
| # Remove outliers | |
| x_filtered, y_filtered, x_outliers, y_outliers = remove_outliers(x, y) | |
| # Scale the variance to a range suitable for marker sizes (e.g., between 5 and 30) | |
| min_marker_size = 5 | |
| max_marker_size = 30 | |
| def scale_variance_to_size(variance): | |
| # Scale variance to marker size (linear mapping) | |
| return min_marker_size + (variance - min(variance_across_categories.values())) * (max_marker_size - min_marker_size) / (max(variance_across_categories.values()) - min(variance_across_categories.values())) | |
| # Function to get the variance for a given model name | |
| def get_variance_for_model(model_name): | |
| return variance_across_categories.get(model_name, 0) # Default to 0 if model not found | |
| # Get markers | |
| filtered_markers = np.array(model_names)[np.in1d(x, x_filtered)] | |
| outlier_markers = np.array(model_names)[np.in1d(x, x_outliers)] | |
| # Get marker sizes and variances for the filtered data | |
| filtered_variances = [get_variance_for_model(mname) for mname in filtered_markers] | |
| marker_sizes_filtered = [scale_variance_to_size(var) for var in filtered_variances] | |
| # Get marker sizes and variances for the outlier data | |
| outlier_variances = [get_variance_for_model(mname) for mname in outlier_markers] | |
| marker_sizes_outliers = [scale_variance_to_size(var) for var in outlier_variances] | |
| # Assign symbols to the model types | |
| # https://docs.bokeh.org/en/latest/docs/examples/basic/scatters/markers.html | |
| _model_type2symbol = { | |
| 'chat': 'circle', | |
| 'pretrained': 'triangle', | |
| 'ensemble': 'star', | |
| } | |
| model_type2symbol = lambda model_type: _model_type2symbol.get(model_type, 'diamond') | |
| # Assign symbols to the filtered data points | |
| filtered_symbols = [model_type2symbol(ldb_records[mname]['Type']) for mname in filtered_markers] | |
| # Assign symbols to the outlier data points | |
| outlier_symbols = [model_type2symbol(ldb_records[mname]['Type']) for mname in outlier_markers] | |
| # Define a color palette with enough colors | |
| stride = len(Turbo256) // len(model_names) | |
| color_palette = list(Turbo256[::stride]) # Adjust this palette size based on the number of data points | |
| random.shuffle(color_palette) | |
| # Create unique colors for filtered data | |
| filtered_colors = [color_palette[i % len(color_palette)] for i in range(len(x_filtered))] | |
| # Create unique colors for outliers | |
| outlier_colors = [color_palette[(i + len(x_filtered)) % len(color_palette)] for i in range(len(x_outliers))] | |
| # Create ColumnDataSource with filtered data | |
| source_filtered = ColumnDataSource(data={ | |
| 'x': x_filtered, | |
| 'y': y_filtered, | |
| 'sizes': np.array(sizes)[np.in1d(x, x_filtered)], # Keep original model sizes | |
| 'marker_sizes': marker_sizes_filtered, # New field for marker sizes based on variance | |
| 'model_names': np.array(model_names)[np.in1d(x, x_filtered)], | |
| 'variance': filtered_variances, # New field for variance | |
| 'color': filtered_colors, | |
| 'symbol': filtered_symbols | |
| }) | |
| # Create ColumnDataSource with outlier data | |
| source_outliers = ColumnDataSource(data={ | |
| 'x': x_outliers, | |
| 'y': y_outliers, | |
| 'sizes': np.array(sizes)[np.in1d(x, x_outliers)], # Keep original model sizes | |
| 'marker_sizes': marker_sizes_outliers, # New field for marker sizes based on variance | |
| 'model_names': np.array(model_names)[np.in1d(x, x_outliers)], | |
| 'variance': outlier_variances, # New field for variance | |
| 'color': outlier_colors, | |
| 'symbol': outlier_symbols | |
| }) | |
| # Create a figure for the category | |
| p = figure( | |
| output_backend="svg", | |
| sizing_mode="stretch_width", | |
| height=800, | |
| #title=f"{category} vs Model Size vs Variance Across Categories", | |
| tools="pan,wheel_zoom,box_zoom,save,reset", | |
| active_scroll="wheel_zoom", | |
| tooltips=[ | |
| ("Model", "@model_names"), | |
| ("Model Size (B parameters)", "@sizes"), | |
| ("Variance", "@variance"), # Added variance to the tooltip | |
| ("Performance", "@y"), | |
| ] | |
| ) | |
| # Plot filtered data with unique colors and scaled marker sizes | |
| p.scatter('x', 'y', size='marker_sizes', source=source_filtered, fill_alpha=0.6, color='color', marker='symbol') | |
| # Plot outliers with unique colors and scaled marker sizes | |
| p.scatter('x', 'y', size='marker_sizes', source=source_outliers, fill_alpha=0.6, color='color', marker='symbol') | |
| # Fit and plot a curve | |
| x_fit, y_fit = fit_curve(x_filtered, y_filtered, degree=1) # You can adjust the degree of the polynomial | |
| p.line(x_fit, y_fit, line_color='gray', line_width=2, line_dash='dashed') | |
| # Add labels (with slight offset to avoid overlap) | |
| p.add_layout(LabelSet( | |
| x='x', | |
| y='y', | |
| text='model_names', | |
| source=source_filtered, | |
| x_offset=5, | |
| y_offset=8, | |
| text_font_size=f"{FONTSIZE-2}pt", | |
| text_color='black', | |
| )) | |
| p.add_layout(LabelSet( | |
| x='x', | |
| y='y', | |
| text='model_names', | |
| source=source_outliers, | |
| x_offset=5, | |
| y_offset=8, | |
| text_font_size=f"{FONTSIZE-2}pt", | |
| text_color='black', | |
| )) | |
| # Set axis labels | |
| p.xaxis.axis_label = 'Model Size (B parameters)' | |
| p.yaxis.axis_label = f'{category}' | |
| # Set axis label font sizes | |
| p.xaxis.axis_label_text_font_size = f"{FONTSIZE}pt" # Set font size for x-axis label | |
| p.yaxis.axis_label_text_font_size = f"{FONTSIZE}pt" # Set font size for y-axis label | |
| # Increase tick label font sizes | |
| p.xaxis.major_label_text_font_size = f"{FONTSIZE}pt" # Increase x-axis tick label size | |
| p.yaxis.major_label_text_font_size = f"{FONTSIZE}pt" # Increase y-axis tick label size | |
| p.x_scale = LogScale() | |
| p.xaxis.ticker = tickers.LogTicker() | |
| p.xaxis.axis_label_text_font_style = "normal" | |
| p.yaxis.axis_label_text_font_style = "normal" | |
| return p | |
| def create_heatmap(data_matrix, original_scores, | |
| selected_rows=None, | |
| hide_scores_tasks=[], | |
| plot_width=None, | |
| plot_height=None, | |
| x_axis_label="Model", | |
| y_axis_label="Task", | |
| x_axis_visible=True, | |
| y_axis_visible=True, | |
| transpose=False, | |
| ): | |
| FONTSIZE = 9 | |
| if transpose: | |
| data_matrix = data_matrix.T | |
| original_scores = original_scores.T | |
| x_axis_label, y_axis_label = y_axis_label, x_axis_label | |
| x_axis_visible, y_axis_visible = y_axis_visible, x_axis_visible | |
| toolbar_location = "right" | |
| x_axis_location = "above" | |
| y_range=list(reversed(data_matrix.columns)) | |
| else: | |
| toolbar_location = "below" | |
| x_axis_location = "below" | |
| y_range=list(data_matrix.columns) | |
| n_rows, n_cols = data_matrix.shape | |
| cell_size = 22 | |
| plot_inner_width = None | |
| plot_inner_height = None | |
| if plot_width == None: | |
| plot_inner_width = n_rows * cell_size | |
| plot_width = plot_inner_width + 500 | |
| if plot_height == None: | |
| plot_inner_height = n_cols * cell_size | |
| plot_height = plot_inner_height + 500 | |
| if selected_rows is not None: | |
| # Select only the specified rows (models) | |
| data_matrix = data_matrix[selected_rows] | |
| original_scores = original_scores[selected_rows] | |
| # Set up the figure with tasks as x-axis and models as y-axis | |
| p = figure( | |
| output_backend="svg", | |
| sizing_mode="fixed", | |
| width=plot_width, | |
| height=plot_height, | |
| x_range=list(data_matrix.index), | |
| y_range=y_range, | |
| toolbar_location=toolbar_location, | |
| tools="pan,wheel_zoom,box_zoom,reset,save", | |
| active_drag=None, | |
| x_axis_label=x_axis_label, | |
| y_axis_label=y_axis_label, | |
| x_axis_location=x_axis_location, | |
| ) | |
| # Create the color mapper for the heatmap | |
| color_mapper = LinearColorMapper(palette='Viridis256', low=0, high=1) # Light for low values, dark for high | |
| # Flatten the matrix for Bokeh plotting | |
| heatmap_data = { | |
| 'x': [], | |
| 'y': [], | |
| 'colors': [], | |
| 'model_names': [], # Updated: Reflects model names now | |
| 'scores': [], | |
| } | |
| label_data = { | |
| 'x': [], | |
| 'y': [], | |
| 'value': [], | |
| 'text_color': [], # New field for label text colors | |
| } | |
| # Iterate through the data_matrix to populate heatmap and label data | |
| for row_idx, (model_name, task_scores) in enumerate(data_matrix.iterrows()): | |
| for col_idx, score in enumerate(task_scores): | |
| heatmap_data['x'].append(model_name) # Model goes to x-axis | |
| heatmap_data['y'].append(data_matrix.columns[col_idx]) # Task goes to y-axis | |
| heatmap_data['colors'].append(score) | |
| heatmap_data['model_names'].append(model_name) # Model names added to hover info | |
| # Get the original score | |
| original_score = original_scores.loc[model_name, data_matrix.columns[col_idx]] | |
| plot_score = data_matrix.loc[model_name, data_matrix.columns[col_idx]] | |
| heatmap_data['scores'].append(original_score) | |
| task_name = data_matrix.columns[col_idx] | |
| if task_name not in hide_scores_tasks: | |
| label_data['x'].append(model_name) | |
| label_data['y'].append(task_name) | |
| label_data['value'].append(round(original_score)) # Round the score | |
| # Determine text color based on score | |
| if plot_score <= 0.6: # Threshold for light/dark text | |
| label_data['text_color'].append('white') # Light color for lower scores | |
| else: | |
| label_data['text_color'].append('black') # Dark color for higher scores | |
| heatmap_source = ColumnDataSource(heatmap_data) | |
| label_source = ColumnDataSource(label_data) | |
| # Create the heatmap | |
| p.rect(x='x', y='y', width=1, height=1, source=heatmap_source, | |
| line_color=None, fill_color={'field': 'colors', 'transform': color_mapper}) | |
| # Add HoverTool for interactivity | |
| hover = HoverTool() | |
| hover.tooltips = [(x_axis_label, "@x"), (y_axis_label, "@y"), ("DWS", "@scores")] # Updated tooltip | |
| p.add_tools(hover) | |
| # Add labels with dynamic text color | |
| labels = LabelSet(x='x', y='y', text='value', source=label_source, | |
| text_color='text_color', text_align='center', text_baseline='middle', | |
| text_font_size=f"{FONTSIZE}pt") | |
| p.add_layout(labels) | |
| # Customize the plot appearance | |
| p.xgrid.grid_line_color = None | |
| p.ygrid.grid_line_color = None | |
| p.xaxis.major_label_orientation = "vertical" | |
| p.yaxis.major_label_text_font_size = f"{FONTSIZE}pt" | |
| p.xaxis.major_label_text_font_size = f"{FONTSIZE}pt" | |
| # Set the axis label font size | |
| p.xaxis.axis_label_text_font_size = f"{FONTSIZE + 5}pt" # Set font size for x-axis label | |
| p.yaxis.axis_label_text_font_size = f"{FONTSIZE + 5}pt" # Set font size for y-axis label | |
| p.xaxis.axis_label_text_font_style = "normal" # Set x-axis label to normal | |
| p.yaxis.axis_label_text_font_style = "normal" # Set y-axis label to normal | |
| # Hide the axis labels | |
| p.xaxis.visible = x_axis_visible | |
| p.yaxis.visible = y_axis_visible | |
| # Fix inner size | |
| if plot_inner_width != None: | |
| p.js_on_change('inner_width', CustomJS(args=dict(p=p, target=plot_inner_width), code=""" | |
| // current inner width of the plot area | |
| const iw = p.inner_width; | |
| // calculate the margin between full width and inner plot area | |
| const margin = p.width - iw; | |
| // adjust total width so that inner width matches the desired target | |
| p.width = target + margin; | |
| // remove only this callback from the inner_width callbacks array | |
| const cbs = p.js_property_callbacks.inner_width; | |
| for (let i = 0; i < cbs.length; i++) { | |
| if (cbs[i] === this) { | |
| cbs.splice(i, 1); | |
| break; | |
| } | |
| } | |
| """)) | |
| if plot_inner_height != None: | |
| p.js_on_change('inner_height', CustomJS(args=dict(p=p, target=plot_inner_height), code=""" | |
| // current inner height of the plot area | |
| const ih = p.inner_height; | |
| // calculate the margin between full height and inner plot area | |
| const margin = p.height - ih; | |
| // adjust total height so that inner height matches the desired target | |
| p.height = target + margin; | |
| // remove only this callback from the inner_height callbacks array | |
| const cbs = p.js_property_callbacks.inner_height; | |
| for (let i = 0; i < cbs.length; i++) { | |
| if (cbs[i] === this) { | |
| cbs.splice(i, 1); | |
| break; | |
| } | |
| } | |
| """)) | |
| return p | |
| # EOF | |