Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import typing as tp | |
| import treetable as tt | |
| from .._base_explorers import BaseExplorer | |
| class LMExplorer(BaseExplorer): | |
| eval_metrics: tp.List[str] = [] | |
| def stages(self) -> tp.List[str]: | |
| return ['train', 'valid'] | |
| def get_grid_metrics(self): | |
| """Return the metrics that should be displayed in the tracking table.""" | |
| return [ | |
| tt.group( | |
| 'train', | |
| [ | |
| tt.leaf('epoch'), | |
| tt.leaf('duration', '.1f'), # duration in minutes | |
| tt.leaf('ping'), | |
| tt.leaf('ce', '.4f'), # cross entropy | |
| tt.leaf("ppl", '.3f'), # perplexity | |
| ], | |
| align='>', | |
| ), | |
| tt.group( | |
| 'valid', | |
| [ | |
| tt.leaf('ce', '.4f'), | |
| tt.leaf('ppl', '.3f'), | |
| tt.leaf('best_ppl', '.3f'), | |
| ], | |
| align='>', | |
| ), | |
| ] | |
| def process_sheep(self, sheep, history): | |
| parts = super().process_sheep(sheep, history) | |
| track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher'] | |
| best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()} | |
| def comparator(mode, a, b): | |
| return a < b if mode == 'lower' else a > b | |
| for metrics in history: | |
| for key, sub in metrics.items(): | |
| for metric in track_by: | |
| # for the validation set, keep track of best metrics (ppl in this example) | |
| # this is so we can conveniently compare metrics between runs in the grid | |
| if key == 'valid' and metric in sub and comparator( | |
| track_by[metric], sub[metric], best_metrics[metric] | |
| ): | |
| best_metrics[metric] = sub[metric] | |
| if 'valid' in parts: | |
| parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()}) | |
| return parts | |
| class GenerationEvalExplorer(BaseExplorer): | |
| eval_metrics: tp.List[str] = [] | |
| def stages(self) -> tp.List[str]: | |
| return ['evaluate'] | |
| def get_grid_metrics(self): | |
| """Return the metrics that should be displayed in the tracking table.""" | |
| return [ | |
| tt.group( | |
| 'evaluate', | |
| [ | |
| tt.leaf('epoch', '.3f'), | |
| tt.leaf('duration', '.1f'), | |
| tt.leaf('ping'), | |
| tt.leaf('ce', '.4f'), | |
| tt.leaf('ppl', '.3f'), | |
| tt.leaf('fad', '.3f'), | |
| tt.leaf('kld', '.3f'), | |
| tt.leaf('text_consistency', '.3f'), | |
| tt.leaf('chroma_cosine', '.3f'), | |
| ], | |
| align='>', | |
| ), | |
| ] | |