| """ | |
| Attack Logs to WandB | |
| ======================== | |
| """ | |
| from textattack.shared.utils import LazyLoader, html_table_from_rows | |
| from .logger import Logger | |
| class WeightsAndBiasesLogger(Logger): | |
| """Logs attack results to Weights & Biases.""" | |
| def __init__(self, **kwargs): | |
| global wandb | |
| wandb = LazyLoader("wandb", globals(), "wandb") | |
| wandb.init(**kwargs) | |
| self.kwargs = kwargs | |
| self.project_name = wandb.run.project_name() | |
| self._result_table_rows = [] | |
| def __setstate__(self, state): | |
| global wandb | |
| wandb = LazyLoader("wandb", globals(), "wandb") | |
| self.__dict__ = state | |
| wandb.init(resume=True, **self.kwargs) | |
| def log_summary_rows(self, rows, title, window_id): | |
| table = wandb.Table(columns=["Attack Results", ""]) | |
| for row in rows: | |
| if isinstance(row[1], str): | |
| try: | |
| row[1] = row[1].replace("%", "") | |
| row[1] = float(row[1]) | |
| except ValueError: | |
| raise ValueError( | |
| f'Unable to convert row value "{row[1]}" for Attack Result "{row[0]}" into float' | |
| ) | |
| table.add_data(*row) | |
| metric_name, metric_score = row | |
| wandb.run.summary[metric_name] = metric_score | |
| wandb.log({"attack_params": table}) | |
| def _log_result_table(self): | |
| """Weights & Biases doesn't have a feature to automatically aggregate | |
| results across timesteps and display the full table. | |
| Therefore, we have to do it manually. | |
| """ | |
| result_table = html_table_from_rows( | |
| self._result_table_rows, header=["", "Original Input", "Perturbed Input"] | |
| ) | |
| wandb.log({"results": wandb.Html(result_table)}) | |
| def log_attack_result(self, result): | |
| original_text_colored, perturbed_text_colored = result.diff_color( | |
| color_method="html" | |
| ) | |
| result_num = len(self._result_table_rows) | |
| self._result_table_rows.append( | |
| [ | |
| f"<b>Result {result_num}</b>", | |
| original_text_colored, | |
| perturbed_text_colored, | |
| ] | |
| ) | |
| result_diff_table = html_table_from_rows( | |
| [[original_text_colored, perturbed_text_colored]] | |
| ) | |
| result_diff_table = wandb.Html(result_diff_table) | |
| wandb.log( | |
| { | |
| "result": result_diff_table, | |
| "original_output": result.original_result.output, | |
| "perturbed_output": result.perturbed_result.output, | |
| } | |
| ) | |
| self._log_result_table() | |
| def log_sep(self): | |
| self.fout.write("-" * 90 + "\n") | |