Spaces:
Runtime error
Runtime error
| """ | |
| Clean chatbot arena chat log. | |
| Usage: | |
| python3 clean_chat_data.py --mode conv_release | |
| """ | |
| import argparse | |
| import datetime | |
| import json | |
| import os | |
| from pytz import timezone | |
| import time | |
| from tqdm import tqdm | |
| from fastchat.serve.monitor.basic_stats import NUM_SERVERS | |
| from fastchat.serve.monitor.clean_battle_data import ( | |
| to_openai_format, | |
| replace_model_name, | |
| ) | |
| from fastchat.utils import detect_language | |
| NETWORK_ERROR_MSG = ( | |
| "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower() | |
| ) | |
| def get_log_files(max_num_files=None): | |
| dates = [] | |
| for month in range(4, 12): | |
| for day in range(1, 33): | |
| dates.append(f"2023-{month:02d}-{day:02d}") | |
| filenames = [] | |
| for d in dates: | |
| for i in range(NUM_SERVERS): | |
| name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") | |
| if os.path.exists(name): | |
| filenames.append(name) | |
| max_num_files = max_num_files or len(filenames) | |
| # filenames = list(reversed(filenames)) | |
| filenames = filenames[-max_num_files:] | |
| return filenames | |
| def clean_chat_data(log_files, action_type): | |
| raw_data = [] | |
| for filename in tqdm(log_files, desc="read files"): | |
| for retry in range(5): | |
| try: | |
| lines = open(filename).readlines() | |
| break | |
| except FileNotFoundError: | |
| time.sleep(2) | |
| for l in lines: | |
| row = json.loads(l) | |
| if row["type"] == action_type: | |
| raw_data.append(row) | |
| all_models = set() | |
| all_ips = dict() | |
| chats = [] | |
| ct_invalid_conv_id = 0 | |
| ct_invalid = 0 | |
| ct_network_error = 0 | |
| for row in raw_data: | |
| try: | |
| if action_type in ["chat", "upvote", "downvote"]: | |
| state = row["state"] | |
| model = row["model"] | |
| elif action_type == "leftvote": | |
| state = row["states"][0] | |
| model = row["states"][0]["model_name"] | |
| elif action_type == "rightvote": | |
| state = row["states"][1] | |
| model = row["states"][1]["model_name"] | |
| conversation_id = state["conv_id"] | |
| except KeyError: | |
| ct_invalid_conv_id += 1 | |
| continue | |
| if conversation_id is None: | |
| ct_invalid_conv_id += 1 | |
| continue | |
| conversation = to_openai_format(state["messages"][state["offset"] :]) | |
| if not isinstance(model, str): | |
| ct_invalid += 1 | |
| continue | |
| model = replace_model_name(model) | |
| try: | |
| lang_code = detect_language(state["messages"][state["offset"]][1]) | |
| except IndexError: | |
| ct_invalid += 1 | |
| continue | |
| if not all(isinstance(x["content"], str) for x in conversation): | |
| ct_invalid += 1 | |
| continue | |
| messages = "".join([x["content"] for x in conversation]).lower() | |
| if NETWORK_ERROR_MSG in messages: | |
| ct_network_error += 1 | |
| continue | |
| ip = row["ip"] | |
| if ip not in all_ips: | |
| all_ips[ip] = len(all_ips) | |
| user_id = all_ips[ip] | |
| chats.append( | |
| dict( | |
| conversation_id=conversation_id, | |
| model=model, | |
| conversation=conversation, | |
| turn=len(conversation) // 2, | |
| language=lang_code, | |
| user_id=user_id, | |
| tstamp=row["tstamp"], | |
| ) | |
| ) | |
| all_models.update([model]) | |
| chats.sort(key=lambda x: x["tstamp"]) | |
| last_updated_tstamp = chats[-1]["tstamp"] | |
| last_updated_datetime = datetime.datetime.fromtimestamp( | |
| last_updated_tstamp, tz=timezone("US/Pacific") | |
| ).strftime("%Y-%m-%d %H:%M:%S %Z") | |
| # Deduplication | |
| dedup_chats = [] | |
| visited_conv_ids = set() | |
| for i in reversed(range(len(chats))): | |
| if chats[i]["conversation_id"] in visited_conv_ids: | |
| continue | |
| visited_conv_ids.add(chats[i]["conversation_id"]) | |
| dedup_chats.append(chats[i]) | |
| print( | |
| f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}" | |
| ) | |
| print( | |
| f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}" | |
| ) | |
| print(f"#models: {len(all_models)}, {all_models}") | |
| print(f"last-updated: {last_updated_datetime}") | |
| return list(reversed(dedup_chats)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--action-type", type=str, default="chat") | |
| parser.add_argument("--max-num-files", type=int) | |
| args = parser.parse_args() | |
| log_files = get_log_files(args.max_num_files) | |
| chats = clean_chat_data(log_files, args.action_type) | |
| last_updated_tstamp = chats[-1]["tstamp"] | |
| cutoff_date = datetime.datetime.fromtimestamp( | |
| last_updated_tstamp, tz=timezone("US/Pacific") | |
| ).strftime("%Y%m%d") | |
| output = f"clean_{args.action_type}_conv_{cutoff_date}.json" | |
| with open(output, "w") as fout: | |
| json.dump(chats, fout, indent=2, ensure_ascii=False) | |
| print(f"Write cleaned data to {output}") | |