feat(logger): add tensorboard key value buffer (#549)

* feat(logger): add tensorboard key value buffer

* fix
pull/570/head
Guoteng 2023-12-29 16:23:47 +08:00 committed by GitHub
parent d418eba094
commit c39d758a8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 11 deletions

View File

@ -176,6 +176,9 @@ monitor = dict(
light_monitor_address=None, # light_monitor address to send heartbeat
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
),
tensorboard=dict(
queue_max_length=10,
),
)
# metric_dtype can be "fp32" or other string

View File

@ -332,6 +332,9 @@ def args_sanity_check():
"alert_file_path": None,
}
},
"tensorboard": {
"queue_max_length": 1,
},
}
for key, value in monitor_default_config.items():

View File

@ -114,6 +114,8 @@ class Writer:
config: str = None,
logger: logging.Logger = None,
enable_tb: bool = True,
queue_max_length: int = 1,
total_steps: int = 100,
) -> None:
self.enable_tb = enable_tb
self.tb_writer, self.tb_logdir = init_tb_writer(
@ -126,21 +128,49 @@ class Writer:
config=config,
logger=logger,
)
self.queue_max_length = queue_max_length
self.total_steps = total_steps
self.add_scalars_buffer = []
self.add_scalar_buffer = []
self.add_scalar_step_counter = 0
self.add_scalars_step_counter = 0
self.add_scalar_last_step = -1
self.add_scalars_last_step = -1
def add_scalar(self, key, value, step):
try:
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
except Exception:
traceback.print_exc()
self.add_scalar_buffer.append((key, value, step))
if step > self.add_scalar_last_step:
self.add_scalar_step_counter += 1
self.add_scalar_last_step = step
if self.add_scalar_step_counter == self.queue_max_length or step >= self.total_steps:
while len(self.add_scalar_buffer) > 0:
key, value, step = self.add_scalar_buffer.pop(0)
try:
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
except Exception:
traceback.print_exc()
self.add_scalar_step_counter = 0
def add_scalars(self, key, value, step):
try:
assert isinstance(value, dict)
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
except Exception:
traceback.print_exc()
self.add_scalars_buffer.append((key, value, step))
if step > self.add_scalars_last_step:
self.add_scalars_step_counter += 1
self.add_scalars_last_step = step
if self.add_scalars_step_counter == self.queue_max_length or step >= self.total_steps:
while len(self.add_scalars_buffer) > 0:
key, value, step = self.add_scalars_buffer.pop(0)
try:
assert isinstance(value, dict)
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
except Exception:
traceback.print_exc()
self.add_scalars_step_counter = 0
def add_text(self, key, value, step):
try:

View File

@ -138,6 +138,8 @@ def main(args):
config=config_lines,
logger=logger,
enable_tb=gpc.config.enable_tb,
queue_max_length=gpc.config.tensorboard.queue_max_length,
total_steps=total_steps,
)
# initialize metric for calculating accuracy and perplexity