Spaces:
Paused
Paused
| import os | |
| import json | |
| import pandas as pd | |
| import ast | |
| import matplotlib.pyplot as plt | |
| from matplotlib import rcParams | |
| import argparse | |
| import seaborn as sns | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--output_dir", type=str, default="output") | |
| parser.add_argument("--model", type=str, default=None) | |
| parser.add_argument("--input_file", type=str, required=True) | |
| parser.add_argument("--percentile", type=float, default=0.9999) | |
| args = parser.parse_args() | |
| output_dir = args.output_dir | |
| input_file = args.input_file | |
| with open(input_file) as f: | |
| data = json.load(f) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Preprocessing | |
| all_convs_new = [] | |
| convs = [] | |
| for row in data: | |
| conv = "" | |
| for turns in row["conversation_a"]: | |
| if turns["role"] == "user": | |
| conv += f"{turns['content']}\n" | |
| convs.append(conv[:10000]) | |
| row["post_process_conv"] = conv[:10000] | |
| all_convs_new.append(row) | |
| df = pd.DataFrame(all_convs_new) | |
| print("Number of conversations: ", len(df)) | |
| prompt_counts = df["post_process_conv"].value_counts() | |
| # Select the top 20 most frequent prompts | |
| top_prompts = prompt_counts.head(20) | |
| print(top_prompts) | |
| # Determine the percentile count | |
| percentile_cutoff = prompt_counts.quantile(args.percentile) | |
| print(f"{args.percentile*100} percentile count: {percentile_cutoff}") | |
| # prompts that are more common than the percentile cutoff | |
| high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index | |
| print( | |
| f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}" | |
| ) | |
| # initialize a new column dedup_tag | |
| dedup_tags = np.array( | |
| [{"high_freq": False, "sampled": True} for _ in range(len(df))] | |
| ) | |
| high_freq_groups = df.groupby("post_process_conv") | |
| for prompt in tqdm(high_frequency_prompts): | |
| df_high_freq = high_freq_groups.get_group(prompt) | |
| sampled_indices = df_high_freq.sample( | |
| n=int(percentile_cutoff), random_state=42 | |
| ).index | |
| dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False} | |
| dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True} | |
| df["dedup_tag"] = dedup_tags | |
| # drop intermediate columns (post_process_conv) | |
| df = df.drop(columns=["post_process_conv"]) | |
| df.to_json( | |
| os.path.join(output_dir, "dedup.json"), | |
| orient="records", | |
| indent=4, | |
| force_ascii=False, | |
| ) | |