File size: 2,431 Bytes
8274db5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import json
import matplotlib.pyplot as plt

def load_roc_data(roc_files):
    """
    加载多个模型的ROC数据
    参数:
        roc_files (list): 包含多个模型的overall_roc_data.json文件路径的列表
    返回:
        roc_data_list (list): 包含每个模型的ROC数据字典的列表
    """
    roc_data_list = []
    for roc_file in roc_files:
        with open(roc_file, "r") as f:
            roc_data = json.load(f)
            roc_data_list.append(roc_data)
    return roc_data_list

def plot_combined_roc(roc_data_list, model_names, output_path):
    """
    绘制多个模型的ROC曲线到同一张图中
    参数:
        roc_data_list (list): 包含每个模型的ROC数据字典的列表
        model_names (list): 每个模型的名称列表
        output_path (str): 保存ROC曲线图的路径
    """
    plt.figure(figsize=(10, 8))
    
    for roc_data, model_name in zip(roc_data_list, model_names):
        fpr = roc_data["fpr"]
        tpr = roc_data["tpr"]
        auc_value = roc_data["auc"]
        plt.plot(fpr, tpr, label=f"{model_name} (AUC = {auc_value:.4f})")
    
    # 绘制对角线
    plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
    
    # 图形设置
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
    plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
    plt.title('Combined ROC Curves for Multiple Models', fontsize=14)
    plt.legend(loc="lower right", fontsize=10)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    
    # 保存图像
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"ROC曲线图已保存至: {output_path}")

def main():
    # 定义存放多个模型ROC数据的目录
    roc_data_dir = "./roc_result"  # 替换为实际路径
    output_path = "./combined_roc_curve.png"  # 保存最终ROC曲线图的路径
    
    # 获取所有模型的overall_roc_data.json文件路径
    roc_files = [os.path.join(roc_data_dir, f) for f in os.listdir(roc_data_dir)]
    
    # 模型名称(从文件名提取)
    model_names = [os.path.basename(f).replace(".json", "") for f in roc_files]
    
    # 加载所有模型的ROC数据
    roc_data_list = load_roc_data(roc_files)
    
    # 绘制并保存组合ROC曲线
    plot_combined_roc(roc_data_list, model_names, output_path)

if __name__ == "__main__":
    main()