From 425bb0df3f58e361f816271809fd1ce3b532c6a1 Mon Sep 17 00:00:00 2001 From: HELSON <72907851+1SAA@users.noreply.github.com> Date: Wed, 9 Mar 2022 16:12:41 +0800 Subject: [PATCH] Added Profiler Context to manage all profilers (#340) --- colossalai/utils/profiler/__init__.py | 3 +- colossalai/utils/profiler/comm_profiler.py | 183 ++++++++++++--------- colossalai/utils/profiler/prof_utils.py | 85 ++++++++++ tests/test_profiler/test_comm_prof.py | 40 ----- 4 files changed, 193 insertions(+), 118 deletions(-) create mode 100644 colossalai/utils/profiler/prof_utils.py delete mode 100644 tests/test_profiler/test_comm_prof.py diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/utils/profiler/__init__.py index 9b216c437..6ec3a29c3 100644 --- a/colossalai/utils/profiler/__init__.py +++ b/colossalai/utils/profiler/__init__.py @@ -1 +1,2 @@ -from .comm_profiler import enable_communication_prof, communication_prof_show +from .comm_profiler import CommProfiler +from .prof_utils import ProfilerContext diff --git a/colossalai/utils/profiler/comm_profiler.py b/colossalai/utils/profiler/comm_profiler.py index 5ffef66c8..25538dff0 100644 --- a/colossalai/utils/profiler/comm_profiler.py +++ b/colossalai/utils/profiler/comm_profiler.py @@ -1,9 +1,13 @@ import inspect +from pathlib import Path +from functools import partial import torch from torch.autograd.profiler import profile import torch.distributed as dist from torch.distributed import ReduceOp +import torch.utils.tensorboard as tb from colossalai.utils import get_current_device +from .prof_utils import BaseProfiler from typing import List, Optional @@ -57,6 +61,13 @@ def _format_bandwith(volme: float, time_us: int): return '{:.3f} MB/s'.format(mb_per_sec) +torch_all_reduce = dist.all_reduce +torch_all_gather = dist.all_gather +torch_reduce_scatter = dist.reduce_scatter +torch_broadcast = dist.broadcast +torch_reduce = dist.reduce + + class CommEvent(object): """Communication Event. Used for communication time and communication volume recording. @@ -73,16 +84,16 @@ class CommEvent(object): self.self_cuda_time += rhs.self_cuda_time -class CommProfiler(object): +class CommProfiler(BaseProfiler): """Communication profiler. Records all communication events. """ - def __init__(self, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0, prof_depth: int = 3): - super().__init__() + def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): + super().__init__(profiler_name="Collective_Communication", priority=0) + self.depth = 3 + depth self.total_count = total_count self.total_comm_vol = total_comm_vol self.total_cuda_time = total_cuda_time - self.depth = prof_depth self.ops_record = dict() self.profiler = None @@ -101,27 +112,58 @@ class CommProfiler(object): self.pending_metadata = None self.warn_flag = False + def enable(self): + dist.all_reduce = partial(all_reduce, profiler=self) + dist.all_gather = partial(all_gather, profiler=self) + dist.reduce_scatter = partial(reduce_scatter, profiler=self) + dist.broadcast = partial(broadcast, profiler=self) + dist.reduce = partial(reduce, profiler=self) + + def disable(self): + dist.all_reduce = torch_all_reduce + dist.all_gather = torch_all_gather + dist.reduce_scatter = torch_reduce_scatter + dist.broadcast = torch_broadcast + dist.reduce = torch_reduce + + def to_tensorboard(self, writer: tb.writer): + writer.add_text(tag="Collective Communication", text_string=self.result_list("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_list()) + def show(self): + print(self.result_list()) + + def result_list(self, sep: str = "\n"): + res = [] + + def append(s: str): + res.append(s) + res.append(sep) + if self.warn_flag: - print("Warnning: there exists multiple communication operations in the same time.\n" - "As a result, the profiling result is not accurate.") - print("Collective communication profiling result:", - "total cuda time: {}".format(_format_time(self.total_cuda_time)), - "average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)), - "total number of calls: {}".format(self.total_count), - "All events:", - sep='\n') + append("Warnning: there exists multiple communication operations in the same time. As a result, " + "the profiling result is not accurate.") + + append("Collective communication profiling result:") + append("total cuda time: {}".format(_format_time(self.total_cuda_time))) + append("average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time))) + append("total number of calls: {}".format(self.total_count)) + append("All events:\n----------------------------------------") show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) for location, event in show_list: - print(location, - "self cuda time: {}".format(_format_time(event.self_cuda_time)), - "{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0), - "self communication volme: {}".format(_format_memory(event.self_comm_vol)), - "average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)), - "number of calls: {}".format(event.self_count), - "--------------------", - sep='\n') + append(location) + append("self cuda time: {}".format(_format_time(event.self_cuda_time))) + append("{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0)) + append("self communication volme: {}".format(_format_memory(event.self_comm_vol))) + append("average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time))) + append("number of calls: {}".format(event.self_count)) + append("----------------------------------------") + + return ''.join(res) @property def has_aync_op(self): @@ -176,65 +218,46 @@ class CommHandler(object): """Communication handler. A dummy handler to wait aync operations. """ - def __init__(self): + def __init__(self, profiler: CommProfiler): super().__init__() - self.prof = COL_COMM_PROF + self.prof = profiler def wait(self): self.prof.wait_async_op() -COL_COMM_PROF = CommProfiler() -torch_all_reduce = dist.all_reduce -torch_all_gather = dist.all_gather -torch_reduce_scatter = dist.reduce_scatter -torch_broadcast = dist.broadcast -torch_reduce = dist.reduce - - -def enable_communication_prof(depth: int = 0): - COL_COMM_PROF.depth = 3 + depth - dist.all_reduce = all_reduce - dist.all_gather = all_gather - dist.reduce_scatter = reduce_scatter - dist.broadcast = broadcast - dist.reduce = reduce - - -def communication_prof_show(): - COL_COMM_PROF.show() - - -def async_check(): - if COL_COMM_PROF.pending_op is not None: - COL_COMM_PROF.warn_flag = True - COL_COMM_PROF.wait_async_op() +def async_check(profiler: CommProfiler): + if profiler.pending_op is not None: + profiler.warn_flag = True + profiler.wait_async_op() def all_reduce(tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, group=None, - async_op: bool = False) -> Optional[CommHandler]: - async_check() + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) comm_size = dist.get_world_size(group) correction = 2 * (comm_size - 1) / comm_size comm_vol = correction * tensor.element_size() * tensor.numel() - COL_COMM_PROF.activate_profiler("ncclKernel_AllReduce_", comm_vol) - COL_COMM_PROF.pending_op = torch_all_reduce(tensor, op, group, async_op) + profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) + profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) if async_op: - return CommHandler() + return CommHandler(profiler) - COL_COMM_PROF.close_profiler(group) + profiler.close_profiler(group) def reduce_scatter(output: torch.Tensor, input_list: List[torch.Tensor], op: ReduceOp = ReduceOp.SUM, group=None, - async_op: bool = False) -> Optional[CommHandler]: - async_check() + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) comm_size = dist.get_world_size(group) correction = (comm_size - 1) / comm_size @@ -242,20 +265,21 @@ def reduce_scatter(output: torch.Tensor, for tensor in input_list: comm_vol += tensor.element_size() * tensor.numel() comm_vol *= correction - COL_COMM_PROF.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) - COL_COMM_PROF.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) + profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) + profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) if async_op: - return CommHandler() + return CommHandler(profiler) - COL_COMM_PROF.close_profiler(group) + profiler.close_profiler(group) def all_gather(tensor_list: List[torch.Tensor], tensor: torch.Tensor, group=None, - async_op: bool = False) -> Optional[CommHandler]: - async_check() + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) comm_size = dist.get_world_size(group) correction = (comm_size - 1) / comm_size @@ -263,40 +287,45 @@ def all_gather(tensor_list: List[torch.Tensor], for ten in tensor_list: comm_vol += ten.element_size() * ten.numel() comm_vol *= correction - COL_COMM_PROF.activate_profiler("ncclKernel_AllGather_", comm_vol) - COL_COMM_PROF.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) + profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) + profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) if async_op: - return CommHandler() + return CommHandler(profiler) - COL_COMM_PROF.close_profiler(group) + profiler.close_profiler(group) -def broadcast(tensor: torch.Tensor, src: int, group=None, async_op: bool = False) -> Optional[CommHandler]: - async_check() +def broadcast(tensor: torch.Tensor, + src: int, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) comm_vol = 1.0 * tensor.element_size() * tensor.numel() - COL_COMM_PROF.activate_profiler("ncclKernel_Broadcast_", comm_vol) - COL_COMM_PROF.pending_op = torch_broadcast(tensor, src, group, async_op) + profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) + profiler.pending_op = torch_broadcast(tensor, src, group, async_op) if async_op: - return CommHandler() + return CommHandler(profiler) - COL_COMM_PROF.close_profiler(group) + profiler.close_profiler(group) def reduce(tensor: torch.Tensor, dst: int, op: ReduceOp = ReduceOp.SUM, group=None, - async_op: bool = False) -> Optional[CommHandler]: - async_check() + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) comm_vol = 1.0 * tensor.element_size() * tensor.numel() - COL_COMM_PROF.activate_profiler("ncclKernel_Reduce_", comm_vol) - COL_COMM_PROF.pending_op = torch_reduce(tensor, dst, op, group, async_op) + profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) + profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) if async_op: - return CommHandler() + return CommHandler(profiler) - COL_COMM_PROF.close_profiler(group) + profiler.close_profiler(group) diff --git a/colossalai/utils/profiler/prof_utils.py b/colossalai/utils/profiler/prof_utils.py new file mode 100644 index 000000000..5d9b23178 --- /dev/null +++ b/colossalai/utils/profiler/prof_utils.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Union, List +from colossalai.core import global_context as gpc + + +class BaseProfiler(ABC): + + def __init__(self, profiler_name: str, priority: int): + self.name = profiler_name + self.priority = priority + + @abstractmethod + def enable(self): + pass + + @abstractmethod + def disable(self): + pass + + @abstractmethod + def to_tensorboard(self, writer): + pass + + @abstractmethod + def to_file(self, filename: Path): + pass + + @abstractmethod + def show(self): + pass + + +class ProfilerContext(object): + """ + Profiler context manager + Usage: + from colossalai.utils.profiler import CommProf, ProfilerContext + from torch.utils.tensorboard import SummaryWriter + cc_prof = CommProf() + with ProfilerContext([cc_prof]) as prof: + train() + writer = SummaryWriter('tb/path') + prof.to_tensorboard(writer) + prof.to_file('./prof_logs/') + prof.show() + """ + + def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): + self.enable = enable + self.profilers = sorted(profilers, key=lambda prof: prof.priority) + + def __enter__(self): + if self.enable: + for prof in self.profilers: + prof.enable() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + for prof in self.profilers: + prof.disable() + + def to_tensorboard(self, writer): + from torch.utils.tensorboard import SummaryWriter + + assert isinstance(writer, SummaryWriter), \ + f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' + + for prof in self.profilers: + prof.to_tensorboard(writer) + + def to_file(self, log_dir: Union[str, Path]): + if isinstance(log_dir, str): + log_dir = Path(log_dir) + + if not log_dir.exists(): + log_dir.mkdir(parents=True, exist_ok=True) + for prof in self.profilers: + log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') + prof.to_file(log_file) + + def show(self): + for prof in self.profilers: + prof.show() diff --git a/tests/test_profiler/test_comm_prof.py b/tests/test_profiler/test_comm_prof.py deleted file mode 100644 index 133ea0262..000000000 --- a/tests/test_profiler/test_comm_prof.py +++ /dev/null @@ -1,40 +0,0 @@ -from functools import partial -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -import colossalai -from colossalai.utils import free_port, get_current_device -from colossalai.utils.profiler import enable_communication_prof, communication_prof_show - -BATCH_SIZE = 1024 -D_MODEL = 1024 -CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=4))) - - -def run_test(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - inputs = torch.randn(BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device()) - outputs = torch.empty(world_size, BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device()) - outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) - - enable_communication_prof() - - op = dist.all_reduce(inputs, async_op=True) - dist.all_gather(outputs_list, inputs) - op.wait() - dist.reduce_scatter(inputs, outputs_list) - dist.broadcast(inputs, 0) - dist.reduce(inputs, 0) - - if rank == 0: - communication_prof_show() - - -def test_cc_prof(): - world_size = 4 - run_func = partial(run_test, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_cc_prof()