mirror of https://github.com/InternLM/InternLM
fix
parent
f68f34234d
commit
d9c9f7c9ee
|
@ -100,8 +100,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
|
||||
self._comm_bcast_stream = torch.cuda.Stream()
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -859,16 +857,21 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param,
|
||||
src=g_rank,
|
||||
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
|
||||
async_op=True,
|
||||
)
|
||||
|
||||
if self._overlap_sync_param:
|
||||
handle = dict()
|
||||
handle["tensor"] = fp16_param
|
||||
handle["src"] = g_rank
|
||||
handle["group"] = gpc.get_group(self._broadcast_parallel_mode[group_id])
|
||||
handle["async_op"] = True
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handle = dist.broadcast(
|
||||
fp16_param,
|
||||
src=g_rank,
|
||||
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
|
||||
async_op=True,
|
||||
)
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
|
|
|
@ -803,6 +803,8 @@ class ParamBcastSyncHandler:
|
|||
self._param_to_rank = dict() # <key: param> <value: rank)>
|
||||
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
|
||||
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
|
||||
self._block_next_block = dict() # <key: nn.Module> <value: nn.Module>
|
||||
self._block_to_handles = dict() # <key: nn.Module> <value: list(bcast handles)>
|
||||
|
||||
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
total_param_num = sum(p.numel() for p in model.parameters())
|
||||
|
@ -824,10 +826,18 @@ class ParamBcastSyncHandler:
|
|||
for _, block in enumerate(children):
|
||||
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
||||
self._block_to_param[block] = list(block.parameters())
|
||||
key_list = list(self._block_to_param.keys())
|
||||
if len(key_list) > 1:
|
||||
up_layer = key_list[-2]
|
||||
self._block_next_block[up_layer] = key_list[-1]
|
||||
else:
|
||||
# record the block that a parameter belongs to
|
||||
# self._block_to_param[name] = list(children.parameters())
|
||||
self._block_to_param[children] = list(children.parameters())
|
||||
key_list = list(self._block_to_param.keys())
|
||||
if len(key_list) > 1:
|
||||
up_layer = key_list[-2]
|
||||
self._block_next_block[up_layer] = key_list[-1]
|
||||
|
||||
alloc_num = 0
|
||||
rank_to_go = 0
|
||||
|
@ -857,16 +867,35 @@ class ParamBcastSyncHandler:
|
|||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
self._register_sync_parameters_hook()
|
||||
|
||||
def _launch_handle(self, layer):
|
||||
handle_metas = []
|
||||
for rank in self._block_to_rank[layer]:
|
||||
handle_metas.extend(self._bcast_handles[rank])
|
||||
# need to clear _bcast_handles since they would be processed later
|
||||
self._bcast_handles[rank] = []
|
||||
# wait all required broadcast handles to be completed
|
||||
handles = []
|
||||
for handle_meta in handle_metas:
|
||||
handle = dist.broadcast(**handle_meta)
|
||||
handles.append(handle)
|
||||
self._block_to_handles[layer] = handles
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
bcast_handles = []
|
||||
# gather all required broadcast hanles into a list
|
||||
for rank in self._block_to_rank[model]:
|
||||
bcast_handles.extend(self._bcast_handles[rank])
|
||||
# need to clear _bcast_handles since they would be processed later
|
||||
self._bcast_handles[rank] = []
|
||||
# wait all required broadcast handles to be completed
|
||||
for handle in bcast_handles:
|
||||
current_layer = model
|
||||
next_layer = self._block_next_block[current_layer] if current_layer in self._block_next_block else None
|
||||
|
||||
# if this is the first layer
|
||||
# launch broadcast for current layer
|
||||
if current_layer == list(self._block_to_param.keys())[0]:
|
||||
self._launch_handle(current_layer)
|
||||
|
||||
# if this is not the last layer
|
||||
# launch broadcast for next layer
|
||||
if next_layer:
|
||||
self._launch_handle(next_layer)
|
||||
|
||||
for handle in self._block_to_handles[current_layer]:
|
||||
handle.wait()
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
|
|
|
@ -5,8 +5,6 @@ import time
|
|||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
class _Timer:
|
||||
"""Timer."""
|
||||
|
@ -25,16 +23,14 @@ class _Timer:
|
|||
megatron_timer.reset()
|
||||
|
||||
assert not self.started_, "timer has already been started"
|
||||
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||
self.stream.synchronize()
|
||||
self.stream.synchronize()
|
||||
self.start_time = time.time()
|
||||
self.started_ = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the timer."""
|
||||
assert self.started_, "timer is not started"
|
||||
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||
self.stream.synchronize()
|
||||
self.stream.synchronize()
|
||||
self.elapsed_ += time.time() - self.start_time
|
||||
self.started_ = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue