Spaces:
Paused
Paused
| """ | |
| Usage: | |
| python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all] | |
| """ | |
| import argparse | |
| from concurrent.futures import ThreadPoolExecutor | |
| import json | |
| import numpy as np | |
| from tqdm import tqdm | |
| from fastchat.llm_judge.common import ( | |
| load_questions, | |
| load_model_answers, | |
| load_judge_prompts, | |
| check_data, | |
| play_a_match_pair, | |
| play_a_match_single, | |
| get_model_list, | |
| Judge, | |
| MatchPair, | |
| MatchSingle, | |
| NEED_REF_CATS, | |
| ) | |
| def make_match( | |
| questions, | |
| models, | |
| model_answers, | |
| judge, | |
| baseline_model, | |
| ref_answers=None, | |
| multi_turn=False, | |
| ): | |
| matches = [] | |
| for q in questions: | |
| if multi_turn and len(q["turns"]) != 2: | |
| continue | |
| for i in range(len(models)): | |
| q_id = q["question_id"] | |
| m_1 = models[i] | |
| m_2 = baseline_model | |
| if m_1 == m_2: | |
| continue | |
| a_1 = model_answers[m_1][q_id] | |
| a_2 = model_answers[baseline_model][q_id] | |
| if ref_answers is not None: | |
| ref = ref_answers[judge.model_name][q_id] | |
| match = MatchPair( | |
| dict(q), | |
| m_1, | |
| m_2, | |
| a_1, | |
| a_2, | |
| judge, | |
| ref_answer=ref, | |
| multi_turn=multi_turn, | |
| ) | |
| else: | |
| match = MatchPair( | |
| dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn | |
| ) | |
| matches.append(match) | |
| return matches | |
| def make_match_all_pairs( | |
| questions, | |
| models, | |
| model_answers, | |
| judge, | |
| baseline_model=None, | |
| ref_answers=None, | |
| multi_turn=False, | |
| ): | |
| matches = [] | |
| for q in questions: | |
| if multi_turn and len(q["turns"]) != 2: | |
| continue | |
| for i in range(len(models)): | |
| for j in range(i + 1, len(models)): | |
| q_id = q["question_id"] | |
| m_1 = models[i] | |
| m_2 = models[j] | |
| a_1 = model_answers[m_1][q_id] | |
| a_2 = model_answers[m_2][q_id] | |
| if ref_answers is not None: | |
| ref = ref_answers[judge.model_name][q_id] | |
| match = MatchPair( | |
| dict(q), | |
| m_1, | |
| m_2, | |
| a_1, | |
| a_2, | |
| judge, | |
| ref_answer=ref, | |
| multi_turn=multi_turn, | |
| ) | |
| else: | |
| match = MatchPair( | |
| dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn | |
| ) | |
| matches.append(match) | |
| return matches | |
| def make_match_single( | |
| questions, | |
| models, | |
| model_answers, | |
| judge, | |
| baseline_model=None, | |
| ref_answers=None, | |
| multi_turn=False, | |
| ): | |
| matches = [] | |
| for q in questions: | |
| if multi_turn and len(q["turns"]) != 2: | |
| continue | |
| for i in range(len(models)): | |
| q_id = q["question_id"] | |
| m = models[i] | |
| a = model_answers[m][q_id] | |
| if ref_answers is not None: | |
| ref = ref_answers[judge.model_name][q_id] | |
| matches.append( | |
| MatchSingle( | |
| dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn | |
| ) | |
| ) | |
| else: | |
| matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn)) | |
| return matches | |
| def make_judge_pairwise(judge_model, judge_prompts): | |
| judges = {} | |
| judges["default"] = Judge(judge_model, judge_prompts["pair-v2"]) | |
| judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True) | |
| judges["default-mt"] = Judge( | |
| judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True | |
| ) | |
| judges["math-mt"] = Judge( | |
| judge_model, | |
| judge_prompts["pair-math-v1-multi-turn"], | |
| ref_based=True, | |
| multi_turn=True, | |
| ) | |
| return judges | |
| def make_judge_single(judge_model, judge_prompts): | |
| judges = {} | |
| judges["default"] = Judge(judge_model, judge_prompts["single-v1"]) | |
| judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True) | |
| judges["default-mt"] = Judge( | |
| judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True | |
| ) | |
| judges["math-mt"] = Judge( | |
| judge_model, | |
| judge_prompts["single-math-v1-multi-turn"], | |
| ref_based=True, | |
| multi_turn=True, | |
| ) | |
| return judges | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--bench-name", | |
| type=str, | |
| default="mt_bench", | |
| help="The name of the benchmark question set.", | |
| ) | |
| parser.add_argument( | |
| "--judge-file", | |
| type=str, | |
| default="data/judge_prompts.jsonl", | |
| help="The file of judge prompts.", | |
| ) | |
| parser.add_argument("--judge-model", type=str, default="gpt-4") | |
| parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| default="single", | |
| choices=["pairwise-baseline", "pairwise-all", "single"], | |
| help=( | |
| "Evaluation mode. " | |
| "`pairwise-baseline` runs pairwise comparision against a baseline. " | |
| "`pairwise-all` runs pairwise comparision between all pairs. " | |
| "`single` runs single answer grading." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--model-list", | |
| type=str, | |
| nargs="+", | |
| default=None, | |
| help="A list of models to be evaluated", | |
| ) | |
| parser.add_argument( | |
| "--parallel", type=int, default=1, help="The number of concurrent API calls." | |
| ) | |
| parser.add_argument( | |
| "--first-n", type=int, help="A debug option. Only run the first `n` judgments." | |
| ) | |
| args = parser.parse_args() | |
| question_file = f"data/{args.bench_name}/question.jsonl" | |
| answer_dir = f"data/{args.bench_name}/model_answer" | |
| ref_answer_dir = f"data/{args.bench_name}/reference_answer" | |
| # Load questions | |
| questions = load_questions(question_file, None, None) | |
| # Load answers | |
| model_answers = load_model_answers(answer_dir) | |
| ref_answers = load_model_answers(ref_answer_dir) | |
| # Load judge | |
| judge_prompts = load_judge_prompts(args.judge_file) | |
| if args.first_n: | |
| questions = questions[: args.first_n] | |
| if args.model_list is None: | |
| models = get_model_list(answer_dir) | |
| else: | |
| models = args.model_list | |
| if args.mode == "single": | |
| judges = make_judge_single(args.judge_model, judge_prompts) | |
| play_a_match_func = play_a_match_single | |
| output_file = ( | |
| f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" | |
| ) | |
| make_match_func = make_match_single | |
| baseline_model = None | |
| else: | |
| judges = make_judge_pairwise(args.judge_model, judge_prompts) | |
| play_a_match_func = play_a_match_pair | |
| output_file = ( | |
| f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" | |
| ) | |
| if args.mode == "pairwise-all": | |
| make_match_func = make_match_all_pairs | |
| baseline_model = None | |
| else: | |
| make_match_func = make_match | |
| baseline_model = args.baseline_model | |
| check_data(questions, model_answers, ref_answers, models, judges) | |
| question_math = [q for q in questions if q["category"] in NEED_REF_CATS] | |
| question_default = [q for q in questions if q["category"] not in NEED_REF_CATS] | |
| # Make matches | |
| matches = [] | |
| matches += make_match_func( | |
| question_default, models, model_answers, judges["default"], baseline_model | |
| ) | |
| matches += make_match_func( | |
| question_math, | |
| models, | |
| model_answers, | |
| judges["math"], | |
| baseline_model, | |
| ref_answers, | |
| ) | |
| matches += make_match_func( | |
| question_default, | |
| models, | |
| model_answers, | |
| judges["default-mt"], | |
| baseline_model, | |
| multi_turn=True, | |
| ) | |
| matches += make_match_func( | |
| question_math, | |
| models, | |
| model_answers, | |
| judges["math-mt"], | |
| baseline_model, | |
| ref_answers, | |
| multi_turn=True, | |
| ) | |
| match_stat = {} | |
| match_stat["bench_name"] = args.bench_name | |
| match_stat["mode"] = args.mode | |
| match_stat["judge"] = args.judge_model | |
| match_stat["baseline"] = baseline_model | |
| match_stat["model_list"] = models | |
| match_stat["total_num_questions"] = len(questions) | |
| match_stat["total_num_matches"] = len(matches) | |
| match_stat["output_path"] = output_file | |
| # Show match stats and prompt enter to continue | |
| print("Stats:") | |
| print(json.dumps(match_stat, indent=4)) | |
| input("Press Enter to confirm...") | |
| # Play matches | |
| if args.parallel == 1: | |
| for match in tqdm(matches): | |
| play_a_match_func(match, output_file=output_file) | |
| else: | |
| def play_a_match_wrapper(match): | |
| play_a_match_func(match, output_file=output_file) | |
| np.random.seed(0) | |
| np.random.shuffle(matches) | |
| with ThreadPoolExecutor(args.parallel) as executor: | |
| for match in tqdm( | |
| executor.map(play_a_match_wrapper, matches), total=len(matches) | |
| ): | |
| pass | |