mirror of https://github.com/InternLM/InternLM
fix(model/overlap_handler.py): fix last block hook when pp with activation
parent
4851291356
commit
5a18b3b651
|
@ -6,6 +6,7 @@ from typing import Any, Union
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.scheduler import SchedulerHook
|
||||
|
@ -32,6 +33,7 @@ class FSTPOverlapHandler:
|
|||
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
|
||||
self.module_to_index = dict() # key: fstp module; value: transformer block index
|
||||
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||
self.last_block = None
|
||||
self.head = []
|
||||
self.embedding = []
|
||||
self.model_checkpoint = gpc.config.model.checkpoint
|
||||
|
@ -54,6 +56,7 @@ class FSTPOverlapHandler:
|
|||
elif isinstance(children, Embedding1D):
|
||||
self.embedding.append(children)
|
||||
elif isinstance(children, nn.ModuleList):
|
||||
self.last_block = children[len(children) - 1]
|
||||
for idx, block in enumerate(children):
|
||||
self.index_to_fstp_modules[idx] = []
|
||||
for _sub_name, sub in block.named_children():
|
||||
|
@ -150,39 +153,23 @@ class FSTPOverlapHandler:
|
|||
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
def get_reduce_scatter_memory(self, key):
|
||||
return_idx = 0
|
||||
|
||||
# if key not in dict
|
||||
if key not in self.reduce_scatter_memory_pool:
|
||||
self.reduce_scatter_memory_pool[key] = []
|
||||
|
||||
# if the data is empty
|
||||
if len(self.reduce_scatter_memory_pool[key]) == 0:
|
||||
self.reduce_scatter_memory_pool[key].append(
|
||||
torch.zeros(
|
||||
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
||||
).contiguous()
|
||||
)
|
||||
setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
|
||||
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
|
||||
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||
else: # if not empty
|
||||
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
|
||||
if mem_item.idle is True:
|
||||
self.reduce_scatter_memory_pool[key][index].idle = False
|
||||
return_idx = index
|
||||
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||
# if the memory pool is all used
|
||||
cur_len = len(self.reduce_scatter_memory_pool[key])
|
||||
self.reduce_scatter_memory_pool[key].append(
|
||||
torch.zeros(
|
||||
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
||||
).contiguous()
|
||||
)
|
||||
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
|
||||
return_idx = cur_len
|
||||
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
|
||||
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
|
||||
if mem_item.idle is True:
|
||||
self.reduce_scatter_memory_pool[key][index].idle = False
|
||||
return self.reduce_scatter_memory_pool[key][index]
|
||||
|
||||
# if the memory pool is all used
|
||||
cur_len = len(self.reduce_scatter_memory_pool[key])
|
||||
self.reduce_scatter_memory_pool[key].append(
|
||||
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
|
||||
)
|
||||
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
|
||||
setattr(self.reduce_scatter_memory_pool[key][cur_len], "index", cur_len)
|
||||
return self.reduce_scatter_memory_pool[key][cur_len]
|
||||
|
||||
def release_reduce_scatter_memory(self, key, index):
|
||||
self.reduce_scatter_memory_pool[key][index].idle = True
|
||||
|
@ -242,6 +229,18 @@ class FSTPOverlapHandler:
|
|||
self.fstp_global_handle[module] = weight_handle
|
||||
weight_handle.wait()
|
||||
|
||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
fstp_modules = self.index_to_fstp_modules[self.num_blocks - 1]
|
||||
if module in fstp_modules:
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
weight_handle.wait()
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
del self.fstp_global_handle[module]
|
||||
|
@ -301,8 +300,11 @@ class FSTPOverlapHandler:
|
|||
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
||||
|
||||
if self.model_checkpoint:
|
||||
for head in self.head:
|
||||
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
|
||||
if gpc.is_last_rank(parallel_mode=ParallelMode.PIPELINE):
|
||||
for head in self.head:
|
||||
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
|
||||
else:
|
||||
self.last_block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||
|
||||
for out_proj in self.fstp_outs:
|
||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||
|
|
Loading…
Reference in New Issue