diff --git a/train_log_show.py b/train_log_show.py new file mode 100644 index 0000000..b9e3bc9 --- /dev/null +++ b/train_log_show.py @@ -0,0 +1,31 @@ +import json +import matplotlib.pyplot as plt + +# 从文件读取日志数据 +with open("../trainer_log.jsonl", "r") as f: + logs = [json.loads(line) for line in f.readlines()] + +# 提取关键信息 +steps = [log["current_steps"] for log in logs] +loss = [log["loss"] for log in logs] +learning_rate = [log["learning_rate"] for log in logs] + +# 创建一个新的图形 +plt.figure() + +# 画出损失曲线 +plt.subplot(211) +plt.plot(steps, loss, label='Loss') +plt.xlabel('Steps') +plt.ylabel('Loss') +plt.legend() + +# 画出学习率曲线 +plt.subplot(212) +plt.plot(steps, learning_rate, label='Learning Rate') +plt.xlabel('Steps') +plt.ylabel('Learning Rate') +plt.legend() + +# 显示图形 +plt.show()