mirror of https://github.com/hpcaitech/ColossalAI
Added Profiler Context to manage all profilers (#340)
parent
d0ae0f2215
commit
425bb0df3f
|
@ -1 +1,2 @@
|
|||
from .comm_profiler import enable_communication_prof, communication_prof_show
|
||||
from .comm_profiler import CommProfiler
|
||||
from .prof_utils import ProfilerContext
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import inspect
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch.autograd.profiler import profile
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
import torch.utils.tensorboard as tb
|
||||
from colossalai.utils import get_current_device
|
||||
from .prof_utils import BaseProfiler
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
|
@ -57,6 +61,13 @@ def _format_bandwith(volme: float, time_us: int):
|
|||
return '{:.3f} MB/s'.format(mb_per_sec)
|
||||
|
||||
|
||||
torch_all_reduce = dist.all_reduce
|
||||
torch_all_gather = dist.all_gather
|
||||
torch_reduce_scatter = dist.reduce_scatter
|
||||
torch_broadcast = dist.broadcast
|
||||
torch_reduce = dist.reduce
|
||||
|
||||
|
||||
class CommEvent(object):
|
||||
"""Communication Event. Used for communication time and communication
|
||||
volume recording.
|
||||
|
@ -73,16 +84,16 @@ class CommEvent(object):
|
|||
self.self_cuda_time += rhs.self_cuda_time
|
||||
|
||||
|
||||
class CommProfiler(object):
|
||||
class CommProfiler(BaseProfiler):
|
||||
"""Communication profiler. Records all communication events.
|
||||
"""
|
||||
|
||||
def __init__(self, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0, prof_depth: int = 3):
|
||||
super().__init__()
|
||||
def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
|
||||
super().__init__(profiler_name="Collective_Communication", priority=0)
|
||||
self.depth = 3 + depth
|
||||
self.total_count = total_count
|
||||
self.total_comm_vol = total_comm_vol
|
||||
self.total_cuda_time = total_cuda_time
|
||||
self.depth = prof_depth
|
||||
|
||||
self.ops_record = dict()
|
||||
self.profiler = None
|
||||
|
@ -101,27 +112,58 @@ class CommProfiler(object):
|
|||
self.pending_metadata = None
|
||||
self.warn_flag = False
|
||||
|
||||
def enable(self):
|
||||
dist.all_reduce = partial(all_reduce, profiler=self)
|
||||
dist.all_gather = partial(all_gather, profiler=self)
|
||||
dist.reduce_scatter = partial(reduce_scatter, profiler=self)
|
||||
dist.broadcast = partial(broadcast, profiler=self)
|
||||
dist.reduce = partial(reduce, profiler=self)
|
||||
|
||||
def disable(self):
|
||||
dist.all_reduce = torch_all_reduce
|
||||
dist.all_gather = torch_all_gather
|
||||
dist.reduce_scatter = torch_reduce_scatter
|
||||
dist.broadcast = torch_broadcast
|
||||
dist.reduce = torch_reduce
|
||||
|
||||
def to_tensorboard(self, writer: tb.writer):
|
||||
writer.add_text(tag="Collective Communication", text_string=self.result_list("\n\n"))
|
||||
|
||||
def to_file(self, filename: Path):
|
||||
with open(filename, "w") as f:
|
||||
f.write(self.result_list())
|
||||
|
||||
def show(self):
|
||||
print(self.result_list())
|
||||
|
||||
def result_list(self, sep: str = "\n"):
|
||||
res = []
|
||||
|
||||
def append(s: str):
|
||||
res.append(s)
|
||||
res.append(sep)
|
||||
|
||||
if self.warn_flag:
|
||||
print("Warnning: there exists multiple communication operations in the same time.\n"
|
||||
"As a result, the profiling result is not accurate.")
|
||||
print("Collective communication profiling result:",
|
||||
"total cuda time: {}".format(_format_time(self.total_cuda_time)),
|
||||
"average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)),
|
||||
"total number of calls: {}".format(self.total_count),
|
||||
"All events:",
|
||||
sep='\n')
|
||||
append("Warnning: there exists multiple communication operations in the same time. As a result, "
|
||||
"the profiling result is not accurate.")
|
||||
|
||||
append("Collective communication profiling result:")
|
||||
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
|
||||
append("average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)))
|
||||
append("total number of calls: {}".format(self.total_count))
|
||||
append("All events:\n----------------------------------------")
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
||||
for location, event in show_list:
|
||||
print(location,
|
||||
"self cuda time: {}".format(_format_time(event.self_cuda_time)),
|
||||
"{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0),
|
||||
"self communication volme: {}".format(_format_memory(event.self_comm_vol)),
|
||||
"average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)),
|
||||
"number of calls: {}".format(event.self_count),
|
||||
"--------------------",
|
||||
sep='\n')
|
||||
append(location)
|
||||
append("self cuda time: {}".format(_format_time(event.self_cuda_time)))
|
||||
append("{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0))
|
||||
append("self communication volme: {}".format(_format_memory(event.self_comm_vol)))
|
||||
append("average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)))
|
||||
append("number of calls: {}".format(event.self_count))
|
||||
append("----------------------------------------")
|
||||
|
||||
return ''.join(res)
|
||||
|
||||
@property
|
||||
def has_aync_op(self):
|
||||
|
@ -176,65 +218,46 @@ class CommHandler(object):
|
|||
"""Communication handler. A dummy handler to wait aync operations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, profiler: CommProfiler):
|
||||
super().__init__()
|
||||
self.prof = COL_COMM_PROF
|
||||
self.prof = profiler
|
||||
|
||||
def wait(self):
|
||||
self.prof.wait_async_op()
|
||||
|
||||
|
||||
COL_COMM_PROF = CommProfiler()
|
||||
torch_all_reduce = dist.all_reduce
|
||||
torch_all_gather = dist.all_gather
|
||||
torch_reduce_scatter = dist.reduce_scatter
|
||||
torch_broadcast = dist.broadcast
|
||||
torch_reduce = dist.reduce
|
||||
|
||||
|
||||
def enable_communication_prof(depth: int = 0):
|
||||
COL_COMM_PROF.depth = 3 + depth
|
||||
dist.all_reduce = all_reduce
|
||||
dist.all_gather = all_gather
|
||||
dist.reduce_scatter = reduce_scatter
|
||||
dist.broadcast = broadcast
|
||||
dist.reduce = reduce
|
||||
|
||||
|
||||
def communication_prof_show():
|
||||
COL_COMM_PROF.show()
|
||||
|
||||
|
||||
def async_check():
|
||||
if COL_COMM_PROF.pending_op is not None:
|
||||
COL_COMM_PROF.warn_flag = True
|
||||
COL_COMM_PROF.wait_async_op()
|
||||
def async_check(profiler: CommProfiler):
|
||||
if profiler.pending_op is not None:
|
||||
profiler.warn_flag = True
|
||||
profiler.wait_async_op()
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = 2 * (comm_size - 1) / comm_size
|
||||
comm_vol = correction * tensor.element_size() * tensor.numel()
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_AllReduce_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_all_reduce(tensor, op, group, async_op)
|
||||
profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol)
|
||||
profiler.pending_op = torch_all_reduce(tensor, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
return CommHandler(profiler)
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def reduce_scatter(output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = (comm_size - 1) / comm_size
|
||||
|
@ -242,20 +265,21 @@ def reduce_scatter(output: torch.Tensor,
|
|||
for tensor in input_list:
|
||||
comm_vol += tensor.element_size() * tensor.numel()
|
||||
comm_vol *= correction
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
|
||||
profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
|
||||
profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
return CommHandler(profiler)
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def all_gather(tensor_list: List[torch.Tensor],
|
||||
tensor: torch.Tensor,
|
||||
group=None,
|
||||
async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = (comm_size - 1) / comm_size
|
||||
|
@ -263,40 +287,45 @@ def all_gather(tensor_list: List[torch.Tensor],
|
|||
for ten in tensor_list:
|
||||
comm_vol += ten.element_size() * ten.numel()
|
||||
comm_vol *= correction
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_AllGather_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
|
||||
profiler.activate_profiler("ncclKernel_AllGather_", comm_vol)
|
||||
profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
return CommHandler(profiler)
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor, src: int, group=None, async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
def broadcast(tensor: torch.Tensor,
|
||||
src: int,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_Broadcast_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_broadcast(tensor, src, group, async_op)
|
||||
profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol)
|
||||
profiler.pending_op = torch_broadcast(tensor, src, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
return CommHandler(profiler)
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
profiler.close_profiler(group)
|
||||
|
||||
|
||||
def reduce(tensor: torch.Tensor,
|
||||
dst: int,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
async_op: bool = False,
|
||||
profiler: CommProfiler = None) -> Optional[CommHandler]:
|
||||
async_check(profiler)
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_Reduce_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_reduce(tensor, dst, op, group, async_op)
|
||||
profiler.activate_profiler("ncclKernel_Reduce_", comm_vol)
|
||||
profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
return CommHandler(profiler)
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
profiler.close_profiler(group)
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class BaseProfiler(ABC):
|
||||
|
||||
def __init__(self, profiler_name: str, priority: int):
|
||||
self.name = profiler_name
|
||||
self.priority = priority
|
||||
|
||||
@abstractmethod
|
||||
def enable(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disable(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_tensorboard(self, writer):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_file(self, filename: Path):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def show(self):
|
||||
pass
|
||||
|
||||
|
||||
class ProfilerContext(object):
|
||||
"""
|
||||
Profiler context manager
|
||||
Usage:
|
||||
from colossalai.utils.profiler import CommProf, ProfilerContext
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
cc_prof = CommProf()
|
||||
with ProfilerContext([cc_prof]) as prof:
|
||||
train()
|
||||
writer = SummaryWriter('tb/path')
|
||||
prof.to_tensorboard(writer)
|
||||
prof.to_file('./prof_logs/')
|
||||
prof.show()
|
||||
"""
|
||||
|
||||
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):
|
||||
self.enable = enable
|
||||
self.profilers = sorted(profilers, key=lambda prof: prof.priority)
|
||||
|
||||
def __enter__(self):
|
||||
if self.enable:
|
||||
for prof in self.profilers:
|
||||
prof.enable()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.enable:
|
||||
for prof in self.profilers:
|
||||
prof.disable()
|
||||
|
||||
def to_tensorboard(self, writer):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert isinstance(writer, SummaryWriter), \
|
||||
f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
|
||||
|
||||
for prof in self.profilers:
|
||||
prof.to_tensorboard(writer)
|
||||
|
||||
def to_file(self, log_dir: Union[str, Path]):
|
||||
if isinstance(log_dir, str):
|
||||
log_dir = Path(log_dir)
|
||||
|
||||
if not log_dir.exists():
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
for prof in self.profilers:
|
||||
log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
|
||||
prof.to_file(log_file)
|
||||
|
||||
def show(self):
|
||||
for prof in self.profilers:
|
||||
prof.show()
|
|
@ -1,40 +0,0 @@
|
|||
from functools import partial
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
import colossalai
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.profiler import enable_communication_prof, communication_prof_show
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
D_MODEL = 1024
|
||||
CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=4)))
|
||||
|
||||
|
||||
def run_test(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
inputs = torch.randn(BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
|
||||
outputs = torch.empty(world_size, BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
|
||||
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
|
||||
|
||||
enable_communication_prof()
|
||||
|
||||
op = dist.all_reduce(inputs, async_op=True)
|
||||
dist.all_gather(outputs_list, inputs)
|
||||
op.wait()
|
||||
dist.reduce_scatter(inputs, outputs_list)
|
||||
dist.broadcast(inputs, 0)
|
||||
dist.reduce(inputs, 0)
|
||||
|
||||
if rank == 0:
|
||||
communication_prof_show()
|
||||
|
||||
|
||||
def test_cc_prof():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cc_prof()
|
Loading…
Reference in New Issue