diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index b3c8b8b..f7132c3 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -10,7 +10,7 @@ from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel from internlm.model.embedding import Embedding1D from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear -from internlm.model.utils import all_gather_raw_memory_pool +from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool from internlm.utils.common import get_current_device @@ -25,6 +25,7 @@ class FSTPOverlapHandler: self.fstp_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle + self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle self.module_to_index = dict() # key: fstp module; value: transformer block index self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules self.head = [] @@ -76,49 +77,61 @@ class FSTPOverlapHandler: self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() return self.zero_const_pool[size] - - def _initialize_memory_pool(self) -> None: - # allocate memory pool + + def _initialize_module_shape(self): hidden_size = gpc.config.HIDDEN_SIZE mlp_ratio = gpc.config.MLP_RATIO mlp_hidden_size = int(hidden_size * mlp_ratio) mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) + + self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size) + self.module_shape["out_proj"] = (hidden_size, hidden_size) + self.module_shape["w1"] = (mlp_hidden_size, hidden_size) + self.module_shape["w2"] = (mlp_hidden_size, hidden_size) + self.module_shape["w3"] = (hidden_size, mlp_hidden_size) + + def _initialize_memory_pool(self) -> None: + # allocate memory pool self.all_gather_memory_pool = [] + self.all_gather_bias_memory_pool = [] self.reduce_scatter_memory_pool = {} + self.module_shape = {} + + self._initialize_module_shape() + dtype = gpc.config.model.get("dtype", torch.half) + device = get_current_device() for _ in range(2): weight = {} for name in self.module_name: - if name == "Wqkv": - weight[name] = torch.zeros( - (3 * hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - elif name == "out_proj": - weight[name] = torch.zeros( - (hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - elif name == "w1" or name == "w2": - weight[name] = torch.zeros( - (mlp_hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - else: - weight[name] = torch.zeros( - (hidden_size, mlp_hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - + weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous() self.all_gather_memory_pool.append(weight) # containing two groups of block weight def get_all_gather_memory(self, module): block_index = self.module_to_index[module] return self.all_gather_memory_pool[block_index % 2][module._fstp_name] + + def get_bias_memory(self, module: nn.Module): + block_index = self.module_to_index[module] + # if the bias memory pool is empty or module has been not allocated memory + # import pdb; pdb.set_trace() + if len(self.all_gather_bias_memory_pool) == 0: + for _ in range(2): + weight = {} + weight[module._fstp_name] = torch.zeros( + self.module_shape[module._fstp_name][0], + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device()).contiguous() + self.all_gather_bias_memory_pool.append(weight) + elif module._fstp_name not in self.all_gather_bias_memory_pool[0]: + for i in range(2): + self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros( + self.module_shape[module._fstp_name][0], + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device()).contiguous() + + return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name] + def get_reduce_scatter_memory(self, key): return_idx = 0 @@ -157,10 +170,19 @@ class FSTPOverlapHandler: def release_reduce_scatter_memory(self, key, index): self.reduce_scatter_memory_pool[key][index].idle = True - + def _all_gather_block_weight_memory_pool(self, block_index: int): fstp_modules = self.index_to_fstp_modules[block_index] for module in fstp_modules: + if module.bias is not None: + bias_handle = all_gather_raw_bias_memory_pool( + module.bias, + self.process_group, + async_op=True, + module=module, + ) + self.bias_global_handle[module] = bias_handle + weight_handle = all_gather_raw_memory_pool( module.weight, self.process_group, @@ -186,6 +208,9 @@ class FSTPOverlapHandler: def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): handle = self.fstp_global_handle[module] handle.wait() + if module.bias is not None: + bias_handle = self.bias_global_handle[module] + bias_handle.wait() def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): if module in self.fstp_global_handle: diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 8070cbd..8a1281e 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -140,6 +140,21 @@ def all_gather_raw_memory_pool( ) return handle +def all_gather_raw_bias_memory_pool( + input_: Tensor, + process_group: ProcessGroup, + async_op: bool = False, + gather_dim: int = 0, + module: nn.Module = None, +): + handle = torch.distributed.all_gather_into_tensor( + gpc.fstp_handler.get_bias_memory(module=module), + input_.contiguous(), + group=process_group, + async_op=async_op, + ) + return handle + def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): assert my_input.dtype == grad_output.dtype @@ -486,8 +501,11 @@ class FSTPFusedDenseFunc(torch.autograd.Function): handle_weight.wait() # TODO memory pool for bias if bias is not None: - total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) - handle_bias.wait() + if overlap_handler is not None: + total_bias = gpc.fstp_handler.get_bias_memory(module=module) + else: + total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) + handle_bias.wait() else: total_bias = bias else: