mirror of https://github.com/InternLM/InternLM
fix
parent
f68f34234d
commit
d9c9f7c9ee
|
@ -100,8 +100,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# self._overlap_communication = overlap_communication
|
# self._overlap_communication = overlap_communication
|
||||||
self._reduce_bucket_size = reduce_bucket_size
|
self._reduce_bucket_size = reduce_bucket_size
|
||||||
|
|
||||||
self._comm_bcast_stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
# gradient scaler
|
# gradient scaler
|
||||||
self.grad_scaler = DynamicGradScaler(
|
self.grad_scaler = DynamicGradScaler(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
|
@ -859,16 +857,21 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||||
# assert grank == rank, f"{grank} == {rank}"
|
# assert grank == rank, f"{grank} == {rank}"
|
||||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
|
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
|
||||||
handle = dist.broadcast(
|
|
||||||
fp16_param,
|
|
||||||
src=g_rank,
|
|
||||||
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
|
|
||||||
async_op=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._overlap_sync_param:
|
if self._overlap_sync_param:
|
||||||
|
handle = dict()
|
||||||
|
handle["tensor"] = fp16_param
|
||||||
|
handle["src"] = g_rank
|
||||||
|
handle["group"] = gpc.get_group(self._broadcast_parallel_mode[group_id])
|
||||||
|
handle["async_op"] = True
|
||||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||||
else:
|
else:
|
||||||
|
handle = dist.broadcast(
|
||||||
|
fp16_param,
|
||||||
|
src=g_rank,
|
||||||
|
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
|
||||||
|
async_op=True,
|
||||||
|
)
|
||||||
handles.append(handle)
|
handles.append(handle)
|
||||||
|
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
|
|
|
@ -803,6 +803,8 @@ class ParamBcastSyncHandler:
|
||||||
self._param_to_rank = dict() # <key: param> <value: rank)>
|
self._param_to_rank = dict() # <key: param> <value: rank)>
|
||||||
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
|
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
|
||||||
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
|
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
|
||||||
|
self._block_next_block = dict() # <key: nn.Module> <value: nn.Module>
|
||||||
|
self._block_to_handles = dict() # <key: nn.Module> <value: list(bcast handles)>
|
||||||
|
|
||||||
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
|
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||||
total_param_num = sum(p.numel() for p in model.parameters())
|
total_param_num = sum(p.numel() for p in model.parameters())
|
||||||
|
@ -824,10 +826,18 @@ class ParamBcastSyncHandler:
|
||||||
for _, block in enumerate(children):
|
for _, block in enumerate(children):
|
||||||
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
||||||
self._block_to_param[block] = list(block.parameters())
|
self._block_to_param[block] = list(block.parameters())
|
||||||
|
key_list = list(self._block_to_param.keys())
|
||||||
|
if len(key_list) > 1:
|
||||||
|
up_layer = key_list[-2]
|
||||||
|
self._block_next_block[up_layer] = key_list[-1]
|
||||||
else:
|
else:
|
||||||
# record the block that a parameter belongs to
|
# record the block that a parameter belongs to
|
||||||
# self._block_to_param[name] = list(children.parameters())
|
# self._block_to_param[name] = list(children.parameters())
|
||||||
self._block_to_param[children] = list(children.parameters())
|
self._block_to_param[children] = list(children.parameters())
|
||||||
|
key_list = list(self._block_to_param.keys())
|
||||||
|
if len(key_list) > 1:
|
||||||
|
up_layer = key_list[-2]
|
||||||
|
self._block_next_block[up_layer] = key_list[-1]
|
||||||
|
|
||||||
alloc_num = 0
|
alloc_num = 0
|
||||||
rank_to_go = 0
|
rank_to_go = 0
|
||||||
|
@ -857,16 +867,35 @@ class ParamBcastSyncHandler:
|
||||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||||
self._register_sync_parameters_hook()
|
self._register_sync_parameters_hook()
|
||||||
|
|
||||||
|
def _launch_handle(self, layer):
|
||||||
|
handle_metas = []
|
||||||
|
for rank in self._block_to_rank[layer]:
|
||||||
|
handle_metas.extend(self._bcast_handles[rank])
|
||||||
|
# need to clear _bcast_handles since they would be processed later
|
||||||
|
self._bcast_handles[rank] = []
|
||||||
|
# wait all required broadcast handles to be completed
|
||||||
|
handles = []
|
||||||
|
for handle_meta in handle_metas:
|
||||||
|
handle = dist.broadcast(**handle_meta)
|
||||||
|
handles.append(handle)
|
||||||
|
self._block_to_handles[layer] = handles
|
||||||
|
|
||||||
def _register_sync_parameters_hook(self) -> None:
|
def _register_sync_parameters_hook(self) -> None:
|
||||||
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
|
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||||
bcast_handles = []
|
current_layer = model
|
||||||
# gather all required broadcast hanles into a list
|
next_layer = self._block_next_block[current_layer] if current_layer in self._block_next_block else None
|
||||||
for rank in self._block_to_rank[model]:
|
|
||||||
bcast_handles.extend(self._bcast_handles[rank])
|
# if this is the first layer
|
||||||
# need to clear _bcast_handles since they would be processed later
|
# launch broadcast for current layer
|
||||||
self._bcast_handles[rank] = []
|
if current_layer == list(self._block_to_param.keys())[0]:
|
||||||
# wait all required broadcast handles to be completed
|
self._launch_handle(current_layer)
|
||||||
for handle in bcast_handles:
|
|
||||||
|
# if this is not the last layer
|
||||||
|
# launch broadcast for next layer
|
||||||
|
if next_layer:
|
||||||
|
self._launch_handle(next_layer)
|
||||||
|
|
||||||
|
for handle in self._block_to_handles[current_layer]:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||||
|
|
|
@ -5,8 +5,6 @@ import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
|
||||||
|
|
||||||
|
|
||||||
class _Timer:
|
class _Timer:
|
||||||
"""Timer."""
|
"""Timer."""
|
||||||
|
@ -25,16 +23,14 @@ class _Timer:
|
||||||
megatron_timer.reset()
|
megatron_timer.reset()
|
||||||
|
|
||||||
assert not self.started_, "timer has already been started"
|
assert not self.started_, "timer has already been started"
|
||||||
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
self.stream.synchronize()
|
||||||
self.stream.synchronize()
|
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.started_ = True
|
self.started_ = True
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the timer."""
|
"""Stop the timer."""
|
||||||
assert self.started_, "timer is not started"
|
assert self.started_, "timer is not started"
|
||||||
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
self.stream.synchronize()
|
||||||
self.stream.synchronize()
|
|
||||||
self.elapsed_ += time.time() - self.start_time
|
self.elapsed_ += time.time() - self.start_time
|
||||||
self.started_ = False
|
self.started_ = False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue