mirror of https://github.com/InternLM/InternLM
feat(logger): add tensorboard key value buffer (#549)
* feat(logger): add tensorboard key value buffer * fixpull/570/head
parent
d418eba094
commit
c39d758a8a
|
@ -176,6 +176,9 @@ monitor = dict(
|
|||
light_monitor_address=None, # light_monitor address to send heartbeat
|
||||
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
|
||||
),
|
||||
tensorboard=dict(
|
||||
queue_max_length=10,
|
||||
),
|
||||
)
|
||||
|
||||
# metric_dtype can be "fp32" or other string
|
||||
|
|
|
@ -332,6 +332,9 @@ def args_sanity_check():
|
|||
"alert_file_path": None,
|
||||
}
|
||||
},
|
||||
"tensorboard": {
|
||||
"queue_max_length": 1,
|
||||
},
|
||||
}
|
||||
|
||||
for key, value in monitor_default_config.items():
|
||||
|
|
|
@ -114,6 +114,8 @@ class Writer:
|
|||
config: str = None,
|
||||
logger: logging.Logger = None,
|
||||
enable_tb: bool = True,
|
||||
queue_max_length: int = 1,
|
||||
total_steps: int = 100,
|
||||
) -> None:
|
||||
self.enable_tb = enable_tb
|
||||
self.tb_writer, self.tb_logdir = init_tb_writer(
|
||||
|
@ -126,21 +128,49 @@ class Writer:
|
|||
config=config,
|
||||
logger=logger,
|
||||
)
|
||||
self.queue_max_length = queue_max_length
|
||||
self.total_steps = total_steps
|
||||
self.add_scalars_buffer = []
|
||||
self.add_scalar_buffer = []
|
||||
self.add_scalar_step_counter = 0
|
||||
self.add_scalars_step_counter = 0
|
||||
self.add_scalar_last_step = -1
|
||||
self.add_scalars_last_step = -1
|
||||
|
||||
def add_scalar(self, key, value, step):
|
||||
try:
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
self.add_scalar_buffer.append((key, value, step))
|
||||
if step > self.add_scalar_last_step:
|
||||
self.add_scalar_step_counter += 1
|
||||
self.add_scalar_last_step = step
|
||||
|
||||
if self.add_scalar_step_counter == self.queue_max_length or step >= self.total_steps:
|
||||
while len(self.add_scalar_buffer) > 0:
|
||||
key, value, step = self.add_scalar_buffer.pop(0)
|
||||
try:
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
self.add_scalar_step_counter = 0
|
||||
|
||||
def add_scalars(self, key, value, step):
|
||||
try:
|
||||
assert isinstance(value, dict)
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
self.add_scalars_buffer.append((key, value, step))
|
||||
if step > self.add_scalars_last_step:
|
||||
self.add_scalars_step_counter += 1
|
||||
self.add_scalars_last_step = step
|
||||
|
||||
if self.add_scalars_step_counter == self.queue_max_length or step >= self.total_steps:
|
||||
while len(self.add_scalars_buffer) > 0:
|
||||
key, value, step = self.add_scalars_buffer.pop(0)
|
||||
try:
|
||||
assert isinstance(value, dict)
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
self.add_scalars_step_counter = 0
|
||||
|
||||
def add_text(self, key, value, step):
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue