from abc import ABC, abstractmethod from pathlib import Path from typing import Union, List from colossalai.core import global_context as gpc class BaseProfiler(ABC): def __init__(self, profiler_name: str, priority: int): self.name = profiler_name self.priority = priority @abstractmethod def enable(self): pass @abstractmethod def disable(self): pass @abstractmethod def to_tensorboard(self, writer): pass @abstractmethod def to_file(self, filename: Path): pass @abstractmethod def show(self): pass class ProfilerContext(object): """ Profiler context manager Usage: ```python world_size = 4 inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) cc_prof = CommProfiler() with ProfilerContext([cc_prof]) as prof: op = dist.all_reduce(inputs, async_op=True) dist.all_gather(outputs_list, inputs) op.wait() dist.reduce_scatter(inputs, outputs_list) dist.broadcast(inputs, 0) dist.reduce(inputs, 0) prof.show() ``` """ def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): self.enable = enable self.profilers = sorted(profilers, key=lambda prof: prof.priority) def __enter__(self): if self.enable: for prof in self.profilers: prof.enable() return self def __exit__(self, exc_type, exc_val, exc_tb): if self.enable: for prof in self.profilers: prof.disable() def to_tensorboard(self, writer): from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter), \ f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' for prof in self.profilers: prof.to_tensorboard(writer) def to_file(self, log_dir: Union[str, Path]): if isinstance(log_dir, str): log_dir = Path(log_dir) if not log_dir.exists(): log_dir.mkdir(parents=True, exist_ok=True) for prof in self.profilers: log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') prof.to_file(log_file) def show(self): for prof in self.profilers: prof.show()