
293 lines
10 KiB
Raw Normal View History

import inspect
from pathlib import Path
from functools import partial
import torch
from torch.autograd.profiler import profile
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.utils import get_current_device
from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwith
from typing import List, Optional
def _get_code_location(depth: int):
ret = ""
length = len(inspect.stack())
for i in range(3, min(length, depth + 1)):
upper_frame = inspect.stack()[i]
function_name = inspect.stack()[i - 1].function
info = upper_frame.filename + "(" + str(upper_frame.lineno) + "): " + function_name + "\n"
ret += info
return ret
torch_all_reduce = dist.all_reduce
torch_all_gather = dist.all_gather
torch_reduce_scatter = dist.reduce_scatter
torch_broadcast = dist.broadcast
torch_reduce = dist.reduce
class CommEvent(object):
"""Communication Event. Used for communication time and communication
volume recording.
def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
self.self_count = count
self.self_comm_vol = comm_vol
self.self_cuda_time = cuda_time
def add(self, rhs):
self.self_count += rhs.self_count
self.self_comm_vol += rhs.self_comm_vol
self.self_cuda_time += rhs.self_cuda_time
class CommProfiler(BaseProfiler):
"""Communication profiler. Records all communication events.
def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
super().__init__(profiler_name="Collective_Communication", priority=0)
self.depth = 3 + depth
self.total_count = total_count
self.total_comm_vol = total_comm_vol
self.total_cuda_time = total_cuda_time
self.ops_record = dict()
self.profiler = None
self.pending_op = None
self.pending_metadata = None
self.warn_flag = False
def reset(self):
self.total_count = 0
self.total_comm_vol = 0
self.total_cuda_time = 0
self.ops_record = dict()
self.profiler = None
self.pending_op = None
self.pending_metadata = None
self.warn_flag = False
def enable(self):
dist.all_reduce = partial(all_reduce, profiler=self)
dist.all_gather = partial(all_gather, profiler=self)
dist.reduce_scatter = partial(reduce_scatter, profiler=self)
dist.broadcast = partial(broadcast, profiler=self)
dist.reduce = partial(reduce, profiler=self)
def disable(self):
dist.all_reduce = torch_all_reduce
dist.all_gather = torch_all_gather
dist.reduce_scatter = torch_reduce_scatter
dist.broadcast = torch_broadcast
dist.reduce = torch_reduce
def to_tensorboard(self, writer):
writer.add_text(tag="Collective Communication", text_string=self.result_list("\n\n"))
def to_file(self, filename: Path):
with open(filename, "w") as f:
def show(self):
def result_list(self, sep: str = "\n"):
res = []
def append(s: str):
if self.warn_flag:
append("Warnning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate.")
append("Collective communication profiling result:")
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
append("average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)))
append("total number of calls: {}".format(self.total_count))
append("All events:\n----------------------------------------")
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list:
append("self cuda time: {}".format(_format_time(event.self_cuda_time)))
append("{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0))
append("self communication volme: {}".format(_format_memory(event.self_comm_vol)))
append("average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)))
append("number of calls: {}".format(event.self_count))
return ''.join(res)
def has_aync_op(self):
return self.pending_op is not None
def activate_profiler(self, kn: str, vol: float):
self.pending_metadata = (kn, _get_code_location(self.depth), vol)
self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True)
def close_profiler(self, group=None):
assert self.profiler is not None, "There is no running dist op"
kernel_name, code_location, vol = self.pending_metadata
self.profiler.__exit__(None, None, None)
if self.profiler.enabled and dist.get_world_size(group) > 1:
assert_flag = 0
current_comm_event = None
events = self.profiler.function_events
for event in events:
if kernel_name in
assert assert_flag == 0, "Multiple dist ops has been called "
current_comm_event = CommEvent(1, vol, event.self_cuda_time_total)
assert_flag += 1
assert current_comm_event is not None, "dist op has not been found"
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
current_comm_event.self_cuda_time = buffer.item()
self.total_count += current_comm_event.self_count
self.total_comm_vol += current_comm_event.self_comm_vol
self.total_cuda_time += current_comm_event.self_cuda_time
if code_location in self.ops_record:
self.ops_record[code_location] = current_comm_event
self.profiler = None
self.pending_op = None
self.pending_metadata = None
def wait_async_op(self):
if self.pending_op is not None:
op = self.pending_op
class CommHandler(object):
"""Communication handler. A dummy handler to wait aync operations.
def __init__(self, profiler: CommProfiler):
super().__init__() = profiler
def wait(self):
def async_check(profiler: CommProfiler):
if profiler.pending_op is not None:
profiler.warn_flag = True
def all_reduce(tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
comm_size = dist.get_world_size(group)
correction = 2 * (comm_size - 1) / comm_size
comm_vol = correction * tensor.element_size() * tensor.numel()
profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol)
profiler.pending_op = torch_all_reduce(tensor, op, group, async_op)
if async_op:
return CommHandler(profiler)
def reduce_scatter(output: torch.Tensor,
input_list: List[torch.Tensor],
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size
comm_vol = 0
for tensor in input_list:
comm_vol += tensor.element_size() * tensor.numel()
comm_vol *= correction
profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
if async_op:
return CommHandler(profiler)
def all_gather(tensor_list: List[torch.Tensor],
tensor: torch.Tensor,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
comm_size = dist.get_world_size(group)
correction = (comm_size - 1) / comm_size
comm_vol = 0
for ten in tensor_list:
comm_vol += ten.element_size() * ten.numel()
comm_vol *= correction
profiler.activate_profiler("ncclKernel_AllGather_", comm_vol)
profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
if async_op:
return CommHandler(profiler)
def broadcast(tensor: torch.Tensor,
src: int,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol)
profiler.pending_op = torch_broadcast(tensor, src, group, async_op)
if async_op:
return CommHandler(profiler)
def reduce(tensor: torch.Tensor,
dst: int,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False,
profiler: CommProfiler = None) -> Optional[CommHandler]:
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
profiler.activate_profiler("ncclKernel_Reduce_", comm_vol)
profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op)
if async_op:
return CommHandler(profiler)