mirror of https://github.com/InternLM/InternLM
refactor linear
parent
ed7232777a
commit
dcd89ed304
|
@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=8, mode="fstp", overlap=True),
|
||||
tensor=dict(size=8, sp="intern", intern_overlap=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=True,
|
||||
)
|
||||
|
|
|
@ -19,25 +19,26 @@ from internlm.model.utils import (
|
|||
all_gather_raw_memory_pool,
|
||||
fstp_fused_dense_func,
|
||||
fused_dense_func_torch,
|
||||
megatron_fused_dense_func_torch,
|
||||
)
|
||||
|
||||
|
||||
class ScaleColumnParallelLinear(nn.Linear):
|
||||
class BaseScaleColumnParallelLinear(nn.Linear):
|
||||
"""
|
||||
ScaleColumnParallelLinear.
|
||||
Base class for ScaleColumnParallelLinear.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
out_features (int): size of each output sample
|
||||
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||
in the config.
|
||||
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul.
|
||||
If not, then the input is already gathered.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
weight_scale (int): For training stability. 1 by default.
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
out_features (int): size of each output sample
|
||||
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||
in the config.
|
||||
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul.
|
||||
If not, then the input is already gathered.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
weight_scale (int): For training stability. 1 by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -57,6 +58,10 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
self.process_group = process_group
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
|
||||
"""
|
||||
ScaleColumnParallelLinear in flash implementation.
|
||||
"""
|
||||
def forward(self, input, gather_dim=0): # pylint: disable=W0622
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
|
@ -74,6 +79,27 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
|
||||
"""
|
||||
ScaleColumnParallelLinear in megatron implementation.
|
||||
"""
|
||||
|
||||
def forward(self, input, gather_dim=0): # pylint: disable=W0622
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
if self.weight_scale != 1:
|
||||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||
else:
|
||||
weight = self.weight
|
||||
return megatron_fused_dense_func_torch(
|
||||
input,
|
||||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
class RewardModelLinear(ScaleColumnParallelLinear):
|
||||
"""
|
||||
|
@ -129,7 +155,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
|
|||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
|
||||
return fused_dense_func_torch(
|
||||
x,
|
||||
self.weight,
|
||||
|
@ -139,6 +164,19 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
|
|||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
|
||||
def forward(self, x, gather_dim=0):
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
return megatron_fused_dense_func_torch(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=self.sequence_parallel,
|
||||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
class RowParallelLinearTorch(RowParallelLinear):
|
||||
def forward(self, x):
|
||||
|
@ -150,10 +188,20 @@ class RowParallelLinearTorch(RowParallelLinear):
|
|||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
class MegatronRowParallelLinearTorch(RowParallelLinear):
|
||||
def forward(self, x):
|
||||
"""
|
||||
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
||||
a reduce_scatter of the result.
|
||||
"""
|
||||
out = megatron_fused_dense_func_torch(x, self.weight, self.bias)
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
class BaseFeedForward(nn.Module):
|
||||
"""
|
||||
FeedForward.
|
||||
Base FeedForward in flash implementation.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
|
@ -177,13 +225,13 @@ class FeedForward(nn.Module):
|
|||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
block_idx: int = 0,
|
||||
colum_cls = None,
|
||||
row_cls = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ColumnParallelLinearTorch(
|
||||
self.w1 = colum_cls(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
|
@ -192,7 +240,7 @@ class FeedForward(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = ColumnParallelLinearTorch(
|
||||
self.w2 = colum_cls(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
|
@ -201,7 +249,7 @@ class FeedForward(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = RowParallelLinearTorch(
|
||||
self.w3 = row_cls(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
|
@ -217,21 +265,9 @@ class FeedForward(nn.Module):
|
|||
out = self.w3(Silu(w1_o, w2_o))
|
||||
return out
|
||||
|
||||
|
||||
class FSTPLinear(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
block_index = gpc.config.fstp_handler.module_to_index[self]
|
||||
name_index = gpc.config.fstp_handler.module_name_index[self]
|
||||
name = gpc.config.fstp_handler.module_name[name_index]
|
||||
return fstp_fused_dense_func(
|
||||
x, self.weight, self.bias, process_group=self.process_group,
|
||||
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
|
||||
)
|
||||
|
||||
|
||||
class FSTPFeedForward(nn.Module):
|
||||
class FeedForward(BaseFeedForward):
|
||||
"""
|
||||
FeedForward.
|
||||
FeedForward in flash implementation.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
|
@ -255,169 +291,106 @@ class FSTPFeedForward(nn.Module):
|
|||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
block_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||
dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch)
|
||||
|
||||
|
||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||
class MegatronFeedForward(BaseFeedForward):
|
||||
"""
|
||||
FeedForward in megatron implementation.
|
||||
|
||||
self.w1 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = FSTPLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
hidden_features (int): size of hidden state of FFN
|
||||
out_features (int): size of each output sample
|
||||
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||
in the config.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_features: int = None,
|
||||
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
bias: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||
dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch)
|
||||
|
||||
class FSTPLinear(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
w1_o = self.w1(x)
|
||||
w2_o = self.w2(x)
|
||||
out = self.w3(F.silu(w1_o) * w2_o)
|
||||
return out
|
||||
block_index = gpc.config.fstp_handler.module_to_index[self]
|
||||
name_index = gpc.config.fstp_handler.module_name_index[self]
|
||||
name = gpc.config.fstp_handler.module_name[name_index]
|
||||
return fstp_fused_dense_func(
|
||||
x, self.weight, self.bias, process_group=self.process_group,
|
||||
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
|
||||
)
|
||||
|
||||
|
||||
class FSTPAllGatherSyncHandler:
|
||||
class FSTPFeedForward(BaseFeedForward):
|
||||
"""
|
||||
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
|
||||
FeedForward in FSTP.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
hidden_features (int): size of hidden state of FFN
|
||||
out_features (int): size of each output sample
|
||||
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||
in the config.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||
# import pdb; pdb.set_trace()
|
||||
self.process_group = process_group
|
||||
self.FSTP_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||
self.module_handler = dict() # key: FSTP module; value: all-gather handler
|
||||
self.module_block = dict() # key: FSTP module; value: transformer block index
|
||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_features: int = None,
|
||||
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
bias: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||
dtype, multiple_of, FSTPLinear, FSTPLinear)
|
||||
|
||||
self.reduce_scatter_handlers = {}
|
||||
self.all_reduce_handlers = {}
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _chunk_name, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for idx, block in enumerate(children):
|
||||
index = 0
|
||||
self.block_module[idx] = {}
|
||||
for _sub_name, sub in block.named_children():
|
||||
sub_modules = list(sub.children())
|
||||
if len(sub_modules) > 0:
|
||||
for name, child in sub.named_children():
|
||||
if isinstance(child, FSTPLinear):
|
||||
|
||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
|
||||
self.FSTP_modules.append(child)
|
||||
self.module_block[child] = idx
|
||||
self.block_module[idx][index] = child
|
||||
self.module_name_index[child] = index
|
||||
index = index + 1
|
||||
else:
|
||||
continue
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
"""
|
||||
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
||||
"""
|
||||
|
||||
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
||||
block_index = self.module_block[module]
|
||||
name_index = self.module_name_index[module]
|
||||
if name_index == 0:
|
||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index + 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.module_handler[module]
|
||||
handler.wait()
|
||||
if name_index != 4:
|
||||
next_module = self.block_module[block_index][name_index + 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
|
||||
def _post_forward_hook(module: nn.Module, input, output):
|
||||
if module in self.FSTP_global_weights:
|
||||
del self.FSTP_global_weights[module]
|
||||
if module in self.module_handler:
|
||||
del self.module_handler[module]
|
||||
|
||||
def _pre_backward_hook(module: nn.Module, grad_output):
|
||||
block_index = self.module_block[module]
|
||||
name_index = self.module_name_index[module]
|
||||
if name_index == 4:
|
||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.module_handler[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
|
||||
def _post_backward_hook(module, grad_input, grad_output):
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
for module in self.FSTP_modules:
|
||||
# import pdb; pdb.set_trace()
|
||||
module.register_forward_pre_hook(_pre_forward_hook)
|
||||
module.register_forward_hook(_post_forward_hook)
|
||||
# module.register_backward_pre_hook(_pre_backward_hook)
|
||||
# module.register_backward_hook(_post_backward_hook)
|
||||
module.register_full_backward_pre_hook(_pre_backward_hook)
|
||||
module.register_full_backward_hook(_post_backward_hook)
|
||||
def get_mlp_cls(sp_mode: str):
|
||||
if sp_mode in ["none", "flash-attn"]:
|
||||
mlp_cls = FeedForward
|
||||
elif sp_mode == "megatron":
|
||||
mlp_cls = MegatronFeedForward
|
||||
else:
|
||||
mlp_cls = FSTPFeedForward
|
||||
return mlp_cls
|
||||
|
||||
def get_linear_cls(sp_mode: str, parallel_mode: str):
|
||||
if parallel_mode == "column":
|
||||
if sp_mode in ["none", "flash-attn"]:
|
||||
cls = ColumnParallelLinearTorch
|
||||
elif sp_mode == "megatron":
|
||||
cls = MegatronColumnParallelLinearTorch
|
||||
else:
|
||||
cls = FSTPLinear
|
||||
elif parallel_mode == 'row':
|
||||
if sp_mode in ["none", "flash-attn"]:
|
||||
cls = RowParallelLinearTorch
|
||||
elif sp_mode == "megatron":
|
||||
cls = MegatronRowParallelLinearTorch
|
||||
else:
|
||||
cls = FSTPLinear
|
||||
return cls
|
||||
|
||||
class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||
"""
|
||||
|
@ -468,7 +441,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
sub_modules = list(sub.children())
|
||||
if len(sub_modules) > 0:
|
||||
for name, child in sub.named_children():
|
||||
# print(f"name: {name}", flush=True)
|
||||
if name == "out_proj":
|
||||
self.FSTP_outs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
|
|
|
@ -15,9 +15,12 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no
|
|||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
FeedForward,
|
||||
MegatronFeedForward,
|
||||
FSTPFeedForward,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
MegatronScaleColumnParallelLinear,
|
||||
get_mlp_cls,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import (
|
||||
|
@ -77,8 +80,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
tp_mode: str = "origin_tp",
|
||||
block_idx: int = 0,
|
||||
sp_mode: str = "none",
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
|
@ -103,8 +105,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_flash_attn=use_flash_attn,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
tp_mode=tp_mode,
|
||||
block_idx=block_idx,
|
||||
sp_mode=sp_mode,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(drop_rate)
|
||||
|
@ -116,7 +117,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
||||
if use_swiglu:
|
||||
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
|
||||
mlp_cls = get_mlp_cls(sp_mode)
|
||||
self.mlp = mlp_cls(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
|
@ -299,12 +300,16 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
|
||||
self.sp_mode = gpc.config.parallel["tensor"]["sp"]
|
||||
if self.sp_mode == "none":
|
||||
gpc.config.parallel.sequence_parallel = False
|
||||
else:
|
||||
gpc.config.parallel.sequence_parallel = True
|
||||
|
||||
if is_reward:
|
||||
head_cls = RewardModelLinear
|
||||
else:
|
||||
head_cls = ScaleColumnParallelLinear
|
||||
head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear
|
||||
if first:
|
||||
if embed_split_hidden:
|
||||
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
||||
|
@ -345,8 +350,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
tp_mode=self.tp_mode,
|
||||
block_idx=lid,
|
||||
sp_mode=self.sp_mode,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
@ -393,7 +397,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
||||
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
|
||||
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
|
||||
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
|
|
@ -42,6 +42,9 @@ from internlm.model.linear import (
|
|||
ColumnParallelLinearTorch,
|
||||
FSTPLinear,
|
||||
RowParallelLinearTorch,
|
||||
MegatronColumnParallelLinearTorch,
|
||||
MegatronRowParallelLinearTorch,
|
||||
get_linear_cls,
|
||||
)
|
||||
|
||||
|
||||
|
@ -175,8 +178,7 @@ class MHA(nn.Module):
|
|||
use_flash_attn: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tp_mode: str = "origin_tp",
|
||||
block_idx: int = 0,
|
||||
sp_mode: str = "none",
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
@ -204,7 +206,7 @@ class MHA(nn.Module):
|
|||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||
|
||||
# notice here should change bias=True
|
||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
Wqkv_cls = get_linear_cls(sp_mode, "column")
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
|
@ -220,12 +222,12 @@ class MHA(nn.Module):
|
|||
self.inner_cross_attn = inner_cross_attn_cls(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
if tp_mode == "fstp":
|
||||
if sp_mode == "intern":
|
||||
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
||||
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
||||
|
||||
# output projection always have the bias (for now)
|
||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
out_proj_cls = get_linear_cls(sp_mode, 'row')
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
|
|
|
@ -164,7 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
|
|||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFunc(torch.autograd.Function):
|
||||
"tp fused dense function"
|
||||
"FusedDenseFunc for tensor parallel in flash-attn implementation."
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
|
@ -255,9 +255,96 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class MegatronFusedDenseFunc(torch.autograd.Function):
|
||||
'''
|
||||
FusedDenseFunc for tensor parallel in megatron implementation.
|
||||
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
|
||||
so that the all-gather in backward is ommited.
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0):
|
||||
"""
|
||||
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
||||
"""
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.sequence_parallel = sequence_parallel
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
x = x.contiguous()
|
||||
if process_group is not None and sequence_parallel:
|
||||
# We want to kick off the all_gather early, before weight dtype conversion
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim)
|
||||
else:
|
||||
total_x = x
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
||||
weight = weight.contiguous()
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||
if min(batch_dim, n, *weight.shape) > 65535 * 32:
|
||||
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
||||
output = F.linear(total_x, weight, bias)
|
||||
if ctx.compute_weight_gradient:
|
||||
ctx.save_for_backward(total_x, weight)
|
||||
else:
|
||||
ctx.save_for_backward(weight)
|
||||
return output if not return_residual else (output, x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, *args):
|
||||
grad_output = grad_output.contiguous()
|
||||
if ctx.return_residual:
|
||||
(grad_input,) = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
|
||||
if ctx.compute_weight_gradient:
|
||||
total_x, weight = ctx.saved_tensors
|
||||
else:
|
||||
(weight,) = ctx.saved_tensors
|
||||
total_x = None
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
if ctx.needs_input_grad[0]:
|
||||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_output, weight.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||
else:
|
||||
grad_input = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||
)
|
||||
else:
|
||||
grad_weight = None
|
||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFuncTorch(FusedDenseFunc):
|
||||
"""A custom PyTorch module extending FusedDenseFunc."""
|
||||
'''FusedDenseFunc in flash implementation for supporting torch.float32'''
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
|
@ -307,17 +394,61 @@ class FusedDenseFuncTorch(FusedDenseFunc):
|
|||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
class MegatronFusedDenseFuncTorch(FusedDenseFunc):
|
||||
'''FusedDenseFunc in megatron implementation for supporting torch.float32'''
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, *args):
|
||||
grad_output = grad_output.contiguous()
|
||||
if ctx.return_residual:
|
||||
(grad_input,) = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
gather_dim = ctx.gather_dim
|
||||
if ctx.compute_weight_gradient:
|
||||
total_x, weight = ctx.saved_tensors
|
||||
else:
|
||||
(weight,) = ctx.saved_tensors
|
||||
total_x = None
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
if ctx.needs_input_grad[0]:
|
||||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_output, weight.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||
else:
|
||||
grad_input = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
# we remove the cuda independence, which is different from flash_attn.
|
||||
grad_weight, grad_bias = linear_bias_wgrad_torch(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||
)
|
||||
else:
|
||||
grad_weight = None
|
||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||
"FSTP fused dense function"
|
||||
"FusedDenseFunc for FSTP, which is optimized based on flash implementation."
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None, block_index=None, module_name=None):
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, overlap_handler=None, block_index=None, module_name=None):
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.all_gather_handler = all_gather_handler
|
||||
ctx.overlap_handler = overlap_handler
|
||||
ctx.module = module
|
||||
ctx.block_index = block_index
|
||||
ctx.module_name = module_name
|
||||
|
@ -329,13 +460,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
# do all_gather for weight and bias before actual computation
|
||||
if all_gather_handler is not None:# and module in all_gather_handler.FSTP_global_weights:
|
||||
# total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
|
||||
# TODO memory pool for bias
|
||||
if bias is not None:
|
||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
handle_bias.wait()
|
||||
|
@ -356,6 +486,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
|
||||
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
||||
output = F.linear(total_x, total_weight, total_bias)
|
||||
# release memory
|
||||
del total_weight
|
||||
del total_bias
|
||||
if ctx.compute_weight_gradient:
|
||||
|
@ -372,8 +503,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
(grad_input,) = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
all_gather_handler = ctx.all_gather_handler
|
||||
module = ctx.module
|
||||
overlap_handler = ctx.overlap_handler
|
||||
block_index = ctx.block_index
|
||||
module_name = ctx.module_name
|
||||
|
||||
|
@ -389,51 +519,35 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
|
||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
||||
# # do all-gather for weight before backward
|
||||
# if module in all_gather_handler.FSTP_global_weights:
|
||||
# total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||
# else:
|
||||
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
# handle_weight.wait()
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
else:
|
||||
total_weight = weight
|
||||
|
||||
# compute weight grad
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
|
||||
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||
)
|
||||
if world_size > 1:
|
||||
if gpc.config.fstp_handler is not None:
|
||||
# grad_weight_async, handle_grad_weight = all_reduce_raw(grad_weight, process_group, async_op=True)
|
||||
# assert hasattr(weight, "_fstp_all_reduce_str")
|
||||
# all_gather_handler.all_reduce_handlers[weight._fstp_all_reduce_str] = (handle_grad_weight, grad_weight_async)
|
||||
# grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||
# if grad_bias is not None:
|
||||
# grad_bias_async, handle_grad_bias = all_reduce_raw(grad_bias, process_group, async_op=True)
|
||||
# assert hasattr(bias, "_fstp_all_reduce_str")
|
||||
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
|
||||
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||
|
||||
if overlap_handler is not None:
|
||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
|
||||
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
||||
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||
grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||
if grad_bias is not None:
|
||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
|
||||
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||
grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||
else:
|
||||
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||
if grad_bias is not None:
|
||||
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
||||
# grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||
# if grad_bias is not None:
|
||||
# grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
||||
else:
|
||||
grad_weight = None
|
||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||
|
@ -449,7 +563,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
del total_weight
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
if world_size > 1 and gpc.config.fstp_handler is None:
|
||||
if world_size > 1 and overlap_handler is None:
|
||||
handle_grad_weight.wait()
|
||||
if grad_bias is not None:
|
||||
handle_grad_bias.wait()
|
||||
|
@ -473,6 +587,22 @@ def fused_dense_func_torch(
|
|||
else:
|
||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||
|
||||
def megatron_fused_dense_func_torch(
|
||||
x: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor] = None,
|
||||
return_residual: bool = False,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True,
|
||||
gather_dim: int = 0,
|
||||
):
|
||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
)
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||
else:
|
||||
return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||
|
||||
def fstp_fused_dense_func(
|
||||
x: Tensor,
|
||||
|
|
|
@ -40,11 +40,6 @@ from .utils import compute_norm
|
|||
inf = math.inf
|
||||
logger = get_logger(__file__)
|
||||
|
||||
def print_memory(msg):
|
||||
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
|
||||
print("===========================================")
|
||||
|
||||
|
||||
class HybridZeroOptimizer(BaseOptimizer):
|
||||
"""
|
||||
Hybrid Zero Optimizer.
|
||||
|
@ -70,7 +65,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
hysteresis = grad_scal_cfg.hysteresis
|
||||
max_scale = grad_scal_cfg.max_scale
|
||||
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
|
||||
self._fstp_handler = gpc.config.fstp_handler
|
||||
|
||||
# Zero related args
|
||||
|
@ -358,20 +353,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
# if not hasattr(_param, "_fstp_all_reduce_str"):
|
||||
# continue
|
||||
|
||||
# key = getattr(_param, "_fstp_all_reduce_str")
|
||||
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
|
||||
# comm_handle.wait()
|
||||
# with torch.no_grad():
|
||||
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
|
||||
# _param.grad.add_(_grad)
|
||||
# # self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
# del self._fstp_handler.all_reduce_handlers[key]
|
||||
# self._fstp_handler.all_reduce_handlers[key] = None
|
||||
# assert key in self._fstp_handler.all_reduce_handlers
|
||||
|
||||
bucket.reset_by_rank(rank)
|
||||
|
||||
|
@ -401,21 +383,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
|
||||
# if not hasattr(_param, "_fstp_all_reduce_str"):
|
||||
# continue
|
||||
|
||||
# key = getattr(_param, "_fstp_all_reduce_str")
|
||||
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
|
||||
# comm_handle.wait()
|
||||
# with torch.no_grad():
|
||||
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
|
||||
# _param.grad.add_(_grad)
|
||||
# # self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
# del self._fstp_handler.all_reduce_handlers[key]
|
||||
# self._fstp_handler.all_reduce_handlers[key] = None
|
||||
# assert key in self._fstp_handler.all_reduce_handlers
|
||||
|
||||
current_bucket.reset_by_rank(reduce_rank)
|
||||
|
||||
|
@ -634,7 +601,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
# we need to manually reduce these gradients
|
||||
print_memory("No 1")
|
||||
if not self._overlap_sync_grad:
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
|
@ -659,7 +625,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
bucket.empty()
|
||||
self._bucket_in_progress = []
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
print_memory("No 2")
|
||||
# compute norm for gradients in the last bucket
|
||||
total_norms = {}
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
@ -681,19 +646,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
total_norms[group_name] = scaled_norm_tensor.item()
|
||||
print_memory("No 3")
|
||||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
print_memory("No 4")
|
||||
|
||||
try:
|
||||
res = self._step(closure=closure, norms=total_norms)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
print(e, flush=True)
|
||||
print(torch.cuda.memory_summary(), flush=True)
|
||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
res = self._step(closure=closure, norms=total_norms)
|
||||
|
||||
return res
|
||||
|
||||
|
@ -740,7 +697,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, norms
|
||||
print_memory("No 5")
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
@ -781,7 +737,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
print_memory("No 6")
|
||||
# unscale and clip grads
|
||||
# get the global norm
|
||||
global_norm_groups = {}
|
||||
|
@ -804,12 +759,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||
# to send them updated their own parameters.
|
||||
if self.has_params:
|
||||
print_memory("No 7")
|
||||
self.optim.step()
|
||||
print_memory("No 8")
|
||||
# release the fp32 grad
|
||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||
print_memory("No 9")
|
||||
# update fp16 partition updated by the current rank
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
if self.param_group_has_params[group_id]:
|
||||
|
@ -818,7 +770,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
)
|
||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
print_memory("No 10")
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(self._comm_bcast_stream):
|
||||
self.broadcast_params()
|
||||
|
@ -829,7 +780,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# so synchronization is maintained
|
||||
for group_name, global_norm in global_norm_groups.items():
|
||||
global_norm_groups[group_name] = global_norm / loss_scale
|
||||
print_memory("No 11")
|
||||
return True, global_norm_groups
|
||||
|
||||
def broadcast_params(self):
|
||||
|
|
|
@ -38,7 +38,6 @@ from internlm.model.embedding import Embedding1D
|
|||
from internlm.model.linear import (
|
||||
CoarseGrainedFSTPAllGatherSyncHandler,
|
||||
FeedForward,
|
||||
FSTPAllGatherSyncHandler,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
)
|
||||
|
@ -111,7 +110,7 @@ def initialize_model():
|
|||
|
||||
gpc.config.fstp_handler = None
|
||||
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
|
||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler._register_sync_parameters_hook()
|
||||
|
|
4
train.py
4
train.py
|
@ -195,7 +195,7 @@ def main(args):
|
|||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||
torch.cuda.memory._record_memory_history()
|
||||
# torch.cuda.memory._record_memory_history()
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
||||
|
@ -300,7 +300,7 @@ def main(args):
|
|||
if gpc.config.fstp_handler is not None:
|
||||
gpc.config.fstp_handler.zero_const_pool = {}
|
||||
gpc.config.fstp_handler.reduce_scatter_memory = {}
|
||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
ckpt_manager.wait_async_upload_finish()
|
||||
|
|
Loading…
Reference in New Issue