mirror of https://github.com/InternLM/InternLM
merge reduce-scatter
commit
1804d01bb3
|
@ -152,19 +152,19 @@ zero1 parallel (dict):
|
|||
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
|
||||
tensor parallel (dict):
|
||||
1. size: int, the size of tensor parallel.
|
||||
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp',
|
||||
the sequence_parallel should be True.
|
||||
2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
|
||||
defaults to 'none', means the sequence parallel will be disabled.
|
||||
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
|
||||
defaults to False.
|
||||
pipeline parallel (dict):
|
||||
1. size: int, the size of pipeline parallel.
|
||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
|
||||
defaults to False.
|
||||
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=8, sp="intern", intern_overlap=True),
|
||||
tensor=dict(size=8, sp="megatron", intern_overlap=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=True,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -194,7 +194,6 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
_output, _loss, _moe_loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||
)
|
||||
engine.optimizer.reset_reduce_bucket()
|
||||
|
||||
if return_loss:
|
||||
loss += _loss
|
||||
|
|
|
@ -306,15 +306,20 @@ def args_sanity_check():
|
|||
), "sequence parallel does not support use_flash_attn=False"
|
||||
|
||||
if isinstance(gpc.config.parallel["tensor"], int):
|
||||
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp")
|
||||
|
||||
if gpc.config.parallel["tensor"].get("mode", None) is None:
|
||||
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
|
||||
|
||||
if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
|
||||
assert (
|
||||
gpc.config.parallel.sequence_parallel is True
|
||||
), "when the tp_mode is fstp, the sequence_parallel should be True."
|
||||
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
|
||||
if gpc.config.parallel["tensor"].get("sp", None) is None:
|
||||
gpc.config.parallel["tensor"]["sp"] = "none"
|
||||
if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
|
||||
gpc.config.parallel["tensor"]["intern_overlap"] = False
|
||||
assert gpc.config.parallel["tensor"].get("sp", None) in [
|
||||
"none",
|
||||
"megatron",
|
||||
"flash-attn",
|
||||
"intern",
|
||||
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
|
||||
# adapt to old version's sequence parallel config
|
||||
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
|
||||
gpc.config.parallel.sequence_parallel = True
|
||||
|
||||
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
||||
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
||||
|
|
|
@ -451,49 +451,33 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
self.index_to_fsdp_modules[idx].append(child)
|
||||
self.module_name_index[child] = index
|
||||
index = index + 1
|
||||
|
||||
|
||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
# _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
# setattr(child.weight, "_fstp_all_reduce_str", f"{_full_name}.weight")
|
||||
# if child.bias is not None:
|
||||
# setattr(child.bias, "_fstp_all_reduce_str", f"{_full_name}.bias")
|
||||
else:
|
||||
continue
|
||||
elif isinstance(children, ScaleColumnParallelLinear):
|
||||
self.head.append(children)
|
||||
elif isinstance(children, Embedding1D):
|
||||
self.embedding.append(children)
|
||||
|
||||
def get_zero_by_shape(self, size:tuple, dtype, device) -> torch.Tensor:
|
||||
if size not in self.zero_const_pool:
|
||||
|
||||
def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
|
||||
if size not in self.zero_const_pool:
|
||||
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
||||
|
||||
|
||||
return self.zero_const_pool[size]
|
||||
|
||||
|
||||
def _all_gather_block_weight(self, block_index: int):
|
||||
#block = self.index_to_block[block_index]
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
# self.block_handles[block] = []
|
||||
for module in fsdp_modules:
|
||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
self.FSTP_global_handle[module] = weight_handle
|
||||
# self.block_handles[block].append(weight_handle)
|
||||
|
||||
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
# self.block_handles[block] = []
|
||||
for module in fsdp_modules:
|
||||
module_index = self.module_name_index[module]
|
||||
name = self.module_name[module_index]
|
||||
weight_handle = all_gather_raw_memory_pool(module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name)
|
||||
# self.FSTP_global_weights[module] = total_weight
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
|
||||
)
|
||||
self.FSTP_global_handle[module] = weight_handle
|
||||
# self.block_handles[block].append(weight_handle)
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
"""
|
||||
|
@ -510,41 +494,14 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
block_index = self.module_to_index[module]
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
# self._all_gather_block_weight(block_index + 1)
|
||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||
|
||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
||||
block_index = self.block_to_index[block]
|
||||
if block_index == 0:
|
||||
# all gather weight for block 0
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
for module in fsdp_modules:
|
||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handle.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
else:
|
||||
# wait handle for current block
|
||||
handles = self.block_handles[block]
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
|
||||
# self._all_gather_block_weight(0)
|
||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
|
||||
self._all_gather_block_weight_memory_pool(0)
|
||||
|
||||
|
||||
def _post_forward_hook_for_block(block: nn.Module, input, output):
|
||||
block_index = self.block_to_index[block]
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
if block in self.block_handles:
|
||||
del self.block_handles[block]
|
||||
for module in fsdp_modules:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,):
|
||||
block_index = self.module_to_index[module]
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||
handle = self.FSTP_global_handle[module]
|
||||
handle.wait()
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
||||
if module in self.FSTP_global_weights:
|
||||
|
@ -552,67 +509,44 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
if module in self.FSTP_global_handle:
|
||||
del self.FSTP_global_handle[module]
|
||||
|
||||
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
||||
# import pdb; pdb.set_trace()
|
||||
block_index = self.block_to_index[block]
|
||||
# if block_index == gpc.config.NUM_LAYER - 1:
|
||||
# # all gather weight for the last block
|
||||
# fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
# for module in fsdp_modules:
|
||||
# total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
# weight_handle.wait()
|
||||
# self.FSTP_global_weights[module] = total_weight
|
||||
# else:
|
||||
# # wait handle for current block
|
||||
# handles = self.block_handles[block]
|
||||
# for handle in handles:
|
||||
# handle.wait()
|
||||
# if block_index == gpc.config.NUM_LAYER - 1:
|
||||
# self._all_gather_block_weight(block_index)
|
||||
# start the all-gather for next block
|
||||
if block_index - 1 >= 0:
|
||||
self._all_gather_block_weight(block_index - 1)
|
||||
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||
first_module = self.block_module[gpc.config.NUM_LAYER - 1][4]
|
||||
total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_handle[first_module] = weight_handler
|
||||
self.FSTP_global_weights[first_module] = total_weight
|
||||
|
||||
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
||||
block_index = self.block_to_index[block]
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
if block in self.block_handles:
|
||||
del self.block_handles[block]
|
||||
for module in fsdp_modules:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output):
|
||||
block_index = self.module_to_index[module]
|
||||
name_index = self.module_name_index[module]
|
||||
|
||||
|
||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler = self.FSTP_global_handle[module]
|
||||
weight_handler.wait()
|
||||
# self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
next_name = self.module_name[name_index - 1]
|
||||
weights_handler = all_gather_raw_memory_pool(
|
||||
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=next_name
|
||||
next_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
block_index=block_index,
|
||||
module_name=next_name,
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
elif name_index == 0:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
|
||||
|
||||
if block_index - 1 >= 0:
|
||||
next_module = self.block_module[block_index - 1][4]
|
||||
name = self.module_name[4]
|
||||
weights_handler = all_gather_raw_memory_pool(
|
||||
next_module.weight, self.process_group, async_op=True, block_index=block_index - 1, module_name=name,
|
||||
next_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
block_index=block_index - 1,
|
||||
module_name=name,
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
else:
|
||||
|
@ -625,76 +559,24 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
# if module in self.FSTP_global_handle:
|
||||
# handler = self.FSTP_global_handle[module]
|
||||
# handler.wait()
|
||||
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
||||
block_index = self.module_to_index[module]
|
||||
name_index = self.module_name_index[module]
|
||||
|
||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler = self.FSTP_global_handle[module]
|
||||
weight_handler.wait()
|
||||
# self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
elif name_index == 0:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
|
||||
if block_index - 1 >= 0:
|
||||
next_module = self.block_module[block_index - 1][4]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
# if module in self.FSTP_global_handle:
|
||||
# handler = self.FSTP_global_handle[module]
|
||||
# handler.wait()
|
||||
|
||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||
if module in self.FSTP_global_weights:
|
||||
del self.FSTP_global_weights[module]
|
||||
if module in self.FSTP_global_handle:
|
||||
del self.FSTP_global_handle[module]
|
||||
|
||||
|
||||
for embedding in self.embedding:
|
||||
embedding.register_forward_hook(_pre_forward_hook_for_embedding)
|
||||
|
||||
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
||||
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
|
||||
# for block in self.FSTP_blocks:
|
||||
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||
# block.register_forward_hook(_post_forward_hook_for_block)
|
||||
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||
|
||||
for out_proj in self.FSTP_outs:
|
||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||
|
||||
# for wqkv in self.FSTP_wqkvs:
|
||||
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
||||
|
||||
for module in self.FSTP_modules:
|
||||
module.register_forward_pre_hook(_pre_forward_hook_for_module)
|
||||
module.register_forward_hook(_post_forward_hook_for_module)
|
||||
# module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
||||
module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool)
|
||||
module.register_full_backward_hook(_post_backward_hook_for_module)
|
||||
|
|
|
@ -396,7 +396,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
assert len(indexes) == 1
|
||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
||||
# if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
|
||||
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
|
||||
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
||||
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.utils.distributed import all_reduce_raw #, reduce_scatter_raw
|
||||
from flash_attn.utils.distributed import all_reduce_raw
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -125,9 +125,20 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
|
|||
)
|
||||
return output, handle
|
||||
|
||||
def all_gather_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0, block_index: int = None, module_name: str = None):
|
||||
|
||||
def all_gather_raw_memory_pool(
|
||||
input_: Tensor,
|
||||
process_group: ProcessGroup,
|
||||
async_op: bool = False,
|
||||
gather_dim: int = 0,
|
||||
block_index: int = None,
|
||||
module_name: str = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
gpc.config.block_memory[block_index % 2][module_name], input_.contiguous(), group=process_group, async_op=async_op
|
||||
gpc.config.block_memory[block_index % 2][module_name],
|
||||
input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op,
|
||||
)
|
||||
return handle
|
||||
|
||||
|
@ -142,23 +153,25 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
|
|||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
|
||||
dtype=input_.dtype, device=input_.device).contiguous()
|
||||
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op)
|
||||
output = torch.empty(
|
||||
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
||||
).contiguous()
|
||||
handle = torch.distributed.reduce_scatter_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
return output, handle
|
||||
|
||||
|
||||
def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||
index = check_reduce_scatter_memory_pool(size)
|
||||
output = gpc.config.reduce_scatter_memory[size]['data'][index]
|
||||
output = gpc.config.reduce_scatter_memory[size]["data"][index]
|
||||
setattr(output, "index", index)
|
||||
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op)
|
||||
handle = torch.distributed.reduce_scatter_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
return output, handle
|
||||
|
||||
|
||||
|
@ -444,7 +457,18 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, overlap_handler=None, block_index=None, module_name=None):
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
return_residual=False,
|
||||
process_group=None,
|
||||
module=None,
|
||||
overlap_handler=None,
|
||||
block_index=None,
|
||||
module_name=None,
|
||||
):
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
|
@ -506,7 +530,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
overlap_handler = ctx.overlap_handler
|
||||
block_index = ctx.block_index
|
||||
module_name = ctx.module_name
|
||||
|
||||
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight, bias = ctx.saved_tensors
|
||||
total_x = x
|
||||
|
@ -540,7 +564,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||
grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||
if grad_bias is not None:
|
||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
|
||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(
|
||||
grad_bias, process_group, async_op=True
|
||||
)
|
||||
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||
grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||
|
@ -619,7 +645,9 @@ def fstp_fused_dense_func(
|
|||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
)
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler, block_index, module_name)
|
||||
return FSTPFusedDenseFunc.apply(
|
||||
x, weight, bias, return_residual, process_group, module, handler, block_index, module_name
|
||||
)
|
||||
else:
|
||||
assert process_group is None
|
||||
out = F.linear(x, weight, bias)
|
||||
|
@ -666,36 +694,37 @@ def Silu(w1_o, w2_o):
|
|||
|
||||
Silu = torch.jit.script(Silu)
|
||||
|
||||
|
||||
def check_reduce_scatter_memory_pool(key):
|
||||
|
||||
return_idx = 0
|
||||
|
||||
|
||||
# if key not in dict
|
||||
if key not in gpc.config.reduce_scatter_memory:
|
||||
gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []}
|
||||
|
||||
gpc.config.reduce_scatter_memory[key] = {"data": [], "used": []}
|
||||
|
||||
# if the data is empty
|
||||
if len(gpc.config.reduce_scatter_memory[key]['data']) == 0:
|
||||
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous())
|
||||
gpc.config.reduce_scatter_memory[key]['used'].append(True)
|
||||
if len(gpc.config.reduce_scatter_memory[key]["data"]) == 0:
|
||||
gpc.config.reduce_scatter_memory[key]["data"].append(
|
||||
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
|
||||
)
|
||||
gpc.config.reduce_scatter_memory[key]["used"].append(True)
|
||||
return_idx = 0
|
||||
return return_idx
|
||||
else: # if not empty
|
||||
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']):
|
||||
if used == False:
|
||||
gpc.config.reduce_scatter_memory[key]['used'][index] = True
|
||||
else: # if not empty
|
||||
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]["used"]):
|
||||
if used is False:
|
||||
gpc.config.reduce_scatter_memory[key]["used"][index] = True
|
||||
return_idx = index
|
||||
return return_idx
|
||||
# if the memory pool is all used
|
||||
length = len(gpc.config.reduce_scatter_memory[key]['data'])
|
||||
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous())
|
||||
gpc.config.reduce_scatter_memory[key]['used'].append(True)
|
||||
length = len(gpc.config.reduce_scatter_memory[key]["data"])
|
||||
gpc.config.reduce_scatter_memory[key]["data"].append(
|
||||
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
|
||||
)
|
||||
gpc.config.reduce_scatter_memory[key]["used"].append(True)
|
||||
return_idx = length
|
||||
return return_idx
|
||||
|
||||
|
||||
def release_reduce_scatter_memory_pool(size, index):
|
||||
gpc.config.reduce_scatter_memory[size]['used'][index] = False
|
||||
gpc.config.reduce_scatter_memory[size]["used"][index] = False
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -10,7 +11,7 @@ from torch.optim import Optimizer
|
|||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
|
||||
from internlm.model.utils import release_reduce_scatter_memory_pool
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
|
@ -65,7 +66,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
hysteresis = grad_scal_cfg.hysteresis
|
||||
max_scale = grad_scal_cfg.max_scale
|
||||
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
|
||||
self._fstp_handler = None
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
self._fstp_handler = gpc.config.fstp_handler
|
||||
|
||||
# Zero related args
|
||||
|
@ -85,8 +87,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||
self._bucket_store = []
|
||||
self._bucket_store_2 = []
|
||||
self._bucket_store: List[BucketStore] = []
|
||||
self._accum_grad_buckets: List[BucketStore] = []
|
||||
self._bucket_in_progress = []
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
|
@ -155,7 +157,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
|
||||
self._broadcast_parallel_mode.append(zero_mode)
|
||||
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||
self._bucket_store_2.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||
self._accum_grad_buckets.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||
|
||||
# assign parameters to ranks the params in the list are sorted
|
||||
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
|
||||
|
@ -301,9 +303,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
||||
|
||||
reduce_scatter_checker = partial(
|
||||
self._wait_reduce_scatter_and_accumulate_grad,
|
||||
self._wait_reduce_scatter_and_accumulate_grads,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
@ -312,7 +314,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args): # pylint: disable=W0613
|
||||
if gpc.config.fstp_handler is not None:
|
||||
if self._fstp_handler is not None:
|
||||
reduce_scatter_checker()
|
||||
|
||||
if self.skip_grad_reduce is False:
|
||||
|
@ -336,56 +338,36 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
group_id = getattr(param, "group_id")
|
||||
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
||||
|
||||
def reset_reduce_bucket(self) -> None:
|
||||
for bucket in self._bucket_store_2:
|
||||
for rank, params in bucket._params.items():
|
||||
for _param in params:
|
||||
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||
continue
|
||||
def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optional[int] = None) -> None:
|
||||
for _param in bucket.get_param(reduce_rank):
|
||||
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||
continue
|
||||
|
||||
key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||
comm_handle.wait()
|
||||
_param.grad.add_(_grad)
|
||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
# wait and accumulate gardient.
|
||||
_key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||
_comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[_key]
|
||||
_comm_handle.wait()
|
||||
_param.grad.add_(_grad)
|
||||
|
||||
# release cuda memory.
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
||||
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
||||
|
||||
bucket.reset_by_rank(rank)
|
||||
|
||||
def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
|
||||
bucket.reset_by_rank(reduce_rank)
|
||||
|
||||
def _wait_reduce_scatter_and_accumulate_grads(self, param, reduce_rank: Optional[int] = None):
|
||||
param_size = param.numel()
|
||||
|
||||
group_id = getattr(param, "group_id")
|
||||
current_bucket = self._accum_grad_buckets[group_id]
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
group_id = getattr(param, "group_id")
|
||||
current_bucket = self._bucket_store_2[group_id]
|
||||
if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size:
|
||||
self._accum_grads_store_in_bucket(current_bucket, reduce_rank)
|
||||
|
||||
if current_bucket.num_elements_in_bucket(reduce_rank) >= 512 * 1024 * 1024:
|
||||
# wait reduce scatter communication
|
||||
params = current_bucket.get_param(reduce_rank)
|
||||
for _param in params:
|
||||
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||
continue
|
||||
|
||||
key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||
comm_handle.wait()
|
||||
_param.grad.add_(_grad)
|
||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
|
||||
current_bucket.reset_by_rank(reduce_rank)
|
||||
|
||||
# otherwise, add the parameter into bucket.
|
||||
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
current_bucket.add_param(param, reduce_rank)
|
||||
|
||||
|
@ -612,6 +594,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for group_id in range(self.num_param_groups):
|
||||
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
||||
|
||||
# we need to accumulate gradients left in the accumulate gardient bucket
|
||||
for group_id in range(self.num_param_groups):
|
||||
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None)
|
||||
|
||||
# compute norm for gradients in the before bucket
|
||||
groups_norms = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
@ -773,7 +759,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(self._comm_bcast_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
|
|
|
@ -45,7 +45,7 @@ class BucketStore(BaseStore):
|
|||
|
||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||
return self._num_elements_in_bucket[reduce_rank]
|
||||
|
||||
|
||||
def num_params_in_bucket(self, reduce_rank: int = None):
|
||||
return len(self._params[reduce_rank])
|
||||
|
||||
|
|
|
@ -107,48 +107,60 @@ def initialize_model():
|
|||
|
||||
# if fsdp enabled, wrap the model
|
||||
model = wrap_FSDP_model(model)
|
||||
|
||||
|
||||
gpc.config.fstp_handler = None
|
||||
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler._register_sync_parameters_hook()
|
||||
gpc.config.fstp_handler = handler
|
||||
|
||||
|
||||
# allocate memory pool
|
||||
block_memory = {} # containing two groups of block weight
|
||||
block_memory = {} # containing two groups of block weight
|
||||
hidden_size = gpc.config.HIDDEN_SIZE
|
||||
mlp_ratio = gpc.config.MLP_RATIO
|
||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)]
|
||||
module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3']
|
||||
size_key = [
|
||||
(3 * hidden_size // world_size, hidden_size),
|
||||
(mlp_hidden_size // world_size, hidden_size),
|
||||
(hidden_size // world_size, mlp_hidden_size),
|
||||
(hidden_size // world_size, hidden_size),
|
||||
]
|
||||
module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
for i in range(2):
|
||||
weight = {}
|
||||
for name in module_name:
|
||||
if name == 'Wqkv':
|
||||
weight[name] = torch.zeros((3 * hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous()
|
||||
elif name == 'out_proj':
|
||||
weight[name] = torch.zeros((hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous()
|
||||
elif name == 'w1' or name == 'w2':
|
||||
weight[name] = torch.zeros((mlp_hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous()
|
||||
if name == "Wqkv":
|
||||
weight[name] = torch.zeros(
|
||||
(3 * hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
elif name == "out_proj":
|
||||
weight[name] = torch.zeros(
|
||||
(hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
elif name == "w1" or name == "w2":
|
||||
weight[name] = torch.zeros(
|
||||
(mlp_hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
else:
|
||||
weight[name] = torch.zeros((hidden_size, mlp_hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous()
|
||||
weight[name] = torch.zeros(
|
||||
(hidden_size, mlp_hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
block_memory[i] = weight
|
||||
reduce_scatter_memory = {}
|
||||
for key in size_key:
|
||||
reduce_scatter_memory[key] = {'data': [], 'used': []}
|
||||
|
||||
reduce_scatter_memory[key] = {"data": [], "used": []}
|
||||
|
||||
gpc.config.block_memory = block_memory
|
||||
gpc.config.reduce_scatter_memory = reduce_scatter_memory
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
|||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.parallel.sequence_parallel
|
||||
try:
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern":
|
||||
gpc.config.parallel.sequence_parallel = True
|
||||
else:
|
||||
gpc.config.parallel.sequence_parallel = False
|
||||
|
@ -106,7 +106,7 @@ def evaluate_on_val_dls(
|
|||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern":
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = torch.Size(
|
||||
[
|
||||
|
|
|
@ -45,7 +45,7 @@ def empty_cache_and_diag(batch_count, interval=50):
|
|||
# # import time
|
||||
# # time.sleep(10)
|
||||
# print(e, "rank = ", gpc.get_global_rank(), flush=True)
|
||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
|
||||
# do empty_cache after the bench
|
||||
torch.cuda.empty_cache()
|
||||
|
|
Loading…
Reference in New Issue