Added Profiler Context to manage all profilers (#340)

pull/394/head
HELSON 2022-03-09 16:12:41 +08:00 committed by Frank Lee
parent d0ae0f2215
commit 425bb0df3f
4 changed files with 193 additions and 118 deletions

View File

@ -1 +1,2 @@
from .comm_profiler import enable_communication_prof, communication_prof_show
from .comm_profiler import CommProfiler
from .prof_utils import ProfilerContext

View File

@ -1,9 +1,13 @@
import inspect
from pathlib import Path
from functools import partial
import torch
from torch.autograd.profiler import profile
import torch.distributed as dist
from torch.distributed import ReduceOp
import torch.utils.tensorboard as tb
from colossalai.utils import get_current_device
from .prof_utils import BaseProfiler
from typing import List, Optional
@ -57,6 +61,13 @@ def _format_bandwith(volme: float, time_us: int):
return '{:.3f} MB/s'.format(mb_per_sec)
torch_all_reduce = dist.all_reduce
torch_all_gather = dist.all_gather
torch_reduce_scatter = dist.reduce_scatter
torch_broadcast = dist.broadcast
torch_reduce = dist.reduce
class CommEvent(object):
"""Communication Event. Used for communication time and communication
volume recording.
@ -73,16 +84,16 @@ class CommEvent(object):
self.self_cuda_time += rhs.self_cuda_time
class CommProfiler(object):
class CommProfiler(BaseProfiler):
"""Communication profiler. Records all communication events.
"""
def __init__(self, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0, prof_depth: int = 3):
super().__init__()
def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
super().__init__(profiler_name="Collective_Communication", priority=0)
self.depth = 3 + depth
self.total_count = total_count
self.total_comm_vol = total_comm_vol
self.total_cuda_time = total_cuda_time
self.depth = prof_depth
self.ops_record = dict()
self.profiler = None
@ -101,27 +112,58 @@ class CommProfiler(object):
self.pending_metadata = None
self.warn_flag = False
def enable(self):
dist.all_reduce = partial(all_reduce, profiler=self)
dist.all_gather = partial(all_gather, profiler=self)
dist.reduce_scatter = partial(reduce_scatter, profiler=self)
dist.broadcast = partial(broadcast, profiler=self)
dist.reduce = partial(reduce, profiler=self)
def disable(self):
dist.all_reduce = torch_all_reduce
dist.all_gather = torch_all_gather
dist.reduce_scatter = torch_reduce_scatter
dist.broadcast = torch_broadcast
dist.reduce = torch_reduce
def to_tensorboard(self, writer: tb.writer):
writer.add_text(tag="Collective Communication", text_string=self.result_list("\n\n"))
def to_file(self, filename: Path):
with open(filename, "w") as f:
f.write(self.result_list())
def show(self):
print(self.result_list())
def result_list(self, sep: str = "\n"):
res = []
def append(s: str):
res.append(s)
res.append(sep)
if self.warn_flag:
print("Warnning: there exists multiple communication operations in the same time.\n"
"As a result, the profiling result is not accurate.")
print("Collective communication profiling result:",
"total cuda time: {}".format(_format_time(self.total_cuda_time)),
"average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)),
"total number of calls: {}".format(self.total_count),
"All events:",
sep='\n')
append("Warnning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate.")
append("Collective communication profiling result:")
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
append("average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)))
append("total number of calls: {}".format(self.total_count))
append("All events:\n----------------------------------------")
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list:
print(location,
"self cuda time: {}".format(_format_time(event.self_cuda_time)),
"{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0),
"self communication volme: {}".format(_format_memory(event.self_comm_vol)),
"average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)),
"number of calls: {}".format(event.self_count),
"--------------------",
sep='\n')
append(location)
append("self cuda time: {}".format(_format_time(event.self_cuda_time)))
append("{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0))
append("self communication volme: {}".format(_format_memory(event.self_comm_vol)))
append("average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)))
append("number of calls: {}".format(event.self_count))
append("----------------------------------------")
return ''.join(res)
@property
def has_aync_op(self):
@ -176,65 +218,46 @@ class CommHandler(object):
"""Communication handler. A dummy handler to wait aync operations.
"""
def __init__(self):
def __init__(self, profiler: CommProfiler):
super().__init__()
self.prof = COL_COMM_PROF
self.prof = profiler
def wait(self):
self.prof.wait_async_op()
COL_COMM_PROF = CommProfiler()
torch_all_reduce = dist.all_reduce
torch_all_gather = dist.all_gather
torch_reduce_scatter = dist.reduce_scatter
torch_broadcast = dist.broadcast
torch_reduce = dist.reduce
def enable_communication_prof(depth: int = 0):
COL_COMM_PROF.depth = 3 + depth
dist.all_reduce = all_reduce
dist.all_gather = all_gather
dist.reduce_scatter = reduce_scatter
dist.broadcast = broadcast
dist.reduce = reduce
def communication_prof_show():
COL_COMM_PROF.show()
def async_check():
if COL_COMM_PROF.pending_op is not None:
COL_COMM_PROF.warn_flag = True
COL_COMM_PROF.wait_async_op()
def async_check(profiler: CommProfiler):
if profiler.pending_op is not None:
profiler.warn_flag = True
profiler.wait_async_op()
def all_reduce(tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False) -> Optional[CommHandler]:
async_check()
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
correction = 2 * (comm_size - 1) / comm_size
comm_vol = correction * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_AllReduce_", comm_vol)
COL_COMM_PROF.pending_op = torch_all_reduce(tensor, op, group, async_op)
profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol)
profiler.pending_op = torch_all_reduce(tensor, op, group, async_op)
if async_op:
return CommHandler()
return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group)
profiler.close_profiler(group)
def reduce_scatter(output: torch.Tensor,
input_list: List[torch.Tensor],
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False) -> Optional[CommHandler]:
async_check()
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size
@ -242,20 +265,21 @@ def reduce_scatter(output: torch.Tensor,
for tensor in input_list:
comm_vol += tensor.element_size() * tensor.numel()
comm_vol *= correction
COL_COMM_PROF.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
COL_COMM_PROF.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
if async_op:
return CommHandler()
return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group)
profiler.close_profiler(group)
def all_gather(tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
group=None,
async_op: bool = False) -> Optional[CommHandler]:
async_check()
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size
@ -263,40 +287,45 @@ def all_gather(tensor_list: List[torch.Tensor],
for ten in tensor_list:
comm_vol += ten.element_size() * ten.numel()
comm_vol *= correction
COL_COMM_PROF.activate_profiler("ncclKernel_AllGather_", comm_vol)
COL_COMM_PROF.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
profiler.activate_profiler("ncclKernel_AllGather_", comm_vol)
profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
if async_op:
return CommHandler()
return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group)
profiler.close_profiler(group)
def broadcast(tensor: torch.Tensor, src: int, group=None, async_op: bool = False) -> Optional[CommHandler]:
async_check()
def broadcast(tensor: torch.Tensor,
src: int,
group=None,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_Broadcast_", comm_vol)
COL_COMM_PROF.pending_op = torch_broadcast(tensor, src, group, async_op)
profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol)
profiler.pending_op = torch_broadcast(tensor, src, group, async_op)
if async_op:
return CommHandler()
return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group)
profiler.close_profiler(group)
def reduce(tensor: torch.Tensor,
dst: int,
op: ReduceOp = ReduceOp.SUM,
group=None,
async_op: bool = False) -> Optional[CommHandler]:
async_check()
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
async_check(profiler)
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
COL_COMM_PROF.activate_profiler("ncclKernel_Reduce_", comm_vol)
COL_COMM_PROF.pending_op = torch_reduce(tensor, dst, op, group, async_op)
profiler.activate_profiler("ncclKernel_Reduce_", comm_vol)
profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op)
if async_op:
return CommHandler()
return CommHandler(profiler)
COL_COMM_PROF.close_profiler(group)
profiler.close_profiler(group)

View File

@ -0,0 +1,85 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, List
from colossalai.core import global_context as gpc
class BaseProfiler(ABC):
def __init__(self, profiler_name: str, priority: int):
self.name = profiler_name
self.priority = priority
@abstractmethod
def enable(self):
pass
@abstractmethod
def disable(self):
pass
@abstractmethod
def to_tensorboard(self, writer):
pass
@abstractmethod
def to_file(self, filename: Path):
pass
@abstractmethod
def show(self):
pass
class ProfilerContext(object):
"""
Profiler context manager
Usage:
from colossalai.utils.profiler import CommProf, ProfilerContext
from torch.utils.tensorboard import SummaryWriter
cc_prof = CommProf()
with ProfilerContext([cc_prof]) as prof:
train()
writer = SummaryWriter('tb/path')
prof.to_tensorboard(writer)
prof.to_file('./prof_logs/')
prof.show()
"""
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):
self.enable = enable
self.profilers = sorted(profilers, key=lambda prof: prof.priority)
def __enter__(self):
if self.enable:
for prof in self.profilers:
prof.enable()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.enable:
for prof in self.profilers:
prof.disable()
def to_tensorboard(self, writer):
from torch.utils.tensorboard import SummaryWriter
assert isinstance(writer, SummaryWriter), \
f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
for prof in self.profilers:
prof.to_tensorboard(writer)
def to_file(self, log_dir: Union[str, Path]):
if isinstance(log_dir, str):
log_dir = Path(log_dir)
if not log_dir.exists():
log_dir.mkdir(parents=True, exist_ok=True)
for prof in self.profilers:
log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
prof.to_file(log_file)
def show(self):
for prof in self.profilers:
prof.show()

View File

@ -1,40 +0,0 @@
from functools import partial
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.utils.profiler import enable_communication_prof, communication_prof_show
BATCH_SIZE = 1024
D_MODEL = 1024
CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=4)))
def run_test(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
inputs = torch.randn(BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
enable_communication_prof()
op = dist.all_reduce(inputs, async_op=True)
dist.all_gather(outputs_list, inputs)
op.wait()
dist.reduce_scatter(inputs, outputs_list)
dist.broadcast(inputs, 0)
dist.reduce(inputs, 0)
if rank == 0:
communication_prof_show()
def test_cc_prof():
world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_cc_prof()