import time import torch from .WandbLog import TensorboardLog _GLOBAL_TIMERS = None _GLOBAL_TENSORBOARD_WRITER = None def set_global_variables(launch_time, tensorboard_path): _set_timers() _set_tensorboard_writer(launch_time, tensorboard_path) def _set_timers(): """Initialize timers.""" global _GLOBAL_TIMERS _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") _GLOBAL_TIMERS = Timers() def _set_tensorboard_writer(launch_time, tensorboard_path): """Set tensorboard writer.""" global _GLOBAL_TENSORBOARD_WRITER _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer") if torch.distributed.get_rank() == 0: _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f"/{launch_time}", launch_time) def get_timers(): """Return timers.""" _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") return _GLOBAL_TIMERS def get_tensorboard_writer(): """Return tensorboard writer. It can be None so no need to check if it is initialized.""" return _GLOBAL_TENSORBOARD_WRITER def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" assert var is not None, "{} is not initialized.".format(name) def _ensure_var_is_not_initialized(var, name): """Make sure the input variable is not None.""" assert var is None, "{} is already initialized.".format(name) class _Timer: """Timer.""" def __init__(self, name): self.name_ = name self.elapsed_ = 0.0 self.started_ = False self.start_time = time.time() def start(self): """Start the timer.""" # assert not self.started_, 'timer has already been started' torch.cuda.synchronize() self.start_time = time.time() self.started_ = True def stop(self): """Stop the timer.""" assert self.started_, "timer is not started" torch.cuda.synchronize() self.elapsed_ += time.time() - self.start_time self.started_ = False def reset(self): """Reset timer.""" self.elapsed_ = 0.0 self.started_ = False def elapsed(self, reset=True): """Calculate the elapsed time.""" started_ = self.started_ # If the timing in progress, end it first. if self.started_: self.stop() # Get the elapsed time. elapsed_ = self.elapsed_ # Reset the elapsed time if reset: self.reset() # If timing was in progress, set it back. if started_: self.start() return elapsed_ class Timers: """Group of timers.""" def __init__(self): self.timers = {} def __call__(self, name): if name not in self.timers: self.timers[name] = _Timer(name) return self.timers[name] def write(self, names, writer, iteration, normalizer=1.0, reset=False): """Write timers to a tensorboard writer""" # currently when using add_scalars, # torch.utils.add_scalars makes each timer its own run, which # pollutes the runs list, so we just add each as a scalar assert normalizer > 0.0 for name in names: value = self.timers[name].elapsed(reset=reset) / normalizer writer.add_scalar(name + "-time", value, iteration) def log(self, names, normalizer=1.0, reset=True): """Log a group of timers.""" assert normalizer > 0.0 string = "time (ms)" for name in names: elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer string += " | {}: {:.2f}".format(name, elapsed_time) if torch.distributed.is_initialized(): if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): print(string, flush=True) else: print(string, flush=True)