diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index db81150..ed0a8d2 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -6,6 +6,7 @@ from typing import Any, Union import torch from torch import nn +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel from internlm.core.scheduler import SchedulerHook @@ -32,6 +33,7 @@ class FSTPOverlapHandler: self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle self.module_to_index = dict() # key: fstp module; value: transformer block index self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules + self.last_block = None self.head = [] self.embedding = [] self.model_checkpoint = gpc.config.model.checkpoint @@ -54,6 +56,7 @@ class FSTPOverlapHandler: elif isinstance(children, Embedding1D): self.embedding.append(children) elif isinstance(children, nn.ModuleList): + self.last_block = children[len(children) - 1] for idx, block in enumerate(children): self.index_to_fstp_modules[idx] = [] for _sub_name, sub in block.named_children(): @@ -150,39 +153,23 @@ class FSTPOverlapHandler: return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name] def get_reduce_scatter_memory(self, key): - return_idx = 0 - # if key not in dict if key not in self.reduce_scatter_memory_pool: self.reduce_scatter_memory_pool[key] = [] - # if the data is empty - if len(self.reduce_scatter_memory_pool[key]) == 0: - self.reduce_scatter_memory_pool[key].append( - torch.zeros( - key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() - ).contiguous() - ) - setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False) - setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx) - return self.reduce_scatter_memory_pool[key][return_idx] - else: # if not empty - for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]): - if mem_item.idle is True: - self.reduce_scatter_memory_pool[key][index].idle = False - return_idx = index - return self.reduce_scatter_memory_pool[key][return_idx] - # if the memory pool is all used - cur_len = len(self.reduce_scatter_memory_pool[key]) - self.reduce_scatter_memory_pool[key].append( - torch.zeros( - key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() - ).contiguous() - ) - setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False) - return_idx = cur_len - setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx) - return self.reduce_scatter_memory_pool[key][return_idx] + for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]): + if mem_item.idle is True: + self.reduce_scatter_memory_pool[key][index].idle = False + return self.reduce_scatter_memory_pool[key][index] + + # if the memory pool is all used + cur_len = len(self.reduce_scatter_memory_pool[key]) + self.reduce_scatter_memory_pool[key].append( + torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() + ) + setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False) + setattr(self.reduce_scatter_memory_pool[key][cur_len], "index", cur_len) + return self.reduce_scatter_memory_pool[key][cur_len] def release_reduce_scatter_memory(self, key, index): self.reduce_scatter_memory_pool[key][index].idle = True @@ -242,6 +229,18 @@ class FSTPOverlapHandler: self.fstp_global_handle[module] = weight_handle weight_handle.wait() + def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): # pylint: disable=W0613 + fstp_modules = self.index_to_fstp_modules[self.num_blocks - 1] + if module in fstp_modules: + weight_handle = all_gather_raw_memory_pool( + module.weight, + self.process_group, + async_op=True, + module=module, + ) + self.fstp_global_handle[module] = weight_handle + weight_handle.wait() + def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 if module in self.fstp_global_handle: del self.fstp_global_handle[module] @@ -301,8 +300,11 @@ class FSTPOverlapHandler: embedding.register_forward_hook(_post_forward_hook_for_embedding) if self.model_checkpoint: - for head in self.head: - head.register_full_backward_pre_hook(_pre_backward_hook_for_head) + if gpc.is_last_rank(parallel_mode=ParallelMode.PIPELINE): + for head in self.head: + head.register_full_backward_pre_hook(_pre_backward_hook_for_head) + else: + self.last_block.register_forward_pre_hook(_pre_forward_hook_for_block) for out_proj in self.fstp_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)