Spaces:
Sleeping
Sleeping
Update src/vis_utils.py
Browse files- src/vis_utils.py +14 -7
src/vis_utils.py
CHANGED
|
@@ -183,6 +183,8 @@ def plot_family_results(methods_selected, dataset, metric, family_path="/tmp/fam
|
|
| 183 |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
|
| 184 |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index
|
| 185 |
|
|
|
|
|
|
|
| 186 |
# Set up the plot
|
| 187 |
sns.set(rc={'figure.figsize': (13.7, 18.27)})
|
| 188 |
sns.set_theme(style="whitegrid", color_codes=True)
|
|
@@ -214,23 +216,25 @@ def plot_family_results(methods_selected, dataset, metric, family_path="/tmp/fam
|
|
| 214 |
|
| 215 |
return filename
|
| 216 |
|
| 217 |
-
def plot_affinity_results(
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
# Filter for selected methods
|
| 222 |
df = df[df['Method'].isin(method_names)]
|
| 223 |
|
| 224 |
# Gather columns related to the specified metric and validate
|
| 225 |
metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
|
| 226 |
-
if not metric_columns:
|
| 227 |
-
print(f"No columns found for metric '{metric}'.")
|
| 228 |
-
return None
|
| 229 |
|
| 230 |
# Reshape data for plotting
|
| 231 |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
|
| 232 |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index for sorting
|
| 233 |
|
|
|
|
|
|
|
| 234 |
# Set up the plot
|
| 235 |
sns.set(rc={'figure.figsize': (13.7, 8.27)})
|
| 236 |
sns.set_theme(style="whitegrid", color_codes=True)
|
|
@@ -246,12 +250,15 @@ def plot_affinity_results(file_path, method_names, metric, save_path="./plot_ima
|
|
| 246 |
ax.grid(b=True, which='minor', color='whitesmoke', linewidth=0.5)
|
| 247 |
|
| 248 |
# Apply custom color settings to y-axis labels
|
| 249 |
-
|
|
|
|
|
|
|
| 250 |
|
| 251 |
# Ensure save path exists
|
| 252 |
os.makedirs(save_path, exist_ok=True)
|
| 253 |
|
| 254 |
# Save the plot
|
|
|
|
| 255 |
filename = os.path.join(save_path, f"{metric}_affinity_results.png")
|
| 256 |
ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight')
|
| 257 |
plt.close() # Close the plot to free memory
|
|
|
|
| 183 |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
|
| 184 |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index
|
| 185 |
|
| 186 |
+
df = df.fillna(0)
|
| 187 |
+
|
| 188 |
# Set up the plot
|
| 189 |
sns.set(rc={'figure.figsize': (13.7, 18.27)})
|
| 190 |
sns.set_theme(style="whitegrid", color_codes=True)
|
|
|
|
| 216 |
|
| 217 |
return filename
|
| 218 |
|
| 219 |
+
def plot_affinity_results(method_names, metric, affinity_path="/tmp/affinity_results.csv"):
|
| 220 |
+
if not os.path.exists(affinity_path):
|
| 221 |
+
benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
|
| 222 |
+
download_from_hub(benchmark_types)
|
| 223 |
+
|
| 224 |
+
df = pd.read_csv(affinity_path)
|
| 225 |
|
| 226 |
# Filter for selected methods
|
| 227 |
df = df[df['Method'].isin(method_names)]
|
| 228 |
|
| 229 |
# Gather columns related to the specified metric and validate
|
| 230 |
metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
# Reshape data for plotting
|
| 233 |
df_long = pd.melt(df[['Method'] + metric_columns], id_vars=['Method'], var_name='Fold', value_name='Value')
|
| 234 |
df_long['Fold'] = df_long['Fold'].apply(lambda x: int(x.split('_')[-1])) # Extract fold index for sorting
|
| 235 |
|
| 236 |
+
df = df.fillna(0)
|
| 237 |
+
|
| 238 |
# Set up the plot
|
| 239 |
sns.set(rc={'figure.figsize': (13.7, 8.27)})
|
| 240 |
sns.set_theme(style="whitegrid", color_codes=True)
|
|
|
|
| 250 |
ax.grid(b=True, which='minor', color='whitesmoke', linewidth=0.5)
|
| 251 |
|
| 252 |
# Apply custom color settings to y-axis labels
|
| 253 |
+
for label in ax.get_yticklabels():
|
| 254 |
+
method = label.get_text()
|
| 255 |
+
label.set_color(get_method_color(method))
|
| 256 |
|
| 257 |
# Ensure save path exists
|
| 258 |
os.makedirs(save_path, exist_ok=True)
|
| 259 |
|
| 260 |
# Save the plot
|
| 261 |
+
save_path = "/tmp"
|
| 262 |
filename = os.path.join(save_path, f"{metric}_affinity_results.png")
|
| 263 |
ax.get_figure().savefig(filename, dpi=400, bbox_inches='tight')
|
| 264 |
plt.close() # Close the plot to free memory
|