From 32664328e7c71a9f9ce5483449dd1e65d5a96038 Mon Sep 17 00:00:00 2001 From: Sun Peng Date: Wed, 23 Aug 2023 16:59:59 +0800 Subject: [PATCH] Feat/overlap_bcast_forward (#218) * feat/support bcast forward overlao * feat/optimize the bcast call * feat/optimize the bcast call * feat/optimize the bcast call * fix lint * fix lint * fix lint * fix lint * add torch.cuda.synchronize in save_checkpoint --------- Co-authored-by: sunpeng --- .../solver/optimizer/hybrid_zero_optim.py | 93 ++++++++----- internlm/solver/optimizer/utils.py | 130 ++++++++++++++++-- internlm/utils/megatron_timers.py | 5 +- internlm/utils/model_checkpoint.py | 1 + train.py | 15 +- 5 files changed, 193 insertions(+), 51 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 55fad5f..298f90a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -3,6 +3,7 @@ import math from functools import partial +from itertools import product import torch import torch.distributed as dist @@ -19,6 +20,7 @@ from internlm.solver.optimizer.store import ( ) from internlm.solver.optimizer.utils import ( DynamicGradScaler, + ParamBcastSyncHandler, flatten, get_grad_accumulate_object, has_inf_or_nan, @@ -87,9 +89,9 @@ class HybridZeroOptimizer(BaseOptimizer): self, optimizer: Optimizer, cpu_offload=False, - overlap_broadcast=False, grad_scal_cfg: Config = None, zero_cfg: Config = None, + param_bcast_sync_handler: ParamBcastSyncHandler = None, ): # DynamicGradScaler related args if gpc.config.model.dtype is torch.float32: @@ -158,7 +160,9 @@ class HybridZeroOptimizer(BaseOptimizer): + f"zo-{self._zero_local_rank}.pt" ) self.params_per_rank_id_dict = [] - self.overlap_broadcast = overlap_broadcast + self._param_bcast_sync_handler = param_bcast_sync_handler + if self._overlap_communication: + assert self._param_bcast_sync_handler is not None # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -230,6 +234,8 @@ class HybridZeroOptimizer(BaseOptimizer): # communication-computation overlapping if self._overlap_communication: self._comm_stream = torch.cuda.Stream() + else: + self._comm_stream = torch.cuda.current_stream() # reduction hook is only used if overlapping communication # if it is stage 1 without overlapping, no hook will be attached @@ -267,8 +273,10 @@ class HybridZeroOptimizer(BaseOptimizer): global_id = str(i) for j in range(len(param.size())): global_id = "_".join([global_id, str(param.size()[j])]) - - rank_to_go = numel_per_rank.index(min(numel_per_rank)) + if self._overlap_communication: + rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param) + else: + rank_to_go = numel_per_rank.index(min(numel_per_rank)) params_per_rank[rank_to_go].append(param) self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) numel_per_rank[rank_to_go] += param.numel() @@ -299,7 +307,9 @@ class HybridZeroOptimizer(BaseOptimizer): self._grad_store.add_accumulate_grad_object(accum_grad_obj) reduction_func = partial( - self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank + self._store_and_try_reduce_grads_by_bucket, + param=param, + reduce_rank=reduce_rank, ) # define hook @@ -385,16 +395,16 @@ class HybridZeroOptimizer(BaseOptimizer): def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): if self._overlap_communication: - stream = self._comm_stream - stream.synchronize() + self._comm_stream.synchronize() self._param_store.clear_grads_of_previous_reduced_params() - else: - stream = torch.cuda.current_stream() - with torch.cuda.stream(stream): + with torch.cuda.stream(self._comm_stream): flat = bucket.flatten() reduced_flat = reduce_tensor( - tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA + tensor=flat, + dtype=self.dtype, + dst_rank=reduce_rank, + parallel_mode=ParallelMode.DATA, ) # update the reduced tensor @@ -532,7 +542,10 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): total_norms.append( self._compute_norm_with_stage( - group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id] + group_id=group_id, + last_bucket=True, + last_stage=True, + previous_norm=groups_norms[group_id], ) ) @@ -562,7 +575,10 @@ class HybridZeroOptimizer(BaseOptimizer): if found_inf: if gpc.is_rank_for_log(): logger.warning("Overflow occurs, please check it.") - send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.") + send_alert_message( + address=gpc.config.alert_address, + message="Overflow occurs, please check it.", + ) self._grad_store._averaged_gradients = dict() self.zero_grad() return False, None @@ -625,35 +641,40 @@ class HybridZeroOptimizer(BaseOptimizer): fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - # TODO: support broadcast overlap - self.broadcast_params(overlap=False) + with torch.cuda.stream(self._comm_stream): + self.broadcast_params() timer("step").stop() + # update gradients may not be needed here, because the sync_params function is used in initialization, # so synchronization is maintained return True, [global_norm / loss_scale for global_norm in global_norm_groups] - def broadcast_params(self, overlap=False): + def broadcast_params(self): handles = [] - for group_id in range(self.num_param_groups): - for rank in range(self._zero_world_size): - # The following operations are performed only on the rank to which parameters are assigned. - if rank not in self.param_group_no_params_ranks[group_id]: - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) - # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank - # assert grank == rank, f"{grank} == {rank}" - g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank] - handle = dist.broadcast( - fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True - ) - handles.append(handle) + for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)): + # The following operations are performed only on the rank to which parameters are assigned. + if rank in self.param_group_no_params_ranks[group_id]: + continue + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) + # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank + # assert grank == rank, f"{grank} == {rank}" + g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank] + handle = dist.broadcast( + fp16_param, + src=g_rank, + group=gpc.get_group(ParallelMode.ZERO1), + async_op=True, + ) - if not overlap: - for handle in handles: - handle.wait() - else: - return handles + if self._overlap_communication: + self._param_bcast_sync_handler.add_bcast_handle(rank, handle) + else: + handles.append(handle) + + for handle in handles: + handle.wait() ################## # FP16 Utilities # @@ -671,7 +692,11 @@ class HybridZeroOptimizer(BaseOptimizer): if avg_grad is not None and has_inf_or_nan(avg_grad): self._found_overflow.fill_(1.0) break - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL)) + dist.all_reduce( + self._found_overflow, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.GLOBAL), + ) return self._found_overflow.item() > 0 diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 5a752ef..38e4560 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -3,15 +3,18 @@ import math from abc import ABC, abstractmethod -from typing import Dict, Optional +from collections import OrderedDict +from functools import partial +from typing import Any, Dict, Optional, Union import torch import torch.distributed as dist -from torch import Tensor +from torch import Tensor, nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import NaiveAMPModel from internlm.utils.common import get_tensor_norm, move_norm_to_cuda from internlm.utils.logger import get_logger from internlm.utils.parallel import is_model_parallel_parameter @@ -60,12 +63,19 @@ def get_grad_accumulate_object(tensor): def split_half_float_double(tensor_list): - dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] - buckets = [] - for _, dtype in enumerate(dtypes): - bucket = [t for t in tensor_list if t.type() == dtype] - if bucket: - buckets.append(bucket) + dtype_buckets = { + "torch.cuda.HalfTensor": [], + "torch.cuda.FloatTensor": [], + "torch.cuda.DoubleTensor": [], + "torch.cuda.BFloat16Tensor": [], + } + + for t in tensor_list: + dtype = t.type() + if dtype in dtype_buckets: + dtype_buckets[dtype].append(t) + + buckets = [bucket for bucket in dtype_buckets.values() if bucket] return buckets @@ -184,7 +194,10 @@ def calc_l2_norm(grads): if APEX_AVAILABLE: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [grads], + False, # no per-parameter norm ) else: norm, _ = multi_tensor_l2norm_torch(grads, False) @@ -228,7 +241,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no # Take max across all model-parallel GPUs. if gpc.get_world_size(ParallelMode.MODEL) > 1: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) + dist.all_reduce( + total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.MODEL), + ) total_norm = total_norm_cuda[0].item() else: tensor_parallel_grads = [] @@ -280,7 +297,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no # Sum across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.MODEL): - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL)) + dist.all_reduce( + total_norm, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.MODEL), + ) # This is because we use zero1, so we need to use this reduction. # TODO: Check zero group to be a subset of dp group. @@ -459,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler): self._scale = self._scale.fill_(state_dict["_scale"]) self._growth_step = state_dict["_growth_step"] self._hysteresis_step = state_dict["_hysteresis_step"] + + +class ParamBcastSyncHandler: + """ + Model Partition Handler for overlap broadcast with forward + """ + + def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None: + self._block_to_param = OrderedDict() # + self._param_to_rank = dict() # + self._block_to_rank = dict() # + self._bcast_handles = dict() # + + zero1_size = gpc.get_world_size(ParallelMode.ZERO1) + total_param_num = sum(p.numel() for p in model.parameters()) + avg_param_num = total_param_num * 1.0 // zero1_size + + # just want to share same for loop for ModuleList and Module + if not isinstance(model, nn.ModuleList): + model = [model] + + # record the parameters to transformer/embeding/head/norm block + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _, children in _chunk.named_children(): + # should be the transformer block definaton in modeling_xxx.py + if isinstance(children, nn.ModuleList): + # record the block that a parameter belongs to + for _, block in enumerate(children): + # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) + self._block_to_param[block] = list(block.parameters()) + else: + # record the block that a parameter belongs to + # self._block_to_param[name] = list(children.parameters()) + self._block_to_param[children] = list(children.parameters()) + + alloc_num = 0 + rank_to_go = 0 + + # process the parameters in block_to_param sequencially, + # allocate each parameter to a local rank of ParallelMode.ZERO1, + # NOTE that we do NOT consider following scenarios: + # 1) whether a parameter is trainable; + # 2) paramters maybe in different optimizer group + for block, params in self._block_to_param.items(): + # allocate a model block to a local rank of ParallelMode.ZERO1 + self._block_to_rank[block] = [rank_to_go] + for p in params: + alloc_num = alloc_num + p.numel() + # in this case, allocate the param to next rank if possible + if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1: + rank_to_go = rank_to_go + 1 + alloc_num = 0 + self._block_to_rank[block].append(rank_to_go) + # allocate a parameter to a local rank of ParallelMode.ZERO1 + self._param_to_rank[p] = rank_to_go + + # initialize an empty list for _bcast_handles of each rank + for rank in range(gpc.get_world_size(ParallelMode.ZERO1)): + self._bcast_handles[rank] = [] + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + self._register_sync_parameters_hook() + + def _register_sync_parameters_hook(self) -> None: + def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613 + bcast_handles = [] + # gather all required broadcast hanles into a list + for rank in self._block_to_rank[model]: + bcast_handles.extend(self._bcast_handles[rank]) + # need to clear _bcast_handles since they would be processed later + self._bcast_handles[rank] = [] + # wait all required broadcast handles to be completed + for handle in bcast_handles: + handle.wait() + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + for block, _ in self._block_to_rank.items(): + block.register_forward_pre_hook(partial(_pre_forward_hook)) + + def get_rank_by_param(self, param) -> int: + return self._param_to_rank[param] + + def add_bcast_handle(self, rank, handle) -> None: + self._bcast_handles[rank].append(handle) diff --git a/internlm/utils/megatron_timers.py b/internlm/utils/megatron_timers.py index 6c4ed11..e319a80 100644 --- a/internlm/utils/megatron_timers.py +++ b/internlm/utils/megatron_timers.py @@ -14,18 +14,19 @@ class _Timer: self.elapsed_ = 0.0 self.started_ = False self.start_time = time.time() + self.stream = torch.cuda.current_stream() def start(self): """Start the timer.""" assert not self.started_, "timer has already been started" - torch.cuda.synchronize() + self.stream.synchronize() self.start_time = time.time() self.started_ = True def stop(self): """Stop the timer.""" assert self.started_, "timer is not started" - torch.cuda.synchronize() + self.stream.synchronize() self.elapsed_ += time.time() - self.start_time self.started_ = False diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 3dca7c5..08d9db7 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -565,6 +565,7 @@ set load_ckpt_folder or use default value \ start = time.time() self.set_save_folder(folder, train_state.step_count) + torch.cuda.synchronize() torch.distributed.barrier() if gpc.is_rank_for_log(): logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") diff --git a/train.py b/train.py index 31e8567..aa48208 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ import socket import time import traceback from functools import partial -from typing import Iterable +from typing import Iterable, Union import numpy as np import torch @@ -36,6 +36,7 @@ from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.utils.common import ( BatchSkipper, DummyProfile, @@ -291,7 +292,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai return batch, train_iter -def initialize_optimizer(model: nn.Module): +def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): """ Initialize optimizer. @@ -300,6 +301,7 @@ def initialize_optimizer(model: nn.Module): Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler). """ + param_bcast_sync_handler = ParamBcastSyncHandler(model) adam_cfg = gpc.config.adam naive_optimizer = torch.optim.AdamW( params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}], @@ -309,7 +311,10 @@ def initialize_optimizer(model: nn.Module): ) optimizer = HybridZeroOptimizer( - naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, + param_bcast_sync_handler=param_bcast_sync_handler, ) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) @@ -599,6 +604,7 @@ def main(args): # do forward and backward timer("fwd-bwd").start() + _, _, loss = trainer.execute_schedule( batch, forward_only=False, return_loss=True, return_output_label=False ) @@ -661,7 +667,8 @@ def main(args): if memory_profiler is not None: memory_profiler.step() - prof.step() + if batch_count % 2 == 0: + prof.step() ckpt_manager.wait_async_upload_finish()