mirror of https://github.com/hpcaitech/ColossalAI
Added profiler communication operations
Fixed bug for learning rate schedulerpull/394/head
parent
d275b98b7d
commit
73bff11288
|
@ -1,17 +1,26 @@
|
|||
from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
|
||||
from .p2p import (send_forward, send_forward_recv_forward,
|
||||
send_backward_recv_forward, send_backward,
|
||||
send_backward_recv_backward, send_forward_recv_backward,
|
||||
send_forward_backward_recv_forward_backward, recv_forward,
|
||||
recv_backward)
|
||||
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward,
|
||||
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
|
||||
recv_forward, recv_backward)
|
||||
from .ring import ring_forward
|
||||
from .utils import send_tensor_meta, recv_tensor_meta
|
||||
|
||||
__all__ = [
|
||||
'all_gather', 'reduce_scatter', 'all_reduce', 'broadcast', 'reduce',
|
||||
'send_forward', 'send_forward_recv_forward',
|
||||
'send_forward_backward_recv_forward_backward', 'send_backward',
|
||||
'send_backward_recv_backward', 'send_backward_recv_forward',
|
||||
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
|
||||
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta',
|
||||
'all_gather',
|
||||
'reduce_scatter',
|
||||
'all_reduce',
|
||||
'broadcast',
|
||||
'reduce',
|
||||
'send_forward',
|
||||
'send_forward_recv_forward',
|
||||
'send_forward_backward_recv_forward_backward',
|
||||
'send_backward',
|
||||
'send_backward_recv_backward',
|
||||
'send_backward_recv_forward',
|
||||
'send_forward_recv_backward',
|
||||
'recv_backward',
|
||||
'recv_forward',
|
||||
'ring_forward',
|
||||
'send_tensor_meta',
|
||||
'recv_tensor_meta',
|
||||
]
|
||||
|
|
|
@ -29,6 +29,7 @@ class LRSchedulerHook(MetricHook):
|
|||
self.store_lr_in_state = store_lr_in_state
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch,
|
||||
initial_lr=self.lr_scheduler.get_last_lr()[0])
|
||||
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
from .activation_checkpoint import checkpoint
|
||||
|
||||
from .common import (clip_grad_norm_fp32, conditional_context,
|
||||
copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
free_port, is_dp_rank_0, is_model_parallel_parameter,
|
||||
is_moe_parallel_parameter, is_no_pp_or_last_stage,
|
||||
is_tp_rank_0, is_using_ddp, is_using_pp,
|
||||
is_using_sequence, multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate, print_rank_0,
|
||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
free_port, is_dp_rank_0, is_model_parallel_parameter, is_moe_parallel_parameter,
|
||||
is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0,
|
||||
switch_virtual_pipeline_parallel_rank, sync_model_param)
|
||||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .comm_profiler import enable_communication_prof, communication_prof_show
|
|
@ -0,0 +1,302 @@
|
|||
import inspect
|
||||
import torch
|
||||
from torch.autograd.profiler import profile
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
from colossalai.utils import get_current_device
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def _get_code_location(depth: int):
|
||||
ret = ""
|
||||
length = len(inspect.stack())
|
||||
for i in range(3, min(length, depth + 1)):
|
||||
upper_frame = inspect.stack()[i]
|
||||
function_name = inspect.stack()[i - 1].function
|
||||
info = upper_frame.filename + "(" + str(upper_frame.lineno) + "): " + function_name + "\n"
|
||||
ret += info
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
# copied from high version pytorch to support low version
|
||||
def _format_time(time_us):
|
||||
"""Defines how to format time in FunctionEvent"""
|
||||
US_IN_SECOND = 1000.0 * 1000.0
|
||||
US_IN_MS = 1000.0
|
||||
if time_us >= US_IN_SECOND:
|
||||
return '{:.3f}s'.format(time_us / US_IN_SECOND)
|
||||
if time_us >= US_IN_MS:
|
||||
return '{:.3f}ms'.format(time_us / US_IN_MS)
|
||||
return '{:.3f}us'.format(time_us)
|
||||
|
||||
|
||||
# copied from high version pytorch to support low version
|
||||
def _format_memory(nbytes):
|
||||
"""Returns a formatted memory size string"""
|
||||
KB = 1024
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
if (abs(nbytes) >= GB):
|
||||
return '{:.2f} Gb'.format(nbytes * 1.0 / GB)
|
||||
elif (abs(nbytes) >= MB):
|
||||
return '{:.2f} Mb'.format(nbytes * 1.0 / MB)
|
||||
elif (abs(nbytes) >= KB):
|
||||
return '{:.2f} Kb'.format(nbytes * 1.0 / KB)
|
||||
else:
|
||||
return str(nbytes) + ' b'
|
||||
|
||||
|
||||
def _format_bandwith(volme: float, time_us: int):
|
||||
sec_div_mb = (1000.0 / 1024.0)**2
|
||||
mb_per_sec = volme / time_us * sec_div_mb
|
||||
|
||||
if mb_per_sec >= 1024.0:
|
||||
return '{:.3f} Gb/s'.format(mb_per_sec / 1024.0)
|
||||
else:
|
||||
return '{:.3f} Mb/s'.format(mb_per_sec)
|
||||
|
||||
|
||||
class CommEvent(object):
|
||||
"""Communication Event. Used for communication time and communication
|
||||
volume recording.
|
||||
"""
|
||||
|
||||
def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
|
||||
self.self_count = count
|
||||
self.self_comm_vol = comm_vol
|
||||
self.self_cuda_time = cuda_time
|
||||
|
||||
def add(self, rhs):
|
||||
self.self_count += rhs.self_count
|
||||
self.self_comm_vol += rhs.self_comm_vol
|
||||
self.self_cuda_time += rhs.self_cuda_time
|
||||
|
||||
|
||||
class CommProfiler(object):
|
||||
"""Communication profiler. Records all communication events.
|
||||
"""
|
||||
|
||||
def __init__(self, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0, prof_depth: int = 3):
|
||||
super().__init__()
|
||||
self.total_count = total_count
|
||||
self.total_comm_vol = total_comm_vol
|
||||
self.total_cuda_time = total_cuda_time
|
||||
self.depth = prof_depth
|
||||
|
||||
self.ops_record = dict()
|
||||
self.profiler = None
|
||||
self.pending_op = None
|
||||
self.pending_metadata = None
|
||||
self.warn_flag = False
|
||||
|
||||
def reset(self):
|
||||
self.total_count = 0
|
||||
self.total_comm_vol = 0
|
||||
self.total_cuda_time = 0
|
||||
|
||||
self.ops_record = dict()
|
||||
self.profiler = None
|
||||
self.pending_op = None
|
||||
self.pending_metadata = None
|
||||
self.warn_flag = False
|
||||
|
||||
def show(self):
|
||||
if self.warn_flag:
|
||||
print("Warnning: there exists multiple communication operations in the same time.\n"
|
||||
"As a result, the profiling result is not accurate.")
|
||||
print("Collective communication profiling result:",
|
||||
"total cuda time: {}".format(_format_time(self.total_cuda_time)),
|
||||
"average bandwith: {}".format(_format_bandwith(self.total_comm_vol, self.total_cuda_time)),
|
||||
"total number of calls: {}".format(self.total_count),
|
||||
"All events:",
|
||||
sep='\n')
|
||||
|
||||
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
|
||||
for location, event in show_list:
|
||||
print(location,
|
||||
"self cuda time: {}".format(_format_time(event.self_cuda_time)),
|
||||
"{:.1f}% of total communication time".format(event.self_cuda_time / self.total_cuda_time * 100.0),
|
||||
"self communication volme: {}".format(_format_memory(event.self_comm_vol)),
|
||||
"average bandwith: {}".format(_format_bandwith(event.self_comm_vol, event.self_cuda_time)),
|
||||
"number of calls: {}".format(event.self_count),
|
||||
"--------------------",
|
||||
sep='\n')
|
||||
|
||||
@property
|
||||
def has_aync_op(self):
|
||||
return self.pending_op is not None
|
||||
|
||||
def activate_profiler(self, kn: str, vol: float):
|
||||
self.pending_metadata = (kn, _get_code_location(self.depth), vol)
|
||||
self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True)
|
||||
self.profiler.__enter__()
|
||||
|
||||
def close_profiler(self, group=None):
|
||||
assert self.profiler is not None, "There is no running dist op"
|
||||
kernel_name, code_location, vol = self.pending_metadata
|
||||
self.profiler.__exit__(None, None, None)
|
||||
|
||||
if self.profiler.enabled:
|
||||
assert_flag = 0
|
||||
current_comm_event = None
|
||||
events = self.profiler.function_events
|
||||
for event in events:
|
||||
if kernel_name in event.name:
|
||||
assert assert_flag == 0, "Multiple dist ops has been called "
|
||||
current_comm_event = CommEvent(1, vol, event.self_cuda_time_total)
|
||||
assert_flag += 1
|
||||
|
||||
assert current_comm_event is not None, "dist op has not been found"
|
||||
|
||||
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
|
||||
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
|
||||
current_comm_event.self_cuda_time = buffer.item()
|
||||
|
||||
self.total_count += current_comm_event.self_count
|
||||
self.total_comm_vol += current_comm_event.self_comm_vol
|
||||
self.total_cuda_time += current_comm_event.self_cuda_time
|
||||
if code_location in self.ops_record:
|
||||
self.ops_record[code_location].add(current_comm_event)
|
||||
else:
|
||||
self.ops_record[code_location] = current_comm_event
|
||||
|
||||
self.profiler = None
|
||||
self.pending_op = None
|
||||
self.pending_metadata = None
|
||||
|
||||
def wait_async_op(self):
|
||||
if self.pending_op is not None:
|
||||
op = self.pending_op
|
||||
op.wait()
|
||||
self.close_profiler()
|
||||
|
||||
|
||||
class CommHandler(object):
|
||||
"""Communication handler. A dummy handler to wait aync operations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.prof = COL_COMM_PROF
|
||||
|
||||
def wait(self):
|
||||
self.prof.wait_async_op()
|
||||
|
||||
|
||||
COL_COMM_PROF = CommProfiler()
|
||||
torch_all_reduce = dist.all_reduce
|
||||
torch_all_gather = dist.all_gather
|
||||
torch_reduce_scatter = dist.reduce_scatter
|
||||
torch_broadcast = dist.broadcast
|
||||
torch_reduce = dist.reduce
|
||||
|
||||
|
||||
def enable_communication_prof(depth: int = 0):
|
||||
COL_COMM_PROF.depth = 3 + depth
|
||||
dist.all_reduce = all_reduce
|
||||
dist.all_gather = all_gather
|
||||
dist.reduce_scatter = reduce_scatter
|
||||
dist.broadcast = broadcast
|
||||
dist.reduce = reduce
|
||||
|
||||
|
||||
def communication_prof_show():
|
||||
COL_COMM_PROF.show()
|
||||
|
||||
|
||||
def async_check():
|
||||
if COL_COMM_PROF.pending_op is not None:
|
||||
COL_COMM_PROF.warn_flag = True
|
||||
COL_COMM_PROF.wait_async_op()
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = 2 * (comm_size - 1) / comm_size
|
||||
comm_vol = correction * tensor.element_size() * tensor.numel()
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_AllReduce_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_all_reduce(tensor, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
|
||||
|
||||
def reduce_scatter(output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = (comm_size - 1) / comm_size
|
||||
comm_vol = 0
|
||||
for tensor in input_list:
|
||||
comm_vol += tensor.element_size() * tensor.numel()
|
||||
comm_vol *= correction
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
|
||||
|
||||
def all_gather(tensor_list: List[torch.Tensor],
|
||||
tensor: torch.Tensor,
|
||||
group=None,
|
||||
async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
correction = (comm_size - 1) / comm_size
|
||||
comm_vol = 0
|
||||
for ten in tensor_list:
|
||||
comm_vol += ten.element_size() * ten.numel()
|
||||
comm_vol *= correction
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_AllGather_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor, src: int, group=None, async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_Broadcast_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_broadcast(tensor, src, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
||||
|
||||
|
||||
def reduce(tensor: torch.Tensor,
|
||||
dst: int,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
group=None,
|
||||
async_op: bool = False) -> Optional[CommHandler]:
|
||||
async_check()
|
||||
|
||||
comm_vol = 1.0 * tensor.element_size() * tensor.numel()
|
||||
COL_COMM_PROF.activate_profiler("ncclKernel_Reduce_", comm_vol)
|
||||
COL_COMM_PROF.pending_op = torch_reduce(tensor, dst, op, group, async_op)
|
||||
|
||||
if async_op:
|
||||
return CommHandler()
|
||||
|
||||
COL_COMM_PROF.close_profiler(group)
|
|
@ -0,0 +1,40 @@
|
|||
from functools import partial
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
import colossalai
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.profiler import enable_communication_prof, communication_prof_show
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
D_MODEL = 1024
|
||||
CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=4)))
|
||||
|
||||
|
||||
def run_test(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
inputs = torch.randn(BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
|
||||
outputs = torch.empty(world_size, BATCH_SIZE, D_MODEL, dtype=torch.float32, device=get_current_device())
|
||||
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
|
||||
|
||||
enable_communication_prof()
|
||||
|
||||
op = dist.all_reduce(inputs, async_op=True)
|
||||
dist.all_gather(outputs_list, inputs)
|
||||
op.wait()
|
||||
dist.reduce_scatter(inputs, outputs_list)
|
||||
dist.broadcast(inputs, 0)
|
||||
dist.reduce(inputs, 0)
|
||||
|
||||
if rank == 0:
|
||||
communication_prof_show()
|
||||
|
||||
|
||||
def test_cc_prof():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cc_prof()
|
Loading…
Reference in New Issue