pull/540/head
lijiaxing 2023-12-18 21:37:17 +08:00
parent f68f34234d
commit d9c9f7c9ee
3 changed files with 50 additions and 22 deletions

View File

@ -100,8 +100,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._comm_bcast_stream = torch.cuda.Stream()
# gradient scaler
self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale,
@ -859,16 +857,21 @@ class HybridZeroOptimizer(BaseOptimizer):
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
handle = dist.broadcast(
fp16_param,
src=g_rank,
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
async_op=True,
)
if self._overlap_sync_param:
handle = dict()
handle["tensor"] = fp16_param
handle["src"] = g_rank
handle["group"] = gpc.get_group(self._broadcast_parallel_mode[group_id])
handle["async_op"] = True
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handle = dist.broadcast(
fp16_param,
src=g_rank,
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
async_op=True,
)
handles.append(handle)
for handle in handles:

View File

@ -803,6 +803,8 @@ class ParamBcastSyncHandler:
self._param_to_rank = dict() # <key: param> <value: rank)>
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
self._block_next_block = dict() # <key: nn.Module> <value: nn.Module>
self._block_to_handles = dict() # <key: nn.Module> <value: list(bcast handles)>
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
total_param_num = sum(p.numel() for p in model.parameters())
@ -824,10 +826,18 @@ class ParamBcastSyncHandler:
for _, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
key_list = list(self._block_to_param.keys())
if len(key_list) > 1:
up_layer = key_list[-2]
self._block_next_block[up_layer] = key_list[-1]
else:
# record the block that a parameter belongs to
# self._block_to_param[name] = list(children.parameters())
self._block_to_param[children] = list(children.parameters())
key_list = list(self._block_to_param.keys())
if len(key_list) > 1:
up_layer = key_list[-2]
self._block_next_block[up_layer] = key_list[-1]
alloc_num = 0
rank_to_go = 0
@ -857,16 +867,35 @@ class ParamBcastSyncHandler:
# register_forward_pre_hook for transformer/embeding/norm/xxx block
self._register_sync_parameters_hook()
def _launch_handle(self, layer):
handle_metas = []
for rank in self._block_to_rank[layer]:
handle_metas.extend(self._bcast_handles[rank])
# need to clear _bcast_handles since they would be processed later
self._bcast_handles[rank] = []
# wait all required broadcast handles to be completed
handles = []
for handle_meta in handle_metas:
handle = dist.broadcast(**handle_meta)
handles.append(handle)
self._block_to_handles[layer] = handles
def _register_sync_parameters_hook(self) -> None:
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
bcast_handles = []
# gather all required broadcast hanles into a list
for rank in self._block_to_rank[model]:
bcast_handles.extend(self._bcast_handles[rank])
# need to clear _bcast_handles since they would be processed later
self._bcast_handles[rank] = []
# wait all required broadcast handles to be completed
for handle in bcast_handles:
current_layer = model
next_layer = self._block_next_block[current_layer] if current_layer in self._block_next_block else None
# if this is the first layer
# launch broadcast for current layer
if current_layer == list(self._block_to_param.keys())[0]:
self._launch_handle(current_layer)
# if this is not the last layer
# launch broadcast for next layer
if next_layer:
self._launch_handle(next_layer)
for handle in self._block_to_handles[current_layer]:
handle.wait()
# register_forward_pre_hook for transformer/embeding/norm/xxx block

View File

@ -5,8 +5,6 @@ import time
import torch
from internlm.core.context import global_context as gpc
class _Timer:
"""Timer."""
@ -25,16 +23,14 @@ class _Timer:
megatron_timer.reset()
assert not self.started_, "timer has already been started"
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
self.stream.synchronize()
self.stream.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
self.stream.synchronize()
self.stream.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False