mirror of https://github.com/hpcaitech/ColossalAI
131 lines
3.8 KiB
Python
131 lines
3.8 KiB
Python
|
import time
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from .WandbLog import TensorboardLog
|
||
|
|
||
|
_GLOBAL_TIMERS = None
|
||
|
_GLOBAL_TENSORBOARD_WRITER = None
|
||
|
|
||
|
|
||
|
def set_global_variables(launch_time, tensorboard_path):
|
||
|
_set_timers()
|
||
|
_set_tensorboard_writer(launch_time, tensorboard_path)
|
||
|
|
||
|
|
||
|
def _set_timers():
|
||
|
"""Initialize timers."""
|
||
|
global _GLOBAL_TIMERS
|
||
|
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
|
||
|
_GLOBAL_TIMERS = Timers()
|
||
|
|
||
|
|
||
|
def _set_tensorboard_writer(launch_time, tensorboard_path):
|
||
|
"""Set tensorboard writer."""
|
||
|
global _GLOBAL_TENSORBOARD_WRITER
|
||
|
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer')
|
||
|
if torch.distributed.get_rank() == 0:
|
||
|
_GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time)
|
||
|
|
||
|
|
||
|
def get_timers():
|
||
|
"""Return timers."""
|
||
|
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
|
||
|
return _GLOBAL_TIMERS
|
||
|
|
||
|
|
||
|
def get_tensorboard_writer():
|
||
|
"""Return tensorboard writer. It can be None so no need
|
||
|
to check if it is initialized."""
|
||
|
return _GLOBAL_TENSORBOARD_WRITER
|
||
|
|
||
|
|
||
|
def _ensure_var_is_initialized(var, name):
|
||
|
"""Make sure the input variable is not None."""
|
||
|
assert var is not None, '{} is not initialized.'.format(name)
|
||
|
|
||
|
|
||
|
def _ensure_var_is_not_initialized(var, name):
|
||
|
"""Make sure the input variable is not None."""
|
||
|
assert var is None, '{} is already initialized.'.format(name)
|
||
|
|
||
|
|
||
|
class _Timer:
|
||
|
"""Timer."""
|
||
|
|
||
|
def __init__(self, name):
|
||
|
self.name_ = name
|
||
|
self.elapsed_ = 0.0
|
||
|
self.started_ = False
|
||
|
self.start_time = time.time()
|
||
|
|
||
|
def start(self):
|
||
|
"""Start the timer."""
|
||
|
# assert not self.started_, 'timer has already been started'
|
||
|
torch.cuda.synchronize()
|
||
|
self.start_time = time.time()
|
||
|
self.started_ = True
|
||
|
|
||
|
def stop(self):
|
||
|
"""Stop the timer."""
|
||
|
assert self.started_, 'timer is not started'
|
||
|
torch.cuda.synchronize()
|
||
|
self.elapsed_ += (time.time() - self.start_time)
|
||
|
self.started_ = False
|
||
|
|
||
|
def reset(self):
|
||
|
"""Reset timer."""
|
||
|
self.elapsed_ = 0.0
|
||
|
self.started_ = False
|
||
|
|
||
|
def elapsed(self, reset=True):
|
||
|
"""Calculate the elapsed time."""
|
||
|
started_ = self.started_
|
||
|
# If the timing in progress, end it first.
|
||
|
if self.started_:
|
||
|
self.stop()
|
||
|
# Get the elapsed time.
|
||
|
elapsed_ = self.elapsed_
|
||
|
# Reset the elapsed time
|
||
|
if reset:
|
||
|
self.reset()
|
||
|
# If timing was in progress, set it back.
|
||
|
if started_:
|
||
|
self.start()
|
||
|
return elapsed_
|
||
|
|
||
|
|
||
|
class Timers:
|
||
|
"""Group of timers."""
|
||
|
|
||
|
def __init__(self):
|
||
|
self.timers = {}
|
||
|
|
||
|
def __call__(self, name):
|
||
|
if name not in self.timers:
|
||
|
self.timers[name] = _Timer(name)
|
||
|
return self.timers[name]
|
||
|
|
||
|
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
|
||
|
"""Write timers to a tensorboard writer"""
|
||
|
# currently when using add_scalars,
|
||
|
# torch.utils.add_scalars makes each timer its own run, which
|
||
|
# polutes the runs list, so we just add each as a scalar
|
||
|
assert normalizer > 0.0
|
||
|
for name in names:
|
||
|
value = self.timers[name].elapsed(reset=reset) / normalizer
|
||
|
writer.add_scalar(name + '-time', value, iteration)
|
||
|
|
||
|
def log(self, names, normalizer=1.0, reset=True):
|
||
|
"""Log a group of timers."""
|
||
|
assert normalizer > 0.0
|
||
|
string = 'time (ms)'
|
||
|
for name in names:
|
||
|
elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
|
||
|
string += ' | {}: {:.2f}'.format(name, elapsed_time)
|
||
|
if torch.distributed.is_initialized():
|
||
|
if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1):
|
||
|
print(string, flush=True)
|
||
|
else:
|
||
|
print(string, flush=True)
|