fix(model/overlap_handler.py): fix last block hook when pp with activation

pull/436/head
huangting4201 2023-11-02 16:05:07 +08:00
parent 4851291356
commit 5a18b3b651
1 changed files with 33 additions and 31 deletions

View File

@ -6,6 +6,7 @@ from typing import Any, Union
import torch
from torch import nn
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.scheduler import SchedulerHook
@ -32,6 +33,7 @@ class FSTPOverlapHandler:
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
self.module_to_index = dict() # key: fstp module; value: transformer block index
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
self.last_block = None
self.head = []
self.embedding = []
self.model_checkpoint = gpc.config.model.checkpoint
@ -54,6 +56,7 @@ class FSTPOverlapHandler:
elif isinstance(children, Embedding1D):
self.embedding.append(children)
elif isinstance(children, nn.ModuleList):
self.last_block = children[len(children) - 1]
for idx, block in enumerate(children):
self.index_to_fstp_modules[idx] = []
for _sub_name, sub in block.named_children():
@ -150,39 +153,23 @@ class FSTPOverlapHandler:
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
def get_reduce_scatter_memory(self, key):
return_idx = 0
# if key not in dict
if key not in self.reduce_scatter_memory_pool:
self.reduce_scatter_memory_pool[key] = []
# if the data is empty
if len(self.reduce_scatter_memory_pool[key]) == 0:
self.reduce_scatter_memory_pool[key].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]
else: # if not empty
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
if mem_item.idle is True:
self.reduce_scatter_memory_pool[key][index].idle = False
return_idx = index
return self.reduce_scatter_memory_pool[key][return_idx]
return self.reduce_scatter_memory_pool[key][index]
# if the memory pool is all used
cur_len = len(self.reduce_scatter_memory_pool[key])
self.reduce_scatter_memory_pool[key].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
)
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
return_idx = cur_len
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]
setattr(self.reduce_scatter_memory_pool[key][cur_len], "index", cur_len)
return self.reduce_scatter_memory_pool[key][cur_len]
def release_reduce_scatter_memory(self, key, index):
self.reduce_scatter_memory_pool[key][index].idle = True
@ -242,6 +229,18 @@ class FSTPOverlapHandler:
self.fstp_global_handle[module] = weight_handle
weight_handle.wait()
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): # pylint: disable=W0613
fstp_modules = self.index_to_fstp_modules[self.num_blocks - 1]
if module in fstp_modules:
weight_handle = all_gather_raw_memory_pool(
module.weight,
self.process_group,
async_op=True,
module=module,
)
self.fstp_global_handle[module] = weight_handle
weight_handle.wait()
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
if module in self.fstp_global_handle:
del self.fstp_global_handle[module]
@ -301,8 +300,11 @@ class FSTPOverlapHandler:
embedding.register_forward_hook(_post_forward_hook_for_embedding)
if self.model_checkpoint:
if gpc.is_last_rank(parallel_mode=ParallelMode.PIPELINE):
for head in self.head:
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
else:
self.last_block.register_forward_pre_hook(_pre_forward_hook_for_block)
for out_proj in self.fstp_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)