mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): add memory_pool switch and refactor overlap handler
parent
b5e4d04a9a
commit
74754397df
|
@ -163,7 +163,7 @@ pipeline parallel (dict):
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=4, sp="intern", intern_overlap=True),
|
||||
tensor=dict(size=4, sp="intern", intern_overlap=True, memory_pool=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from internlm.core.scheduler import SchedulerHook
|
|||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
||||
from internlm.model.utils import (
|
||||
all_gather_raw,
|
||||
all_gather_raw_bias_memory_pool,
|
||||
all_gather_raw_memory_pool,
|
||||
)
|
||||
|
@ -29,14 +30,17 @@ class FSTPOverlapHandler:
|
|||
self.fstp_outs = []
|
||||
self.fstp_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
|
||||
self.weight_global_handle = dict() # key: fstp module; value: module global all-gather op handle
|
||||
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
|
||||
self.weight_global_output = dict() # key: fstp module; value: module global weight after all-gather op
|
||||
self.bias_global_output = dict() # key: fstp module; value: module bias global weight after all-gather op
|
||||
self.module_to_index = dict() # key: fstp module; value: transformer block index
|
||||
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||
self.last_block = None
|
||||
self.head = []
|
||||
self.embedding = []
|
||||
self.model_checkpoint = gpc.config.model.checkpoint
|
||||
self.enable_memory_pool = gpc.config.parallel["tensor"].get("memory_pool", False)
|
||||
self.is_forward = True
|
||||
|
||||
self.reduce_scatter_handlers = {}
|
||||
|
@ -60,34 +64,36 @@ class FSTPOverlapHandler:
|
|||
for idx, block in enumerate(children):
|
||||
self.index_to_fstp_modules[idx] = []
|
||||
for _sub_name, sub in block.named_children():
|
||||
sub_modules = list(sub.children())
|
||||
if len(sub_modules) > 0:
|
||||
for name, child in sub.named_children():
|
||||
if name == "out_proj":
|
||||
self.fstp_outs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.module_to_index[child] = idx
|
||||
self.fstp_modules.append(child)
|
||||
self.index_to_fstp_modules[idx].append(child)
|
||||
for name, child in sub.named_children():
|
||||
if name == "out_proj":
|
||||
self.fstp_outs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.module_to_index[child] = idx
|
||||
self.fstp_modules.append(child)
|
||||
self.index_to_fstp_modules[idx].append(child)
|
||||
|
||||
setattr(child, "_fstp_name", name)
|
||||
setattr(child, "_fstp_name", name)
|
||||
|
||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
|
||||
self.num_blocks = len(self.index_to_fstp_modules)
|
||||
|
||||
self._initialize_memory_pool()
|
||||
if self.enable_memory_pool:
|
||||
self._initialize_memory_pool()
|
||||
self._register_sync_parameters_hook()
|
||||
|
||||
def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
|
||||
if size not in self.zero_const_pool:
|
||||
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
||||
if self.enable_memory_pool:
|
||||
if size not in self.zero_const_pool:
|
||||
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
||||
|
||||
return self.zero_const_pool[size]
|
||||
return self.zero_const_pool[size]
|
||||
else:
|
||||
return torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
||||
|
||||
def set_forward_mode(self, flag):
|
||||
self.is_forward = flag
|
||||
|
@ -122,14 +128,20 @@ class FSTPOverlapHandler:
|
|||
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
||||
|
||||
def clear_memory_pool(self) -> None:
|
||||
assert self.enable_memory_pool
|
||||
|
||||
self.zero_const_pool = {}
|
||||
self.reduce_scatter_memory_pool = {}
|
||||
|
||||
def get_all_gather_memory(self, module):
|
||||
def _get_weight_from_memory_pool(self, module):
|
||||
assert self.enable_memory_pool
|
||||
|
||||
block_index = self.module_to_index[module]
|
||||
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
def get_bias_memory(self, module: nn.Module):
|
||||
def _get_bias_from_memory_pool(self, module: nn.Module):
|
||||
assert self.enable_memory_pool
|
||||
|
||||
block_index = self.module_to_index[module]
|
||||
# if the bias memory pool is empty or module has been not allocated memory
|
||||
if len(self.all_gather_bias_memory_pool) == 0:
|
||||
|
@ -151,7 +163,21 @@ class FSTPOverlapHandler:
|
|||
|
||||
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
def get_weight_all_gather(self, module):
|
||||
if self.enable_memory_pool:
|
||||
return self._get_weight_from_memory_pool(module)
|
||||
else:
|
||||
return self.weight_global_output[module]
|
||||
|
||||
def get_bias_all_gather(self, module):
|
||||
if self.enable_memory_pool:
|
||||
return self._get_bias_from_memory_pool(module)
|
||||
else:
|
||||
return self.bias_global_output[module]
|
||||
|
||||
def get_reduce_scatter_memory(self, key):
|
||||
assert self.enable_memory_pool
|
||||
|
||||
# if key not in dict
|
||||
if key not in self.reduce_scatter_memory_pool:
|
||||
self.reduce_scatter_memory_pool[key] = []
|
||||
|
@ -171,11 +197,11 @@ class FSTPOverlapHandler:
|
|||
return self.reduce_scatter_memory_pool[key][cur_len]
|
||||
|
||||
def release_reduce_scatter_memory(self, key, index):
|
||||
assert self.enable_memory_pool
|
||||
self.reduce_scatter_memory_pool[key][index].idle = True
|
||||
|
||||
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
||||
fstp_modules = self.index_to_fstp_modules[block_index]
|
||||
for module in fstp_modules:
|
||||
def _all_gather_module_weight(self, module):
|
||||
if self.enable_memory_pool:
|
||||
if module.bias is not None:
|
||||
bias_handle = all_gather_raw_bias_memory_pool(
|
||||
module.bias,
|
||||
|
@ -191,103 +217,102 @@ class FSTPOverlapHandler:
|
|||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
self.weight_global_handle[module] = weight_handle
|
||||
else:
|
||||
if module.bias is not None:
|
||||
bias_output, bias_handle = all_gather_raw(
|
||||
module.bias,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
)
|
||||
self.bias_global_handle[module] = bias_handle
|
||||
self.bias_global_output[module] = bias_output
|
||||
|
||||
weight_output, weight_handle = all_gather_raw(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
)
|
||||
self.weight_global_handle[module] = weight_handle
|
||||
self.weight_global_output[module] = weight_output
|
||||
|
||||
def _all_gather_block_weight(self, block_index: int):
|
||||
fstp_modules = self.index_to_fstp_modules[block_index]
|
||||
for module in fstp_modules:
|
||||
self._all_gather_module_weight(module)
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
"""
|
||||
register forward hooks and backward hooks for fstp modules.
|
||||
"""
|
||||
|
||||
def _wait_handle(module):
|
||||
handle = self.weight_global_handle[module]
|
||||
handle.wait()
|
||||
if module.bias is not None:
|
||||
bias_handle = self.bias_global_handle[module]
|
||||
bias_handle.wait()
|
||||
|
||||
def _clear_handle(module):
|
||||
if module in self.weight_global_handle:
|
||||
del self.weight_global_handle[module]
|
||||
if module in self.bias_global_handle:
|
||||
del self.bias_global_handle[module]
|
||||
# if module in self.weight_global_output:
|
||||
# del self.weight_global_output[module]
|
||||
# if module in self.bias_global_output:
|
||||
# del self.bias_global_output[module]
|
||||
|
||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
self._all_gather_block_weight_memory_pool(0)
|
||||
self._all_gather_block_weight(0)
|
||||
|
||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
block_index = self.module_to_index[module]
|
||||
if self.model_checkpoint and self.is_forward is False:
|
||||
if block_index - 1 >= 0:
|
||||
self._all_gather_block_weight_memory_pool(block_index - 1)
|
||||
self._all_gather_block_weight(block_index - 1)
|
||||
else:
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < self.num_blocks:
|
||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||
self._all_gather_block_weight(block_index + 1)
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
handle = self.fstp_global_handle[module]
|
||||
handle.wait()
|
||||
if module.bias is not None:
|
||||
bias_handle = self.bias_global_handle[module]
|
||||
bias_handle.wait()
|
||||
else:
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
weight_handle.wait()
|
||||
if module not in self.weight_global_handle:
|
||||
self._all_gather_module_weight(module)
|
||||
|
||||
_wait_handle(module)
|
||||
|
||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
fstp_modules = self.index_to_fstp_modules[self.num_blocks - 1]
|
||||
if module in fstp_modules:
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
weight_handle.wait()
|
||||
self._all_gather_module_weight(module)
|
||||
_wait_handle(module)
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
del self.fstp_global_handle[module]
|
||||
_clear_handle(module)
|
||||
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
|
||||
first_backward_module = self.fstp_modules[-1]
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
first_backward_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=first_backward_module,
|
||||
)
|
||||
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||
self._all_gather_module_weight(self.fstp_modules[-1])
|
||||
|
||||
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
|
||||
if self.is_forward is False:
|
||||
self._all_gather_block_weight_memory_pool(self.num_blocks - 1)
|
||||
self._all_gather_block_weight(self.num_blocks - 1)
|
||||
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
||||
# wait handle for current module
|
||||
if module in self.fstp_global_handle:
|
||||
weight_handle = self.fstp_global_handle[module]
|
||||
weight_handle.wait()
|
||||
else:
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
weight_handle.wait()
|
||||
if module not in self.weight_global_handle:
|
||||
self._all_gather_module_weight(module)
|
||||
|
||||
_wait_handle(module)
|
||||
|
||||
# start the all-gather for next module
|
||||
module_index = self.fstp_modules.index(module)
|
||||
if module_index - 1 >= 0:
|
||||
next_module = self.fstp_modules[module_index - 1]
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
next_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=next_module,
|
||||
)
|
||||
self.fstp_global_handle[next_module] = weight_handle
|
||||
self._all_gather_module_weight(next_module)
|
||||
|
||||
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
del self.fstp_global_handle[module]
|
||||
_clear_handle(module)
|
||||
|
||||
# register forward hooks
|
||||
# 1. register post_forward_hook @embedding module to prefetch for block 0
|
||||
|
|
|
@ -132,7 +132,7 @@ def all_gather_raw_memory_pool(
|
|||
module: nn.Module = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
gpc.fstp_handler.get_all_gather_memory(module=module),
|
||||
gpc.fstp_handler.get_weight_all_gather(module=module),
|
||||
input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op,
|
||||
|
@ -147,7 +147,7 @@ def all_gather_raw_bias_memory_pool(
|
|||
module: nn.Module = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
gpc.fstp_handler.get_bias_memory(module=module),
|
||||
gpc.fstp_handler.get_bias_all_gather(module=module),
|
||||
input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op,
|
||||
|
@ -177,8 +177,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
|
|||
def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||
output = gpc.fstp_handler.get_reduce_scatter_memory(size)
|
||||
if gpc.fstp_handler.enable_memory_pool:
|
||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||
output = gpc.fstp_handler.get_reduce_scatter_memory(size)
|
||||
else:
|
||||
output = torch.empty(
|
||||
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
||||
).contiguous()
|
||||
handle = torch.distributed.reduce_scatter_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
|
@ -493,14 +498,14 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
if world_size > 1:
|
||||
# do all_gather for weight and bias before actual computation
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||
total_weight = gpc.fstp_handler.get_weight_all_gather(module=module)
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
# TODO memory pool for bias
|
||||
|
||||
if bias is not None:
|
||||
if overlap_handler is not None:
|
||||
total_bias = gpc.fstp_handler.get_bias_memory(module=module)
|
||||
total_bias = gpc.fstp_handler.get_bias_all_gather(module=module)
|
||||
else:
|
||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
handle_bias.wait()
|
||||
|
@ -554,7 +559,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||
total_weight = gpc.fstp_handler.get_weight_all_gather(module=module)
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
|
@ -655,7 +660,7 @@ class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc):
|
|||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||
total_weight = gpc.fstp_handler.get_weight_all_gather(module=module)
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
|
|
|
@ -389,7 +389,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
_param.grad.add_(_grad)
|
||||
|
||||
# release cuda memory.
|
||||
self._fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index)
|
||||
if self._fstp_handler.enable_memory_pool:
|
||||
self._fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index)
|
||||
_grad = None
|
||||
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
||||
|
||||
bucket.reset_by_rank(reduce_rank)
|
||||
|
|
2
train.py
2
train.py
|
@ -324,7 +324,7 @@ def main(args):
|
|||
if batch_count % 2 == 0:
|
||||
prof.step()
|
||||
|
||||
if gpc.fstp_handler is not None:
|
||||
if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool:
|
||||
gpc.fstp_handler.clear_memory_pool()
|
||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
|
Loading…
Reference in New Issue