diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 0058e04..09af7f4 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -152,19 +152,19 @@ zero1 parallel (dict): 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. - 2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp', - the sequence_parallel should be True. + 2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'], + defaults to 'none', means the sequence parallel will be disabled. + 3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp, + defaults to False. pipeline parallel (dict): 1. size: int, the size of pipeline parallel. 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, defaults to False. -sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, sp="intern", intern_overlap=True), + tensor=dict(size=8, sp="megatron", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=True, ) cudnn_deterministic = False diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index f0caf05..56661d8 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -194,7 +194,6 @@ class NonPipelineScheduler(BaseScheduler): _output, _loss, _moe_loss = self._train_one_batch( _data, _label, engine, forward_only, return_loss, self._grad_accum_size ) - engine.optimizer.reset_reduce_bucket() if return_loss: loss += _loss diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 80611fe..0e74f76 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -306,15 +306,20 @@ def args_sanity_check(): ), "sequence parallel does not support use_flash_attn=False" if isinstance(gpc.config.parallel["tensor"], int): - gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp") - - if gpc.config.parallel["tensor"].get("mode", None) is None: - gpc.config.parallel["tensor"]["mode"] = "origin_tp" - - if gpc.config.parallel["tensor"].get("mode", None) == "fstp": - assert ( - gpc.config.parallel.sequence_parallel is True - ), "when the tp_mode is fstp, the sequence_parallel should be True." + gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False) + if gpc.config.parallel["tensor"].get("sp", None) is None: + gpc.config.parallel["tensor"]["sp"] = "none" + if gpc.config.parallel["tensor"].get("intern_overlap", None) is None: + gpc.config.parallel["tensor"]["intern_overlap"] = False + assert gpc.config.parallel["tensor"].get("sp", None) in [ + "none", + "megatron", + "flash-attn", + "intern", + ], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported" + # adapt to old version's sequence parallel config + if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]: + gpc.config.parallel.sequence_parallel = True # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 8f57a02..2bbb941 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -451,49 +451,33 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self.index_to_fsdp_modules[idx].append(child) self.module_name_index[child] = index index = index + 1 - + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") if child.bias is not None: setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") - # _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - # setattr(child.weight, "_fstp_all_reduce_str", f"{_full_name}.weight") - # if child.bias is not None: - # setattr(child.bias, "_fstp_all_reduce_str", f"{_full_name}.bias") else: continue elif isinstance(children, ScaleColumnParallelLinear): self.head.append(children) elif isinstance(children, Embedding1D): self.embedding.append(children) - - def get_zero_by_shape(self, size:tuple, dtype, device) -> torch.Tensor: - if size not in self.zero_const_pool: + + def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor: + if size not in self.zero_const_pool: self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() - + return self.zero_const_pool[size] - - def _all_gather_block_weight(self, block_index: int): - #block = self.index_to_block[block_index] - fsdp_modules = self.index_to_fsdp_modules[block_index] - # self.block_handles[block] = [] - for module in fsdp_modules: - total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) - self.FSTP_global_weights[module] = total_weight - self.FSTP_global_handle[module] = weight_handle - # self.block_handles[block].append(weight_handle) - def _all_gather_block_weight_memory_pool(self, block_index: int): fsdp_modules = self.index_to_fsdp_modules[block_index] - # self.block_handles[block] = [] for module in fsdp_modules: module_index = self.module_name_index[module] name = self.module_name[module_index] - weight_handle = all_gather_raw_memory_pool(module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name) - # self.FSTP_global_weights[module] = total_weight + weight_handle = all_gather_raw_memory_pool( + module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name + ) self.FSTP_global_handle[module] = weight_handle - # self.block_handles[block].append(weight_handle) def _register_sync_parameters_hook(self) -> None: """ @@ -510,41 +494,14 @@ class CoarseGrainedFSTPAllGatherSyncHandler: block_index = self.module_to_index[module] # start the all-gather for next block if block_index + 1 < gpc.config.NUM_LAYER: - # self._all_gather_block_weight(block_index + 1) self._all_gather_block_weight_memory_pool(block_index + 1) - def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): - block_index = self.block_to_index[block] - if block_index == 0: - # all gather weight for block 0 - fsdp_modules = self.index_to_fsdp_modules[block_index] - for module in fsdp_modules: - total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handle.wait() - self.FSTP_global_weights[module] = total_weight - else: - # wait handle for current block - handles = self.block_handles[block] - for handle in handles: - handle.wait() - - def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output): - # self._all_gather_block_weight(0) + def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output): self._all_gather_block_weight_memory_pool(0) - - def _post_forward_hook_for_block(block: nn.Module, input, output): - block_index = self.block_to_index[block] - fsdp_modules = self.index_to_fsdp_modules[block_index] - if block in self.block_handles: - del self.block_handles[block] - for module in fsdp_modules: - del self.FSTP_global_weights[module] - - def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,): - block_index = self.module_to_index[module] - handler = self.FSTP_global_handle[module] - handler.wait() + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): + handle = self.FSTP_global_handle[module] + handle.wait() def _post_forward_hook_for_module(module: nn.Module, input, output): if module in self.FSTP_global_weights: @@ -552,67 +509,44 @@ class CoarseGrainedFSTPAllGatherSyncHandler: if module in self.FSTP_global_handle: del self.FSTP_global_handle[module] - def _pre_backward_hook_for_block(block: nn.Module, grad_output): - # import pdb; pdb.set_trace() - block_index = self.block_to_index[block] - # if block_index == gpc.config.NUM_LAYER - 1: - # # all gather weight for the last block - # fsdp_modules = self.index_to_fsdp_modules[block_index] - # for module in fsdp_modules: - # total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) - # weight_handle.wait() - # self.FSTP_global_weights[module] = total_weight - # else: - # # wait handle for current block - # handles = self.block_handles[block] - # for handle in handles: - # handle.wait() - # if block_index == gpc.config.NUM_LAYER - 1: - # self._all_gather_block_weight(block_index) - # start the all-gather for next block - if block_index - 1 >= 0: - self._all_gather_block_weight(block_index - 1) - def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): first_module = self.block_module[gpc.config.NUM_LAYER - 1][4] total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True) self.FSTP_global_handle[first_module] = weight_handler self.FSTP_global_weights[first_module] = total_weight - def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): - block_index = self.block_to_index[block] - fsdp_modules = self.index_to_fsdp_modules[block_index] - if block in self.block_handles: - del self.block_handles[block] - for module in fsdp_modules: - del self.FSTP_global_weights[module] - def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output): block_index = self.module_to_index[module] name_index = self.module_name_index[module] - + if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: - # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) weight_handler = self.FSTP_global_handle[module] weight_handler.wait() - # self.FSTP_global_weights[module] = total_weight # start the all-gather for next module next_module = self.block_module[block_index][name_index - 1] next_name = self.module_name[name_index - 1] weights_handler = all_gather_raw_memory_pool( - next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=next_name + next_module.weight, + self.process_group, + async_op=True, + block_index=block_index, + module_name=next_name, ) self.FSTP_global_handle[next_module] = weights_handler elif name_index == 0: handler = self.FSTP_global_handle[module] handler.wait() - + if block_index - 1 >= 0: next_module = self.block_module[block_index - 1][4] name = self.module_name[4] weights_handler = all_gather_raw_memory_pool( - next_module.weight, self.process_group, async_op=True, block_index=block_index - 1, module_name=name, + next_module.weight, + self.process_group, + async_op=True, + block_index=block_index - 1, + module_name=name, ) self.FSTP_global_handle[next_module] = weights_handler else: @@ -625,76 +559,24 @@ class CoarseGrainedFSTPAllGatherSyncHandler: next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name ) self.FSTP_global_handle[next_module] = weights_handler - # if module in self.FSTP_global_handle: - # handler = self.FSTP_global_handle[module] - # handler.wait() - - def _pre_backward_hook_for_module(module: nn.Module, grad_output): - block_index = self.module_to_index[module] - name_index = self.module_name_index[module] - - if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: - # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler = self.FSTP_global_handle[module] - weight_handler.wait() - # self.FSTP_global_weights[module] = total_weight - - # start the all-gather for next module - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - elif name_index == 0: - handler = self.FSTP_global_handle[module] - handler.wait() - - if block_index - 1 >= 0: - next_module = self.block_module[block_index - 1][4] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - else: - handler = self.FSTP_global_handle[module] - handler.wait() - if name_index != 0: - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - # if module in self.FSTP_global_handle: - # handler = self.FSTP_global_handle[module] - # handler.wait() def _post_backward_hook_for_module(module, grad_input, grad_output): if module in self.FSTP_global_weights: del self.FSTP_global_weights[module] if module in self.FSTP_global_handle: del self.FSTP_global_handle[module] - + for embedding in self.embedding: - embedding.register_forward_hook(_pre_forward_hook_for_embedding) - + embedding.register_forward_hook(_post_forward_hook_for_embedding) + for head in self.head: head.register_full_backward_hook(_post_backward_hook_for_head) - # for block in self.FSTP_blocks: - # block.register_forward_pre_hook(_pre_forward_hook_for_block) - # block.register_forward_hook(_post_forward_hook_for_block) - # block.register_full_backward_pre_hook(_pre_backward_hook_for_block) - # block.register_full_backward_hook(_post_backward_hook_for_block) - for out_proj in self.FSTP_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) - - # for wqkv in self.FSTP_wqkvs: - # wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv) for module in self.FSTP_modules: module.register_forward_pre_hook(_pre_forward_hook_for_module) module.register_forward_hook(_post_forward_hook_for_module) - # module.register_full_backward_pre_hook(_pre_backward_hook_for_module) module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool) module.register_full_backward_hook(_post_backward_hook_for_module) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 99d540f..3ed78d7 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -396,7 +396,7 @@ class PackedFlashInternLm1D(nn.Module): assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. + # if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension. if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern": indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 6757906..b1894e9 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,20 +1,20 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Optional, Union +from typing import Optional import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F -from flash_attn.utils.distributed import all_reduce_raw #, reduce_scatter_raw +from flash_attn.utils.distributed import all_reduce_raw from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.utils.logger import get_logger from internlm.utils.common import get_current_device +from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -125,9 +125,20 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = ) return output, handle -def all_gather_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0, block_index: int = None, module_name: str = None): + +def all_gather_raw_memory_pool( + input_: Tensor, + process_group: ProcessGroup, + async_op: bool = False, + gather_dim: int = 0, + block_index: int = None, + module_name: str = None, +): handle = torch.distributed.all_gather_into_tensor( - gpc.config.block_memory[block_index % 2][module_name], input_.contiguous(), group=process_group, async_op=async_op + gpc.config.block_memory[block_index % 2][module_name], + input_.contiguous(), + group=process_group, + async_op=async_op, ) return handle @@ -142,23 +153,25 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 - output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], - dtype=input_.dtype, device=input_.device).contiguous() - handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), - group=process_group, - async_op=async_op) + output = torch.empty( + input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device + ).contiguous() + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) return output, handle + def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 size = (input_.shape[0] // world_size, *input_.shape[1:]) index = check_reduce_scatter_memory_pool(size) - output = gpc.config.reduce_scatter_memory[size]['data'][index] + output = gpc.config.reduce_scatter_memory[size]["data"][index] setattr(output, "index", index) - handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), - group=process_group, - async_op=async_op) + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) return output, handle @@ -444,7 +457,18 @@ class FSTPFusedDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, overlap_handler=None, block_index=None, module_name=None): + def forward( + ctx, + x, + weight, + bias, + return_residual=False, + process_group=None, + module=None, + overlap_handler=None, + block_index=None, + module_name=None, + ): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group @@ -506,7 +530,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): overlap_handler = ctx.overlap_handler block_index = ctx.block_index module_name = ctx.module_name - + if ctx.compute_weight_gradient: x, weight, bias = ctx.saved_tensors total_x = x @@ -540,7 +564,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function): overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) if grad_bias is not None: - grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True) + grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool( + grad_bias, process_group, async_op=True + ) assert hasattr(bias, "_fstp_reduce_scatter_str") overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) @@ -619,7 +645,9 @@ def fstp_fused_dense_func( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler, block_index, module_name) + return FSTPFusedDenseFunc.apply( + x, weight, bias, return_residual, process_group, module, handler, block_index, module_name + ) else: assert process_group is None out = F.linear(x, weight, bias) @@ -666,36 +694,37 @@ def Silu(w1_o, w2_o): Silu = torch.jit.script(Silu) + def check_reduce_scatter_memory_pool(key): - return_idx = 0 - + # if key not in dict if key not in gpc.config.reduce_scatter_memory: - gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []} - + gpc.config.reduce_scatter_memory[key] = {"data": [], "used": []} + # if the data is empty - if len(gpc.config.reduce_scatter_memory[key]['data']) == 0: - gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key, - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous()) - gpc.config.reduce_scatter_memory[key]['used'].append(True) + if len(gpc.config.reduce_scatter_memory[key]["data"]) == 0: + gpc.config.reduce_scatter_memory[key]["data"].append( + torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() + ) + gpc.config.reduce_scatter_memory[key]["used"].append(True) return_idx = 0 return return_idx - else: # if not empty - for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']): - if used == False: - gpc.config.reduce_scatter_memory[key]['used'][index] = True + else: # if not empty + for index, used in enumerate(gpc.config.reduce_scatter_memory[key]["used"]): + if used is False: + gpc.config.reduce_scatter_memory[key]["used"][index] = True return_idx = index return return_idx # if the memory pool is all used - length = len(gpc.config.reduce_scatter_memory[key]['data']) - gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key, - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous()) - gpc.config.reduce_scatter_memory[key]['used'].append(True) + length = len(gpc.config.reduce_scatter_memory[key]["data"]) + gpc.config.reduce_scatter_memory[key]["data"].append( + torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() + ) + gpc.config.reduce_scatter_memory[key]["used"].append(True) return_idx = length return return_idx + def release_reduce_scatter_memory_pool(size, index): - gpc.config.reduce_scatter_memory[size]['used'][index] = False \ No newline at end of file + gpc.config.reduce_scatter_memory[size]["used"][index] = False diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 4de5c7c..0f536ec 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -3,6 +3,7 @@ import math from functools import partial +from typing import List, Optional import torch import torch.distributed as dist @@ -10,7 +11,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool +from internlm.model.utils import release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -65,7 +66,8 @@ class HybridZeroOptimizer(BaseOptimizer): hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True: + self._fstp_handler = None + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: self._fstp_handler = gpc.config.fstp_handler # Zero related args @@ -85,8 +87,8 @@ class HybridZeroOptimizer(BaseOptimizer): # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(ParallelMode.ZERO1) self._grad_store = GradientStore(ParallelMode.DATA) - self._bucket_store = [] - self._bucket_store_2 = [] + self._bucket_store: List[BucketStore] = [] + self._accum_grad_buckets: List[BucketStore] = [] self._bucket_in_progress = [] # fp16 and fp32 params for mixed precision training @@ -155,7 +157,7 @@ class HybridZeroOptimizer(BaseOptimizer): # TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name self._broadcast_parallel_mode.append(zero_mode) self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"])) - self._bucket_store_2.append(BucketStore(group_id, param_group["dp_mode"])) + self._accum_grad_buckets.append(BucketStore(group_id, param_group["dp_mode"])) # assign parameters to ranks the params in the list are sorted params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group) @@ -301,9 +303,9 @@ class HybridZeroOptimizer(BaseOptimizer): param=param, reduce_rank=reduce_rank, ) - + reduce_scatter_checker = partial( - self._wait_reduce_scatter_and_accumulate_grad, + self._wait_reduce_scatter_and_accumulate_grads, param=param, reduce_rank=reduce_rank, ) @@ -312,7 +314,7 @@ class HybridZeroOptimizer(BaseOptimizer): # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad def reduce_grad_hook(*args): # pylint: disable=W0613 - if gpc.config.fstp_handler is not None: + if self._fstp_handler is not None: reduce_scatter_checker() if self.skip_grad_reduce is False: @@ -336,56 +338,36 @@ class HybridZeroOptimizer(BaseOptimizer): group_id = getattr(param, "group_id") return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id]) - def reset_reduce_bucket(self) -> None: - for bucket in self._bucket_store_2: - for rank, params in bucket._params.items(): - for _param in params: - if not hasattr(_param, "_fstp_reduce_scatter_str"): - continue + def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optional[int] = None) -> None: + for _param in bucket.get_param(reduce_rank): + if not hasattr(_param, "_fstp_reduce_scatter_str"): + continue - key = getattr(_param, "_fstp_reduce_scatter_str") - comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] - comm_handle.wait() - _param.grad.add_(_grad) - # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) - del self._fstp_handler.reduce_scatter_handlers[key] - self._fstp_handler.reduce_scatter_handlers[key] = None - assert key in self._fstp_handler.reduce_scatter_handlers + # wait and accumulate gardient. + _key = getattr(_param, "_fstp_reduce_scatter_str") + _comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[_key] + _comm_handle.wait() + _param.grad.add_(_grad) + # release cuda memory. + release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) + self._fstp_handler.reduce_scatter_handlers[_key] = None - bucket.reset_by_rank(rank) - - def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None): + bucket.reset_by_rank(reduce_rank) + + def _wait_reduce_scatter_and_accumulate_grads(self, param, reduce_rank: Optional[int] = None): param_size = param.numel() + group_id = getattr(param, "group_id") + current_bucket = self._accum_grad_buckets[group_id] + # check if the bucket is full # if full, will reduce the grads already in the bucket # after reduction, the bucket will be empty - group_id = getattr(param, "group_id") - current_bucket = self._bucket_store_2[group_id] + if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size: + self._accum_grads_store_in_bucket(current_bucket, reduce_rank) - if current_bucket.num_elements_in_bucket(reduce_rank) >= 512 * 1024 * 1024: - # wait reduce scatter communication - params = current_bucket.get_param(reduce_rank) - for _param in params: - if not hasattr(_param, "_fstp_reduce_scatter_str"): - continue - - key = getattr(_param, "_fstp_reduce_scatter_str") - comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] - comm_handle.wait() - _param.grad.add_(_grad) - # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) - del self._fstp_handler.reduce_scatter_handlers[key] - self._fstp_handler.reduce_scatter_handlers[key] = None - assert key in self._fstp_handler.reduce_scatter_handlers - - current_bucket.reset_by_rank(reduce_rank) - + # otherwise, add the parameter into bucket. current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) current_bucket.add_param(param, reduce_rank) @@ -612,6 +594,10 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) + # we need to accumulate gradients left in the accumulate gardient bucket + for group_id in range(self.num_param_groups): + self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None) + # compute norm for gradients in the before bucket groups_norms = [] for group_id in range(self.num_param_groups): @@ -773,7 +759,7 @@ class HybridZeroOptimizer(BaseOptimizer): torch.cuda.synchronize() with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() - + timer("step").stop() # update gradients may not be needed here, because the sync_params function is used in initialization, diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 228045e..f486cce 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -45,7 +45,7 @@ class BucketStore(BaseStore): def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] - + def num_params_in_bucket(self, reduce_rank: int = None): return len(self._params[reduce_rank]) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 20592c2..53996b3 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -107,48 +107,60 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) - + gpc.config.fstp_handler = None - if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True: + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) - # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() gpc.config.fstp_handler = handler - + # allocate memory pool - block_memory = {} # containing two groups of block weight + block_memory = {} # containing two groups of block weight hidden_size = gpc.config.HIDDEN_SIZE mlp_ratio = gpc.config.MLP_RATIO mlp_hidden_size = int(hidden_size * mlp_ratio) mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) world_size = gpc.get_world_size(ParallelMode.TENSOR) - size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)] - module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3'] + size_key = [ + (3 * hidden_size // world_size, hidden_size), + (mlp_hidden_size // world_size, hidden_size), + (hidden_size // world_size, mlp_hidden_size), + (hidden_size // world_size, hidden_size), + ] + module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] for i in range(2): weight = {} for name in module_name: - if name == 'Wqkv': - weight[name] = torch.zeros((3 * hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() - elif name == 'out_proj': - weight[name] = torch.zeros((hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() - elif name == 'w1' or name == 'w2': - weight[name] = torch.zeros((mlp_hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() + if name == "Wqkv": + weight[name] = torch.zeros( + (3 * hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + elif name == "out_proj": + weight[name] = torch.zeros( + (hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + elif name == "w1" or name == "w2": + weight[name] = torch.zeros( + (mlp_hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() else: - weight[name] = torch.zeros((hidden_size, mlp_hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() + weight[name] = torch.zeros( + (hidden_size, mlp_hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() block_memory[i] = weight reduce_scatter_memory = {} for key in size_key: - reduce_scatter_memory[key] = {'data': [], 'used': []} - + reduce_scatter_memory[key] = {"data": [], "used": []} + gpc.config.block_memory = block_memory gpc.config.reduce_scatter_memory = reduce_scatter_memory diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 968a1db..f708fa7 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape def switch_sequence_parallel_mode(): prev_mode = gpc.config.parallel.sequence_parallel try: - if gpc.config.parallel["tensor"]["mode"] == "fstp": + if gpc.config.parallel["tensor"]["sp"] == "intern": gpc.config.parallel.sequence_parallel = True else: gpc.config.parallel.sequence_parallel = False @@ -106,7 +106,7 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz - if gpc.config.parallel["tensor"]["mode"] == "fstp": + if gpc.config.parallel["tensor"]["sp"] == "intern": sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) tensor_shape = torch.Size( [ diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 52d9638..bf4cf1c 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -45,7 +45,7 @@ def empty_cache_and_diag(batch_count, interval=50): # # import time # # time.sleep(10) # print(e, "rank = ", gpc.get_global_rank(), flush=True) - # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # do empty_cache after the bench torch.cuda.empty_cache() diff --git a/train.py b/train.py index a917d12..02f2802 100644 --- a/train.py +++ b/train.py @@ -296,7 +296,7 @@ def main(args): memory_profiler.step() prof.step() - + if gpc.config.fstp_handler is not None: gpc.config.fstp_handler.zero_const_pool = {} gpc.config.fstp_handler.reduce_scatter_memory = {}