From c39d758a8aeaaf9c56aac017904d0a0ae2284299 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Fri, 29 Dec 2023 16:23:47 +0800 Subject: [PATCH] feat(logger): add tensorboard key value buffer (#549) * feat(logger): add tensorboard key value buffer * fix --- configs/7B_sft.py | 3 ++ internlm/initialize/launch.py | 3 ++ internlm/utils/writer.py | 52 +++++++++++++++++++++++++++-------- train.py | 2 ++ 4 files changed, 49 insertions(+), 11 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index c0a9bc8..9d7f722 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -176,6 +176,9 @@ monitor = dict( light_monitor_address=None, # light_monitor address to send heartbeat alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", ), + tensorboard=dict( + queue_max_length=10, + ), ) # metric_dtype can be "fp32" or other string diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 491e2b0..7d6badc 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -332,6 +332,9 @@ def args_sanity_check(): "alert_file_path": None, } }, + "tensorboard": { + "queue_max_length": 1, + }, } for key, value in monitor_default_config.items(): diff --git a/internlm/utils/writer.py b/internlm/utils/writer.py index 018917a..776d130 100644 --- a/internlm/utils/writer.py +++ b/internlm/utils/writer.py @@ -114,6 +114,8 @@ class Writer: config: str = None, logger: logging.Logger = None, enable_tb: bool = True, + queue_max_length: int = 1, + total_steps: int = 100, ) -> None: self.enable_tb = enable_tb self.tb_writer, self.tb_logdir = init_tb_writer( @@ -126,21 +128,49 @@ class Writer: config=config, logger=logger, ) + self.queue_max_length = queue_max_length + self.total_steps = total_steps + self.add_scalars_buffer = [] + self.add_scalar_buffer = [] + self.add_scalar_step_counter = 0 + self.add_scalars_step_counter = 0 + self.add_scalar_last_step = -1 + self.add_scalars_last_step = -1 def add_scalar(self, key, value, step): - try: - if self.enable_tb and self.tb_writer is not None: - self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step) - except Exception: - traceback.print_exc() + self.add_scalar_buffer.append((key, value, step)) + if step > self.add_scalar_last_step: + self.add_scalar_step_counter += 1 + self.add_scalar_last_step = step + + if self.add_scalar_step_counter == self.queue_max_length or step >= self.total_steps: + while len(self.add_scalar_buffer) > 0: + key, value, step = self.add_scalar_buffer.pop(0) + try: + if self.enable_tb and self.tb_writer is not None: + self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step) + except Exception: + traceback.print_exc() + + self.add_scalar_step_counter = 0 def add_scalars(self, key, value, step): - try: - assert isinstance(value, dict) - if self.enable_tb and self.tb_writer is not None: - self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step) - except Exception: - traceback.print_exc() + self.add_scalars_buffer.append((key, value, step)) + if step > self.add_scalars_last_step: + self.add_scalars_step_counter += 1 + self.add_scalars_last_step = step + + if self.add_scalars_step_counter == self.queue_max_length or step >= self.total_steps: + while len(self.add_scalars_buffer) > 0: + key, value, step = self.add_scalars_buffer.pop(0) + try: + assert isinstance(value, dict) + if self.enable_tb and self.tb_writer is not None: + self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step) + except Exception: + traceback.print_exc() + + self.add_scalars_step_counter = 0 def add_text(self, key, value, step): try: diff --git a/train.py b/train.py index 6874f9e..9610cc5 100644 --- a/train.py +++ b/train.py @@ -138,6 +138,8 @@ def main(args): config=config_lines, logger=logger, enable_tb=gpc.config.enable_tb, + queue_max_length=gpc.config.tensorboard.queue_max_length, + total_steps=total_steps, ) # initialize metric for calculating accuracy and perplexity