Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import matplotlib.pyplot as plt | |
| def load_roc_data(roc_files): | |
| """ | |
| 加载多个模型的ROC数据 | |
| 参数: | |
| roc_files (list): 包含多个模型的overall_roc_data.json文件路径的列表 | |
| 返回: | |
| roc_data_list (list): 包含每个模型的ROC数据字典的列表 | |
| """ | |
| roc_data_list = [] | |
| for roc_file in roc_files: | |
| with open(roc_file, "r") as f: | |
| roc_data = json.load(f) | |
| roc_data_list.append(roc_data) | |
| return roc_data_list | |
| def plot_combined_roc(roc_data_list, model_names, output_path): | |
| """ | |
| 绘制多个模型的ROC曲线到同一张图中 | |
| 参数: | |
| roc_data_list (list): 包含每个模型的ROC数据字典的列表 | |
| model_names (list): 每个模型的名称列表 | |
| output_path (str): 保存ROC曲线图的路径 | |
| """ | |
| plt.figure(figsize=(10, 8)) | |
| for roc_data, model_name in zip(roc_data_list, model_names): | |
| fpr = roc_data["fpr"] | |
| tpr = roc_data["tpr"] | |
| auc_value = roc_data["auc"] | |
| plt.plot(fpr, tpr, label=f"{model_name} (AUC = {auc_value:.4f})") | |
| # 绘制对角线 | |
| plt.plot([0, 1], [0, 1], 'k--', label="Random Guess") | |
| # 图形设置 | |
| plt.xlim([0.0, 1.0]) | |
| plt.ylim([0.0, 1.05]) | |
| plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12) | |
| plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12) | |
| plt.title('Combined ROC Curves for Multiple Models', fontsize=14) | |
| plt.legend(loc="lower right", fontsize=10) | |
| plt.grid(alpha=0.3) | |
| plt.tight_layout() | |
| # 保存图像 | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
| print(f"ROC曲线图已保存至: {output_path}") | |
| def main(): | |
| # 定义存放多个模型ROC数据的目录 | |
| roc_data_dir = "./roc_result" # 替换为实际路径 | |
| output_path = "./combined_roc_curve.png" # 保存最终ROC曲线图的路径 | |
| # 获取所有模型的overall_roc_data.json文件路径 | |
| roc_files = [os.path.join(roc_data_dir, f) for f in os.listdir(roc_data_dir)] | |
| # 模型名称(从文件名提取) | |
| model_names = [os.path.basename(f).replace(".json", "") for f in roc_files] | |
| # 加载所有模型的ROC数据 | |
| roc_data_list = load_roc_data(roc_files) | |
| # 绘制并保存组合ROC曲线 | |
| plot_combined_roc(roc_data_list, model_names, output_path) | |
| if __name__ == "__main__": | |
| main() |