Added Profiler Context to manage all profilers (#340)

pull/394/head
HELSON 2022-03-09 16:12:41 +08:00 committed by Frank Lee
parent d0ae0f2215
commit 425bb0df3f
4 changed files with 193 additions and 118 deletions

View File

@ -1 +1,2 @@
from .comm_profiler import enable_communication_prof, communication_prof_show from .comm_profiler import CommProfiler
from .prof_utils import ProfilerContext

View File

@ -1,9 +1,13 @@
import inspect import inspect
from pathlib import Path
from functools import partial
import torch import torch
from torch.autograd.profiler import profile from torch.autograd.profiler import profile
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
import torch.utils.tensorboard as tb
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .prof_utils import BaseProfiler
from typing import List, Optional from typing import List, Optional
@ -57,6 +61,13 @@ def _format_bandwith(volme: float, time_us: int):
return '{:.3f} MB/s'.format(mb_per_sec) return '{:.3f} MB/s'.format(mb_per_sec)
torch_all_reduce = dist.all_reduce
torch_all_gather = dist.all_gather
torch_reduce_scatter = dist.reduce_scatter
torch_broadcast = dist.broadcast
torch_reduce = dist.reduce
class CommEvent(object): class CommEvent(object):
"""Communication Event. Used for communication time and communication """Communication Event. Used for communication time and communication
volume recording. volume recording.
@ -73,16 +84,16 @@ class CommEvent(object):
self.self_cuda_time += rhs.self_cuda_time self.self_cuda_time += rhs.self_cuda_time
class CommProfiler(object): class CommProfiler(BaseProfiler):
"""Communication profiler. Records all communication events. """Communication profiler. Records all communication events.
""" """
def __init__(self, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0, prof_depth: int = 3): def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
super().__init__() super().__init__(profiler_name="Collective_Communication", priority=0)
self.depth = 3 + depth
self.total_count = total_count self.total_count = total_count
self.total_comm_vol = total_comm_vol self.total_comm_vol = total_comm_vol
self.total_cuda_time = total_cuda_time self.total_cuda_time = total_cuda_time
self.depth = prof_depth
self.ops_record = dict() self.ops_record = dict()
self.profiler = None self.profiler = None
@ -101,27 +112,58 @@ class CommProfiler(object):
self.pending_metadata = None self.pending_metadata = None
self.warn_flag = False self.warn_flag = False
def enable(self):
dist.all_reduce = partial(all_reduce, profiler=self)
dist.all_gather = partial(all_gather, profiler=self)
dist.reduce_scatter = partial(reduce_scatter, profiler=self)
dist.broadcast = partial(broadcast, profiler=self)
dist.reduce = partial(reduce, profiler=self)
def disable(self):
dist.all_reduce = torch_all_reduce
dist.all_gather = torch_all_gather
dist.reduce_scatter = torch_reduce_scatter
dist.broadcast = torch_broadcast
dist.reduce = torch_reduce
def to_tensorboard(self, writer: tb.writer):
writer.add_text(tag="Collective Communication", text_string=self.result_list("\n\n"))
def to_file(self, filename: Path):
with open(filename, "w") as f:
f.write(self.result_list())
def show(self): def show(self):
print(self.result_list())
def result_list(self, sep: str = "\n"):
res = []
def append(s: str):
res.append(s)
res.append(sep)
if self.warn_flag: if self.warn_flag:
print("Warnning: there exists multiple communication operations in the same time.\n" append("Warnning: there exists multiple communication operations in the same time. As a result, "
"As a result, the profiling result is not accurate.") "the profiling result is not accurate.")
print("Collective communication profiling result:",
"total cuda time: {}".format(_format_time(self.total_cuda_time)), append("Collective communication profiling result:")
"average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)), append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
"total number of calls: {}".format(self.total_count), append("average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)))
"All events:", append("total number of calls: {}".format(self.total_count))
sep='\n') append("All events:\n----------------------------------------")
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list: for location, event in show_list:
print(location, append(location)
"self cuda time: {}".format(_format_time(event.self_cuda_time)), append("self cuda time: {}".format(_format_time(event.self_cuda_time)))
"{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0), append("{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0))
"self communication volme: {}".format(_format_memory(event.self_comm_vol)), append("self communication volme: {}".format(_format_memory(event.self_comm_vol)))
"average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)), append("average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)))
"number of calls: {}".format(event.self_count), append("number of calls: {}".format(event.self_count))
"--------------------", append("----------------------------------------")
sep='\n')
return ''.join(res)
@property @property
def has_aync_op(self): def has_aync_op(self):
@ -176,65 +218,46 @@ class CommHandler(object):
"""Communication handler. A dummy handler to wait aync operations. """Communication handler. A dummy handler to wait aync operations.
""" """
def __init__(self): def __init__(self, profiler: CommProfiler):
super().__init__() super().__init__()
self.prof = COL_COMM_PROF self.prof = profiler
def wait(self): def wait(self):
self.prof.wait_async_op() self.prof.wait_async_op()
COL_COMM_PROF = CommProfiler() def async_check(profiler: CommProfiler):
torch_all_reduce = dist.all_reduce if profiler.pending_op is not None:
torch_all_gather = dist.all_gather profiler.warn_flag = True
torch_reduce_scatter = dist.reduce_scatter profiler.wait_async_op()
torch_broadcast = dist.broadcast
torch_reduce = dist.reduce
def enable_communication_prof(depth: int = 0):
COL_COMM_PROF.depth = 3 + depth
dist.all_reduce = all_reduce
dist.all_gather = all_gather
dist.reduce_scatter = reduce_scatter
dist.broadcast = broadcast
dist.reduce = reduce
def communication_prof_show():
COL_COMM_PROF.show()
def async_check():
if COL_COMM_PROF.pending_op is not None:
COL_COMM_PROF.warn_flag = True
COL_COMM_PROF.wait_async_op()
def all_reduce(tensor: torch.Tensor, def all_reduce(tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
group=None, group=None,
async_op: bool = False) -> Optional[CommHandler]: async_op: bool = False,
async_check() profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group) comm_size = dist.get_world_size(group)
correction = 2 * (comm_size - 1) / comm_size correction = 2 * (comm_size - 1) / comm_size
comm_vol = correction * tensor.element_size() * tensor.numel() comm_vol = correction * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_AllReduce_", comm_vol) profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol)
COL_COMM_PROF.pending_op = torch_all_reduce(tensor, op, group, async_op) profiler.pending_op = torch_all_reduce(tensor, op, group, async_op)
if async_op: if async_op:
return CommHandler() return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group) profiler.close_profiler(group)
def reduce_scatter(output: torch.Tensor, def reduce_scatter(output: torch.Tensor,
input_list: List[torch.Tensor], input_list: List[torch.Tensor],
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
group=None, group=None,
async_op: bool = False) -> Optional[CommHandler]: async_op: bool = False,
async_check() profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group) comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size correction = (comm_size - 1) / comm_size
@ -242,20 +265,21 @@ def reduce_scatter(output: torch.Tensor,
for tensor in input_list: for tensor in input_list:
comm_vol += tensor.element_size() * tensor.numel() comm_vol += tensor.element_size() * tensor.numel()
comm_vol *= correction comm_vol *= correction
COL_COMM_PROF.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
COL_COMM_PROF.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
if async_op: if async_op:
return CommHandler() return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group) profiler.close_profiler(group)
def all_gather(tensor_list: List[torch.Tensor], def all_gather(tensor_list: List[torch.Tensor],
tensor: torch.Tensor, tensor: torch.Tensor,
group=None, group=None,
async_op: bool = False) -> Optional[CommHandler]: async_op: bool = False,
async_check() profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group) comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size correction = (comm_size - 1) / comm_size
@ -263,40 +287,45 @@ def all_gather(tensor_list: List[torch.Tensor],
for ten in tensor_list: for ten in tensor_list:
comm_vol += ten.element_size() * ten.numel() comm_vol += ten.element_size() * ten.numel()
comm_vol *= correction comm_vol *= correction
COL_COMM_PROF.activate_profiler("ncclKernel_AllGather_", comm_vol) profiler.activate_profiler("ncclKernel_AllGather_", comm_vol)
COL_COMM_PROF.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
if async_op: if async_op:
return CommHandler() return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group) profiler.close_profiler(group)
def broadcast(tensor: torch.Tensor, src: int, group=None, async_op: bool = False) -> Optional[CommHandler]: def broadcast(tensor: torch.Tensor,
async_check() src: int,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_vol = 1.0 * tensor.element_size() * tensor.numel() comm_vol = 1.0 * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_Broadcast_", comm_vol) profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol)
COL_COMM_PROF.pending_op = torch_broadcast(tensor, src, group, async_op) profiler.pending_op = torch_broadcast(tensor, src, group, async_op)
if async_op: if async_op:
return CommHandler() return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group) profiler.close_profiler(group)
def reduce(tensor: torch.Tensor, def reduce(tensor: torch.Tensor,
dst: int, dst: int,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
group=None, group=None,
async_op: bool = False) -> Optional[CommHandler]: async_op: bool = False,
async_check() profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_vol = 1.0 * tensor.element_size() * tensor.numel() comm_vol = 1.0 * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_Reduce_", comm_vol) profiler.activate_profiler("ncclKernel_Reduce_", comm_vol)
COL_COMM_PROF.pending_op = torch_reduce(tensor, dst, op, group, async_op) profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op)
if async_op: if async_op:
return CommHandler() return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group) profiler.close_profiler(group)

View File

@ -0,0 +1,85 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, List
from colossalai.core import global_context as gpc
class BaseProfiler(ABC):
def __init__(self, profiler_name: str, priority: int):
self.name = profiler_name
self.priority = priority
@abstractmethod
def enable(self):
pass
@abstractmethod
def disable(self):
pass
@abstractmethod
def to_tensorboard(self, writer):
pass
@abstractmethod
def to_file(self, filename: Path):
pass
@abstractmethod
def show(self):
pass
class ProfilerContext(object):
"""
Profiler context manager
Usage:
from colossalai.utils.profiler import CommProf, ProfilerContext
from torch.utils.tensorboard import SummaryWriter
cc_prof = CommProf()
with ProfilerContext([cc_prof]) as prof:
train()
writer = SummaryWriter('tb/path')
prof.to_tensorboard(writer)
prof.to_file('./prof_logs/')
prof.show()
"""
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):
self.enable = enable
self.profilers = sorted(profilers, key=lambda prof: prof.priority)
def __enter__(self):
if self.enable:
for prof in self.profilers:
prof.enable()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.enable:
for prof in self.profilers:
prof.disable()
def to_tensorboard(self, writer):
from torch.utils.tensorboard import SummaryWriter
assert isinstance(writer, SummaryWriter), \
f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
for prof in self.profilers:
prof.to_tensorboard(writer)
def to_file(self, log_dir: Union[str, Path]):
if isinstance(log_dir, str):
log_dir = Path(log_dir)
if not log_dir.exists():
log_dir.mkdir(parents=True, exist_ok=True)
for prof in self.profilers:
log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
prof.to_file(log_file)
def show(self):
for prof in self.profilers:
prof.show()

View File

@ -1,40 +0,0 @@
from functools import partial
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.utils.profiler import enable_communication_prof, communication_prof_show
BATCH_SIZE = 1024
D_MODEL = 1024
CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=4)))
def run_test(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
inputs = torch.randn(BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
enable_communication_prof()
op = dist.all_reduce(inputs, async_op=True)
dist.all_gather(outputs_list, inputs)
op.wait()
dist.reduce_scatter(inputs, outputs_list)
dist.broadcast(inputs, 0)
dist.reduce(inputs, 0)
if rank == 0:
communication_prof_show()
def test_cc_prof():
world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_cc_prof()