mirror of https://github.com/InternLM/InternLM
284 lines
13 KiB
Python
284 lines
13 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import Any, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from internlm.core.context import global_context as gpc
|
|
from internlm.core.naive_amp import NaiveAMPModel
|
|
from internlm.model.embedding import Embedding1D
|
|
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
|
from internlm.model.utils import (
|
|
all_gather_raw_bias_memory_pool,
|
|
all_gather_raw_memory_pool,
|
|
)
|
|
from internlm.utils.common import get_current_device
|
|
|
|
|
|
class FSTPOverlapHandler:
|
|
"""
|
|
FSTP overlap handler for managing the all-gather and reduce_scatter overlapping.
|
|
"""
|
|
|
|
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
|
self.process_group = process_group
|
|
self.fstp_outs = []
|
|
self.fstp_modules = []
|
|
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
|
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
|
|
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
|
|
self.module_to_index = dict() # key: fstp module; value: transformer block index
|
|
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
|
self.head = []
|
|
self.embedding = []
|
|
|
|
self.reduce_scatter_handlers = {}
|
|
self.zero_const_pool = {}
|
|
|
|
# just want to share same for loop for ModuleList and Module
|
|
if not isinstance(model, nn.ModuleList):
|
|
model = [model]
|
|
|
|
for _chunk in model:
|
|
if isinstance(_chunk, NaiveAMPModel):
|
|
_chunk = _chunk.model
|
|
|
|
for _chunk_name, children in _chunk.named_children():
|
|
if isinstance(children, ScaleColumnParallelLinear):
|
|
self.head.append(children)
|
|
elif isinstance(children, Embedding1D):
|
|
self.embedding.append(children)
|
|
elif isinstance(children, nn.ModuleList):
|
|
for idx, block in enumerate(children):
|
|
self.index_to_fstp_modules[idx] = []
|
|
for _sub_name, sub in block.named_children():
|
|
sub_modules = list(sub.children())
|
|
if len(sub_modules) > 0:
|
|
for name, child in sub.named_children():
|
|
if name == "out_proj":
|
|
self.fstp_outs.append(child)
|
|
self.module_to_index[child] = idx
|
|
if isinstance(child, FSTPLinear):
|
|
self.module_to_index[child] = idx
|
|
self.fstp_modules.append(child)
|
|
self.index_to_fstp_modules[idx].append(child)
|
|
|
|
setattr(child, "_fstp_name", name)
|
|
|
|
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
|
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
|
if child.bias is not None:
|
|
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
|
|
|
self._initialize_memory_pool()
|
|
self._register_sync_parameters_hook()
|
|
|
|
def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
|
|
if size not in self.zero_const_pool:
|
|
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
|
|
|
return self.zero_const_pool[size]
|
|
|
|
def _initialize_module_shape(self):
|
|
hidden_size = gpc.config.HIDDEN_SIZE
|
|
mlp_ratio = gpc.config.MLP_RATIO
|
|
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
|
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
|
|
|
self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size)
|
|
self.module_shape["out_proj"] = (hidden_size, hidden_size)
|
|
self.module_shape["w1"] = (mlp_hidden_size, hidden_size)
|
|
self.module_shape["w2"] = (mlp_hidden_size, hidden_size)
|
|
self.module_shape["w3"] = (hidden_size, mlp_hidden_size)
|
|
|
|
def _initialize_memory_pool(self) -> None:
|
|
# allocate memory pool
|
|
self.all_gather_memory_pool = []
|
|
self.all_gather_bias_memory_pool = []
|
|
self.reduce_scatter_memory_pool = {}
|
|
self.module_shape = {}
|
|
|
|
self._initialize_module_shape()
|
|
dtype = gpc.config.model.get("dtype", torch.half)
|
|
device = get_current_device()
|
|
|
|
for _ in range(2):
|
|
weight = {}
|
|
for name in self.module_name:
|
|
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
|
|
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
|
|
|
def clear_memory_pool(self) -> None:
|
|
self.zero_const_pool = {}
|
|
self.reduce_scatter_memory_pool = {}
|
|
|
|
def get_all_gather_memory(self, module):
|
|
block_index = self.module_to_index[module]
|
|
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
|
|
|
def get_bias_memory(self, module: nn.Module):
|
|
block_index = self.module_to_index[module]
|
|
# if the bias memory pool is empty or module has been not allocated memory
|
|
# import pdb; pdb.set_trace()
|
|
if len(self.all_gather_bias_memory_pool) == 0:
|
|
for _ in range(2):
|
|
weight = {}
|
|
weight[module._fstp_name] = torch.zeros(
|
|
self.module_shape[module._fstp_name][0],
|
|
dtype=gpc.config.model.get("dtype", torch.half),
|
|
device=get_current_device(),
|
|
).contiguous()
|
|
self.all_gather_bias_memory_pool.append(weight)
|
|
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
|
|
for i in range(2):
|
|
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
|
|
self.module_shape[module._fstp_name][0],
|
|
dtype=gpc.config.model.get("dtype", torch.half),
|
|
device=get_current_device(),
|
|
).contiguous()
|
|
|
|
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
|
|
|
|
def get_reduce_scatter_memory(self, key):
|
|
return_idx = 0
|
|
|
|
# if key not in dict
|
|
if key not in self.reduce_scatter_memory_pool:
|
|
self.reduce_scatter_memory_pool[key] = []
|
|
|
|
# if the data is empty
|
|
if len(self.reduce_scatter_memory_pool[key]) == 0:
|
|
self.reduce_scatter_memory_pool[key].append(
|
|
torch.zeros(
|
|
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
|
).contiguous()
|
|
)
|
|
setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
|
|
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
|
|
return self.reduce_scatter_memory_pool[key][return_idx]
|
|
else: # if not empty
|
|
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
|
|
if mem_item.idle is True:
|
|
self.reduce_scatter_memory_pool[key][index].idle = False
|
|
return_idx = index
|
|
return self.reduce_scatter_memory_pool[key][return_idx]
|
|
# if the memory pool is all used
|
|
cur_len = len(self.reduce_scatter_memory_pool[key])
|
|
self.reduce_scatter_memory_pool[key].append(
|
|
torch.zeros(
|
|
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
|
).contiguous()
|
|
)
|
|
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
|
|
return_idx = cur_len
|
|
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
|
|
return self.reduce_scatter_memory_pool[key][return_idx]
|
|
|
|
def release_reduce_scatter_memory(self, key, index):
|
|
self.reduce_scatter_memory_pool[key][index].idle = True
|
|
|
|
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
|
fstp_modules = self.index_to_fstp_modules[block_index]
|
|
for module in fstp_modules:
|
|
if module.bias is not None:
|
|
bias_handle = all_gather_raw_bias_memory_pool(
|
|
module.bias,
|
|
self.process_group,
|
|
async_op=True,
|
|
module=module,
|
|
)
|
|
self.bias_global_handle[module] = bias_handle
|
|
|
|
weight_handle = all_gather_raw_memory_pool(
|
|
module.weight,
|
|
self.process_group,
|
|
async_op=True,
|
|
module=module,
|
|
)
|
|
self.fstp_global_handle[module] = weight_handle
|
|
|
|
def _register_sync_parameters_hook(self) -> None:
|
|
"""
|
|
register forward hooks and backward hooks for fstp modules.
|
|
"""
|
|
|
|
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
|
self._all_gather_block_weight_memory_pool(0)
|
|
|
|
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
|
block_index = self.module_to_index[module]
|
|
# start the all-gather for next block
|
|
if block_index + 1 < gpc.config.NUM_LAYER:
|
|
self._all_gather_block_weight_memory_pool(block_index + 1)
|
|
|
|
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
|
handle = self.fstp_global_handle[module]
|
|
handle.wait()
|
|
if module.bias is not None:
|
|
bias_handle = self.bias_global_handle[module]
|
|
bias_handle.wait()
|
|
|
|
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
|
if module in self.fstp_global_handle:
|
|
del self.fstp_global_handle[module]
|
|
|
|
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
|
|
first_backward_module = self.fstp_modules[-1]
|
|
weight_handle = all_gather_raw_memory_pool(
|
|
first_backward_module.weight,
|
|
self.process_group,
|
|
async_op=True,
|
|
module=first_backward_module,
|
|
)
|
|
self.fstp_global_handle[first_backward_module] = weight_handle
|
|
|
|
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
|
# wait handle for current module
|
|
weight_handle = self.fstp_global_handle[module]
|
|
weight_handle.wait()
|
|
|
|
# start the all-gather for next module
|
|
module_index = self.fstp_modules.index(module)
|
|
if module_index - 1 >= 0:
|
|
next_module = self.fstp_modules[module_index - 1]
|
|
weight_handle = all_gather_raw_memory_pool(
|
|
next_module.weight,
|
|
self.process_group,
|
|
async_op=True,
|
|
module=next_module,
|
|
)
|
|
self.fstp_global_handle[next_module] = weight_handle
|
|
|
|
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
|
|
if module in self.fstp_global_handle:
|
|
del self.fstp_global_handle[module]
|
|
|
|
# register forward hooks
|
|
# 1. register post_forward_hook @embedding module to prefetch for block 0
|
|
# 2. register pre_forward_hook @out_proj module to prefetch for next block,
|
|
# notice that next block's all_gather op should be after current block's all_to_all op
|
|
# 3. register pre_forward_hook @fstp_module to wait handle for current module
|
|
# 4. register post_forward_hook @fstp_module to release resource
|
|
for embedding in self.embedding:
|
|
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
|
|
|
for out_proj in self.fstp_outs:
|
|
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
|
|
|
for module in self.fstp_modules:
|
|
module.register_forward_pre_hook(_pre_forward_hook_for_module)
|
|
module.register_forward_hook(_post_forward_hook_for_module)
|
|
|
|
# register backward hooks
|
|
# 1. register post_backward_hook @head module to prefetch for the last block's last module
|
|
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
|
|
# 3. register post_backward_hook @fstp_module to release resource
|
|
for head in self.head:
|
|
head.register_full_backward_hook(_post_backward_hook_for_head)
|
|
|
|
for module in self.fstp_modules:
|
|
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
|
module.register_full_backward_hook(_post_backward_hook_for_module)
|