Spaces:
Runtime error
Runtime error
| """ | |
| Usage: | |
| python3 topic_clustering.py --in arena.json --english-only --min-length 32 | |
| python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536 | |
| """ | |
| import argparse | |
| import json | |
| import pickle | |
| import string | |
| import time | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.util import cos_sim | |
| from sklearn.cluster import KMeans, AgglomerativeClustering | |
| import torch | |
| from tqdm import tqdm | |
| from fastchat.utils import detect_language | |
| def remove_punctuation(input_string): | |
| # Make a translator object to remove all punctuation | |
| translator = str.maketrans("", "", string.punctuation) | |
| # Use the translator object to remove the punctuation | |
| no_punct = input_string.translate(translator) | |
| return no_punct | |
| def read_texts(input_file, min_length, max_length, english_only): | |
| visited = set() | |
| texts = [] | |
| lines = json.load(open(input_file, "r")) | |
| for l in tqdm(lines): | |
| if "text" in l: | |
| line_texts = [l["text"]] | |
| elif "conversation_a" in l: | |
| line_texts = [ | |
| x["content"] for x in l["conversation_a"] if x["role"] == "user" | |
| ] | |
| elif "conversation" in l: | |
| line_texts = [ | |
| x["content"] for x in l["conversation"] if x["role"] == "user" | |
| ] | |
| for text in line_texts: | |
| text = text.strip() | |
| # Filter language | |
| if english_only: | |
| lang = detect_language(text) | |
| if lang != "English": | |
| continue | |
| # Filter short or long prompts | |
| if min_length: | |
| if len(text) < min_length: | |
| continue | |
| if max_length: | |
| if len(text) > max_length: | |
| continue | |
| # De-duplication | |
| words = sorted([x.lower() for x in remove_punctuation(text).split(" ")]) | |
| words = "".join(words) | |
| if words in visited: | |
| continue | |
| visited.add(words) | |
| texts.append(text) | |
| return np.array(texts) | |
| def get_embeddings(texts, model_name, batch_size): | |
| model = SentenceTransformer(model_name) | |
| embeddings = model.encode( | |
| texts, | |
| batch_size=batch_size, | |
| show_progress_bar=True, | |
| device="cuda", | |
| convert_to_tensor=True, | |
| ) | |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
| return embeddings.cpu() | |
| def run_k_means(embeddings, num_clusters): | |
| np.random.seed(42) | |
| clustering_model = KMeans(n_clusters=num_clusters, n_init="auto") | |
| clustering_model.fit(embeddings.numpy()) | |
| centers = torch.from_numpy(clustering_model.cluster_centers_) | |
| labels = torch.from_numpy(clustering_model.labels_) | |
| # Sort labels | |
| classes, counts = np.unique(labels, return_counts=True) | |
| indices = np.argsort(counts)[::-1] | |
| classes = [classes[i] for i in indices] | |
| new_labels = torch.empty_like(labels) | |
| new_centers = torch.empty_like(centers) | |
| for i, c in enumerate(classes): | |
| new_labels[labels == c] = i | |
| new_centers[i] = centers[c] | |
| return new_centers, new_labels | |
| def run_agg_cluster(embeddings, num_clusters): | |
| np.random.seed(42) | |
| clustering_model = AgglomerativeClustering(n_clusters=num_clusters) | |
| clustering_model.fit(embeddings) | |
| labels = torch.from_numpy(clustering_model.labels_) | |
| # Sort labels | |
| classes, counts = np.unique(labels, return_counts=True) | |
| indices = np.argsort(counts)[::-1] | |
| classes = [classes[i] for i in indices] | |
| new_labels = torch.empty_like(labels) | |
| for i, c in enumerate(classes): | |
| new_labels[labels == c] = i | |
| # Compute centers | |
| centers = [] | |
| for i in range(len(classes)): | |
| centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) | |
| centers = torch.cat(centers) | |
| return centers, new_labels | |
| def run_hdbscan_cluster(embeddings): | |
| import hdbscan | |
| np.random.seed(42) | |
| clusterer = hdbscan.HDBSCAN(min_cluster_size=10) | |
| labels = torch.from_numpy(clusterer.fit_predict(embeddings)) | |
| # Sort labels | |
| classes, counts = np.unique(labels, return_counts=True) | |
| indices = np.argsort(counts)[::-1] | |
| classes = [classes[i] for i in indices] | |
| new_labels = torch.empty_like(labels) | |
| for i, c in enumerate(classes): | |
| new_labels[labels == c] = i | |
| # Compute centers | |
| centers = [] | |
| for i in range(len(classes)): | |
| centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) | |
| centers = torch.cat(centers) | |
| return centers, new_labels | |
| def get_topk_indices(centers, labels, embeddings, topk): | |
| indices = [] | |
| arange = torch.arange(len(labels)) | |
| counts = torch.unique(labels, return_counts=True)[1] | |
| topk = min(topk, counts.min().item()) | |
| for i in range(len(centers)): | |
| tmp_indices = labels == i | |
| tmp_arange = arange[tmp_indices] | |
| tmp_embeddings = embeddings[tmp_indices] | |
| scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] | |
| sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) | |
| indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0)) | |
| return torch.cat(indices) | |
| def print_topk(texts, labels, topk_indices, show_cut_off): | |
| ret = "" | |
| for k in range(len(topk_indices)): | |
| num_samples = torch.sum(labels == k).item() | |
| ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n" | |
| for idx in topk_indices[k]: | |
| ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n" | |
| ret += "=" * 40 + "\n\n" | |
| return ret | |
| def get_cluster_info(texts, labels, topk_indices): | |
| np.random.seed(42) | |
| cluster_info = [] | |
| for k in range(len(topk_indices)): | |
| num_samples = torch.sum(labels == k).item() | |
| topk_prompts = [] | |
| for idx in topk_indices[k]: | |
| topk_prompts.append(texts[idx]) | |
| random_prompts = [] | |
| for idx in range(len(topk_indices)): | |
| random_prompts.append(np.random.choice(texts)) | |
| cluster_info.append((num_samples, topk_prompts, random_prompts)) | |
| return cluster_info | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input-file", type=str, required=True) | |
| parser.add_argument("--model", type=str, default="all-mpnet-base-v2") | |
| # default="all-MiniLM-L12-v2") | |
| # default="multi-qa-distilbert-cos-v1") | |
| parser.add_argument("--batch-size", type=int, default=256) | |
| parser.add_argument("--min-length", type=int) | |
| parser.add_argument("--max-length", type=int) | |
| parser.add_argument("--english-only", action="store_true") | |
| parser.add_argument("--num-clusters", type=int, default=20) | |
| parser.add_argument( | |
| "--cluster-alg", | |
| type=str, | |
| choices=["kmeans", "aggcls", "HDBSCAN"], | |
| default="kmeans", | |
| ) | |
| parser.add_argument("--show-top-k", type=int, default=200) | |
| parser.add_argument("--show-cut-off", type=int, default=512) | |
| args = parser.parse_args() | |
| num_clusters = args.num_clusters | |
| show_top_k = args.show_top_k | |
| show_cut_off = args.show_cut_off | |
| texts = read_texts( | |
| args.input_file, args.min_length, args.max_length, args.english_only | |
| ) | |
| print(f"#text: {len(texts)}") | |
| embeddings = get_embeddings(texts, args.model, args.batch_size) | |
| if args.cluster_alg == "kmeans": | |
| centers, labels = run_k_means(embeddings, num_clusters) | |
| elif args.cluster_alg == "aggcls": | |
| centers, labels = run_agg_cluster(embeddings, num_clusters) | |
| elif args.cluster_alg == "HDBSCAN": | |
| centers, labels = run_hdbscan_cluster(embeddings) | |
| else: | |
| raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}") | |
| topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k) | |
| topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off) | |
| num_clusters = len(centers) | |
| # Dump results | |
| filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}" | |
| print(topk_str) | |
| with open(filename_prefix + "_topk.txt", "w") as fout: | |
| fout.write(topk_str) | |
| with open(filename_prefix + "_all.txt", "w") as fout: | |
| for i in range(len(centers)): | |
| tmp_indices = labels == i | |
| tmp_embeddings = embeddings[tmp_indices] | |
| tmp_texts = texts[tmp_indices] | |
| scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] | |
| sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) | |
| for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]): | |
| obj = {"cluster": i, "text": text, "sim": score.item()} | |
| fout.write(json.dumps(obj, ensure_ascii=False) + "\n") | |
| cluster_info = get_cluster_info(texts, labels, topk_indices) | |
| with open(filename_prefix + "_cluster.pkl", "wb") as fout: | |
| pickle.dump(cluster_info, fout) | |