Spaces:
Runtime error
Runtime error
| # log 画图 | |
| from datetime import datetime | |
| import numpy as np | |
| import pandas as pd | |
| from matplotlib import pyplot as plt | |
| import sys | |
| sys.path.extend(['.', '..']) | |
| from config import PROJECT_ROOT | |
| def str_to_timestamp(string: str) -> float: | |
| ''' | |
| ''' | |
| date_fmt = '%Y-%m-%d %H:%M:%S.%f' | |
| string = string.replace('[', '').replace(']', '') | |
| # 转化为时间戳 | |
| return datetime.strptime(string, date_fmt).timestamp() | |
| def plot_traing_loss(log_file: str, start_date: str, end_date: str, pic_save_to_file: str=None) -> None: | |
| ''' | |
| 将log日志中记录的画图,按需保存到文件,由于log日志打印内容较多,需要指定要打印loss的开始时间和结束时间 | |
| examlpe: | |
| >>> plot_traing_loss('./logs/trainer.log', '[2023-10-01 08:44:39.303]', '[2023-10-01 11:29:12.376]') | |
| >>> plot_traing_loss('./logs/trainer.log', '2023-10-01 08:44:39.303', '2023-10-01 11:29:12.376') | |
| ''' | |
| start_timestamp = str_to_timestamp(start_date) | |
| end_timestamp = str_to_timestamp(end_date) | |
| loss_list = [] | |
| with open(log_file, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| if 'training loss: epoch:' in line: | |
| line = line.split(' ') | |
| date = ' '.join(line[0: 2]) | |
| if str_to_timestamp(date) < start_timestamp: | |
| continue | |
| if str_to_timestamp(date) > end_timestamp: | |
| break | |
| if len(line) != 9: continue | |
| epoch = line[5][6: -1] # 'epoch:0,' | |
| step = line[6][5: -1] # 'step:0,' | |
| loss = float(line[7][5: -1]) # 'loss:0.11086619377136231\n' | |
| device = line[8][7: -1] | |
| loss_list.append([epoch, step, loss, device]) | |
| df = pd.DataFrame(loss_list, columns=['epoch', 'step', 'loss', 'device']) | |
| # 多项式拟合 | |
| x = list(range(0, len(df['loss']))) | |
| x_range = np.arange(0, len(df['loss']), step=0.005) | |
| fit3 = np.polyfit(x, df['loss'], 3) | |
| p1d = np.poly1d(fit3) | |
| y_fit = p1d(x_range) | |
| plt.figure(figsize=(8, 6),dpi=100) | |
| plt.plot(df['loss'],'g',label = 'loss') | |
| plt.plot(x_range, y_fit, 'r', label='fit loss') | |
| plt.ylabel('loss') | |
| plt.xlabel('sampling step') | |
| plt.legend() #个性化图例(颜色、形状等) | |
| if pic_save_to_file is not None: | |
| plt.savefig(pic_save_to_file) | |
| plt.show() | |
| if __name__ == '__main__': | |
| # plot_traing_loss(PROJECT_ROOT + '/logs/chat_trainer-20231011.log', '[2023-10-11 11:04:53.960]', '[2023-10-18 01:41:40.540]', pic_save_to_file=PROJECT_ROOT + '/img/train_loss.png') | |
| # plot_traing_loss(PROJECT_ROOT + '/logs/chat_trainer-20231018.log', '[2023-10-18 02:06:28.137]', '[2023-10-18 18:03:35.230]', pic_save_to_file=PROJECT_ROOT + '/img/finetune_loss.png') | |
| pass | |