Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						0b8c16d
	
1
								Parent(s):
							
							ab74236
								
upload plot
Browse files- app.py +6 -1
- src/plt.py +53 -0
- src/utils.py +12 -0
    	
        app.py
    CHANGED
    
    | @@ -5,6 +5,7 @@ from apscheduler.schedulers.background import BackgroundScheduler | |
| 5 | 
             
            from datasets import load_dataset
         | 
| 6 | 
             
            from src.utils import load_all_data
         | 
| 7 | 
             
            from src.md import ABOUT_TEXT, TOP_TEXT
         | 
|  | |
| 8 | 
             
            import numpy as np
         | 
| 9 |  | 
| 10 | 
             
            api = HfApi()
         | 
| @@ -210,7 +211,11 @@ with gr.Blocks() as app: | |
| 210 | 
             
                            sample_display = gr.Markdown("{sampled data loads here}")
         | 
| 211 |  | 
| 212 | 
             
                        button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
         | 
| 213 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 214 |  | 
| 215 | 
             
            # Load data when app starts, TODO make this used somewhere...
         | 
| 216 | 
             
            # def load_data_on_start():
         | 
|  | |
| 5 | 
             
            from datasets import load_dataset
         | 
| 6 | 
             
            from src.utils import load_all_data
         | 
| 7 | 
             
            from src.md import ABOUT_TEXT, TOP_TEXT
         | 
| 8 | 
            +
            from src.plt import plot_avg_correlation
         | 
| 9 | 
             
            import numpy as np
         | 
| 10 |  | 
| 11 | 
             
            api = HfApi()
         | 
|  | |
| 211 | 
             
                            sample_display = gr.Markdown("{sampled data loads here}")
         | 
| 212 |  | 
| 213 | 
             
                        button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
         | 
| 214 | 
            +
                    # removed plot because not pretty enough
         | 
| 215 | 
            +
                    # with gr.TabItem("Model Correlation"):
         | 
| 216 | 
            +
                    #     with gr.Row():
         | 
| 217 | 
            +
                    #         plot = plot_avg_correlation(herm_data_avg, prefs_data)
         | 
| 218 | 
            +
                    #         gr.Plot(plot)
         | 
| 219 |  | 
| 220 | 
             
            # Load data when app starts, TODO make this used somewhere...
         | 
| 221 | 
             
            # def load_data_on_start():
         | 
    	
        src/plt.py
    ADDED
    
    | @@ -0,0 +1,53 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import matplotlib.pyplot as plt
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
            from .utils import undo_hyperlink
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def plot_avg_correlation(df1, df2):
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                Plots the "average" column for each unique model that appears in both dataframes.
         | 
| 8 | 
            +
                
         | 
| 9 | 
            +
                Parameters:
         | 
| 10 | 
            +
                - df1: pandas DataFrame containing columns "model" and "average".
         | 
| 11 | 
            +
                - df2: pandas DataFrame containing columns "model" and "average".
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                # Identify the unique models that appear in both DataFrames
         | 
| 14 | 
            +
                common_models = pd.Series(list(set(df1['model']) & set(df2['model'])))
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                # Set up the plot
         | 
| 17 | 
            +
                plt.figure(figsize=(13, 6), constrained_layout=True)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                # axes from 0 to 1 for x and y 
         | 
| 20 | 
            +
                plt.xlim(0.475, 0.8)
         | 
| 21 | 
            +
                plt.ylim(0.475, 0.8)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                # larger font (16)
         | 
| 24 | 
            +
                plt.rcParams.update({'font.size': 12, 'axes.labelsize': 14,'axes.titlesize': 14})
         | 
| 25 | 
            +
                # plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
         | 
| 26 | 
            +
                # plt.tight_layout()
         | 
| 27 | 
            +
                # plt.margins(0,0)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                for model in common_models:
         | 
| 30 | 
            +
                    # Filter data for the current model
         | 
| 31 | 
            +
                    df1_model_data = df1[df1['model'] == model]['average'].values
         | 
| 32 | 
            +
                    df2_model_data = df2[df2['model'] == model]['average'].values
         | 
| 33 | 
            +
                    
         | 
| 34 | 
            +
                    # Plotting
         | 
| 35 | 
            +
                    plt.scatter(df1_model_data, df2_model_data, label=model)
         | 
| 36 | 
            +
                    m_name = undo_hyperlink(model)
         | 
| 37 | 
            +
                    if m_name == "No text found":
         | 
| 38 | 
            +
                        m_name = "Random"
         | 
| 39 | 
            +
                    # Add text above each point like 
         | 
| 40 | 
            +
                    # plt.text(x[i] + 0.1, y[i] + 0.1, label, ha='left', va='bottom')
         | 
| 41 | 
            +
                    plt.text(df1_model_data - .005, df2_model_data, m_name, horizontalalignment='right', verticalalignment='center')
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                # add correlation line to scatter plot
         | 
| 44 | 
            +
                # first, compute correlation
         | 
| 45 | 
            +
                corr = df1['average'].corr(df2['average'])
         | 
| 46 | 
            +
                # add correlation line based on corr
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
                plt.xlabel('HERM Eval. Set Avg.', fontsize=16)
         | 
| 51 | 
            +
                plt.ylabel('Pref. Test Sets Avg.', fontsize=16)
         | 
| 52 | 
            +
                # plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
         | 
| 53 | 
            +
                return plt
         | 
    	
        src/utils.py
    CHANGED
    
    | @@ -3,6 +3,7 @@ from pathlib import Path | |
| 3 | 
             
            from datasets import load_dataset
         | 
| 4 | 
             
            import numpy as np
         | 
| 5 | 
             
            import os
         | 
|  | |
| 6 |  | 
| 7 | 
             
            # From Open LLM Leaderboard
         | 
| 8 | 
             
            def model_hyperlink(link, model_name):
         | 
| @@ -10,6 +11,17 @@ def model_hyperlink(link, model_name): | |
| 10 | 
             
                    return "random"
         | 
| 11 | 
             
                return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
         | 
| 12 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 13 | 
             
            # Define a function to fetch and process data
         | 
| 14 | 
             
            def load_all_data(data_repo, subdir:str, subsubsets=False):    # use HF api to pull the git repo
         | 
| 15 | 
             
                dir = Path(data_repo)
         | 
|  | |
| 3 | 
             
            from datasets import load_dataset
         | 
| 4 | 
             
            import numpy as np
         | 
| 5 | 
             
            import os
         | 
| 6 | 
            +
            import re
         | 
| 7 |  | 
| 8 | 
             
            # From Open LLM Leaderboard
         | 
| 9 | 
             
            def model_hyperlink(link, model_name):
         | 
|  | |
| 11 | 
             
                    return "random"
         | 
| 12 | 
             
                return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
         | 
| 13 |  | 
| 14 | 
            +
            def undo_hyperlink(html_string):
         | 
| 15 | 
            +
                # Regex pattern to match content inside > and <
         | 
| 16 | 
            +
                pattern = r'>[^<]+<'
         | 
| 17 | 
            +
                match = re.search(pattern, html_string)
         | 
| 18 | 
            +
                if match:
         | 
| 19 | 
            +
                    # Extract the matched text and remove leading '>' and trailing '<'
         | 
| 20 | 
            +
                    return match.group(0)[1:-1]
         | 
| 21 | 
            +
                else:
         | 
| 22 | 
            +
                    return "No text found"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
             
            # Define a function to fetch and process data
         | 
| 26 | 
             
            def load_all_data(data_repo, subdir:str, subsubsets=False):    # use HF api to pull the git repo
         | 
| 27 | 
             
                dir = Path(data_repo)
         | 

