mirror of https://github.com/InternLM/InternLM
refactor linear
parent
ed7232777a
commit
dcd89ed304
|
@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=-1, fsdp=False),
|
zero1=dict(size=-1, fsdp=False),
|
||||||
tensor=dict(size=8, mode="fstp", overlap=True),
|
tensor=dict(size=8, sp="intern", intern_overlap=True),
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,25 +19,26 @@ from internlm.model.utils import (
|
||||||
all_gather_raw_memory_pool,
|
all_gather_raw_memory_pool,
|
||||||
fstp_fused_dense_func,
|
fstp_fused_dense_func,
|
||||||
fused_dense_func_torch,
|
fused_dense_func_torch,
|
||||||
|
megatron_fused_dense_func_torch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScaleColumnParallelLinear(nn.Linear):
|
class BaseScaleColumnParallelLinear(nn.Linear):
|
||||||
"""
|
"""
|
||||||
ScaleColumnParallelLinear.
|
Base class for ScaleColumnParallelLinear.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample
|
in_features (int): size of each input sample
|
||||||
out_features (int): size of each output sample
|
out_features (int): size of each output sample
|
||||||
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||||
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||||
in the config.
|
in the config.
|
||||||
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||||
we do an all_gather of x before doing the matmul.
|
we do an all_gather of x before doing the matmul.
|
||||||
If not, then the input is already gathered.
|
If not, then the input is already gathered.
|
||||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||||
dtype (Optional[torch.dtype]): The type of data.
|
dtype (Optional[torch.dtype]): The type of data.
|
||||||
weight_scale (int): For training stability. 1 by default.
|
weight_scale (int): For training stability. 1 by default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -57,6 +58,10 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.weight_scale = weight_scale
|
self.weight_scale = weight_scale
|
||||||
|
|
||||||
|
class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
|
||||||
|
"""
|
||||||
|
ScaleColumnParallelLinear in flash implementation.
|
||||||
|
"""
|
||||||
def forward(self, input, gather_dim=0): # pylint: disable=W0622
|
def forward(self, input, gather_dim=0): # pylint: disable=W0622
|
||||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||||
# we do an all_gather of x before doing the matmul.
|
# we do an all_gather of x before doing the matmul.
|
||||||
|
@ -74,6 +79,27 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
gather_dim=gather_dim,
|
gather_dim=gather_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
|
||||||
|
"""
|
||||||
|
ScaleColumnParallelLinear in megatron implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, input, gather_dim=0): # pylint: disable=W0622
|
||||||
|
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||||
|
# we do an all_gather of x before doing the matmul.
|
||||||
|
# If not, then the input is already gathered.
|
||||||
|
if self.weight_scale != 1:
|
||||||
|
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||||
|
else:
|
||||||
|
weight = self.weight
|
||||||
|
return megatron_fused_dense_func_torch(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
self.bias,
|
||||||
|
process_group=self.process_group,
|
||||||
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
|
gather_dim=gather_dim,
|
||||||
|
)
|
||||||
|
|
||||||
class RewardModelLinear(ScaleColumnParallelLinear):
|
class RewardModelLinear(ScaleColumnParallelLinear):
|
||||||
"""
|
"""
|
||||||
|
@ -129,7 +155,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
|
||||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||||
# we do an all_gather of x before doing the matmul.
|
# we do an all_gather of x before doing the matmul.
|
||||||
# If not, then the input is already gathered.
|
# If not, then the input is already gathered.
|
||||||
|
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
x,
|
x,
|
||||||
self.weight,
|
self.weight,
|
||||||
|
@ -139,6 +164,19 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
|
||||||
gather_dim=gather_dim,
|
gather_dim=gather_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
|
||||||
|
def forward(self, x, gather_dim=0):
|
||||||
|
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||||
|
# we do an all_gather of x before doing the matmul.
|
||||||
|
# If not, then the input is already gathered.
|
||||||
|
return megatron_fused_dense_func_torch(
|
||||||
|
x,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
process_group=self.process_group,
|
||||||
|
sequence_parallel=self.sequence_parallel,
|
||||||
|
gather_dim=gather_dim,
|
||||||
|
)
|
||||||
|
|
||||||
class RowParallelLinearTorch(RowParallelLinear):
|
class RowParallelLinearTorch(RowParallelLinear):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -150,10 +188,20 @@ class RowParallelLinearTorch(RowParallelLinear):
|
||||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||||
return reduce_fn(out, self.process_group)
|
return reduce_fn(out, self.process_group)
|
||||||
|
|
||||||
|
class MegatronRowParallelLinearTorch(RowParallelLinear):
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
||||||
|
a reduce_scatter of the result.
|
||||||
|
"""
|
||||||
|
out = megatron_fused_dense_func_torch(x, self.weight, self.bias)
|
||||||
|
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||||
|
return reduce_fn(out, self.process_group)
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
|
class BaseFeedForward(nn.Module):
|
||||||
"""
|
"""
|
||||||
FeedForward.
|
Base FeedForward in flash implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample
|
in_features (int): size of each input sample
|
||||||
|
@ -177,13 +225,13 @@ class FeedForward(nn.Module):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
multiple_of: int = 256,
|
multiple_of: int = 256,
|
||||||
block_idx: int = 0,
|
colum_cls = None,
|
||||||
|
row_cls = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
self.w1 = ColumnParallelLinearTorch(
|
self.w1 = colum_cls(
|
||||||
in_features,
|
in_features,
|
||||||
hidden_features,
|
hidden_features,
|
||||||
process_group,
|
process_group,
|
||||||
|
@ -192,7 +240,7 @@ class FeedForward(nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.w2 = ColumnParallelLinearTorch(
|
self.w2 = colum_cls(
|
||||||
in_features,
|
in_features,
|
||||||
hidden_features,
|
hidden_features,
|
||||||
process_group,
|
process_group,
|
||||||
|
@ -201,7 +249,7 @@ class FeedForward(nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.w3 = RowParallelLinearTorch(
|
self.w3 = row_cls(
|
||||||
hidden_features,
|
hidden_features,
|
||||||
out_features,
|
out_features,
|
||||||
process_group,
|
process_group,
|
||||||
|
@ -217,21 +265,9 @@ class FeedForward(nn.Module):
|
||||||
out = self.w3(Silu(w1_o, w2_o))
|
out = self.w3(Silu(w1_o, w2_o))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class FeedForward(BaseFeedForward):
|
||||||
class FSTPLinear(ColumnParallelLinear):
|
|
||||||
def forward(self, x):
|
|
||||||
block_index = gpc.config.fstp_handler.module_to_index[self]
|
|
||||||
name_index = gpc.config.fstp_handler.module_name_index[self]
|
|
||||||
name = gpc.config.fstp_handler.module_name[name_index]
|
|
||||||
return fstp_fused_dense_func(
|
|
||||||
x, self.weight, self.bias, process_group=self.process_group,
|
|
||||||
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FSTPFeedForward(nn.Module):
|
|
||||||
"""
|
"""
|
||||||
FeedForward.
|
FeedForward in flash implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample
|
in_features (int): size of each input sample
|
||||||
|
@ -255,169 +291,106 @@ class FSTPFeedForward(nn.Module):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
multiple_of: int = 256,
|
multiple_of: int = 256,
|
||||||
block_idx: int = 0,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||||
|
dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch)
|
||||||
|
|
||||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
|
||||||
|
|
||||||
self.w1 = FSTPLinear(
|
class MegatronFeedForward(BaseFeedForward):
|
||||||
in_features,
|
"""
|
||||||
hidden_features,
|
FeedForward in megatron implementation.
|
||||||
process_group,
|
|
||||||
bias,
|
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
self.w2 = FSTPLinear(
|
|
||||||
in_features,
|
|
||||||
hidden_features,
|
|
||||||
process_group,
|
|
||||||
bias,
|
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
self.w3 = FSTPLinear(
|
|
||||||
hidden_features,
|
|
||||||
out_features,
|
|
||||||
process_group,
|
|
||||||
bias=bias,
|
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): size of each input sample
|
||||||
|
hidden_features (int): size of hidden state of FFN
|
||||||
|
out_features (int): size of each output sample
|
||||||
|
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||||
|
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||||
|
in the config.
|
||||||
|
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||||
|
dtype (Optional[torch.dtype]): The type of data.
|
||||||
|
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: int,
|
||||||
|
out_features: int = None,
|
||||||
|
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
|
bias: bool = True,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
):
|
||||||
|
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||||
|
dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch)
|
||||||
|
|
||||||
|
class FSTPLinear(ColumnParallelLinear):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
w1_o = self.w1(x)
|
block_index = gpc.config.fstp_handler.module_to_index[self]
|
||||||
w2_o = self.w2(x)
|
name_index = gpc.config.fstp_handler.module_name_index[self]
|
||||||
out = self.w3(F.silu(w1_o) * w2_o)
|
name = gpc.config.fstp_handler.module_name[name_index]
|
||||||
return out
|
return fstp_fused_dense_func(
|
||||||
|
x, self.weight, self.bias, process_group=self.process_group,
|
||||||
|
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
|
||||||
|
)
|
||||||
|
|
||||||
|
class FSTPFeedForward(BaseFeedForward):
|
||||||
class FSTPAllGatherSyncHandler:
|
|
||||||
"""
|
"""
|
||||||
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
|
FeedForward in FSTP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): size of each input sample
|
||||||
|
hidden_features (int): size of hidden state of FFN
|
||||||
|
out_features (int): size of each output sample
|
||||||
|
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
||||||
|
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
||||||
|
in the config.
|
||||||
|
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||||
|
dtype (Optional[torch.dtype]): The type of data.
|
||||||
|
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
def __init__(
|
||||||
# import pdb; pdb.set_trace()
|
self,
|
||||||
self.process_group = process_group
|
in_features: int,
|
||||||
self.FSTP_modules = []
|
hidden_features: int,
|
||||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
out_features: int = None,
|
||||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
self.module_handler = dict() # key: FSTP module; value: all-gather handler
|
bias: bool = True,
|
||||||
self.module_block = dict() # key: FSTP module; value: transformer block index
|
device: Optional[torch.device] = None,
|
||||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
dtype: Optional[torch.dtype] = None,
|
||||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
multiple_of: int = 256,
|
||||||
|
):
|
||||||
|
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||||
|
dtype, multiple_of, FSTPLinear, FSTPLinear)
|
||||||
|
|
||||||
self.reduce_scatter_handlers = {}
|
def get_mlp_cls(sp_mode: str):
|
||||||
self.all_reduce_handlers = {}
|
if sp_mode in ["none", "flash-attn"]:
|
||||||
|
mlp_cls = FeedForward
|
||||||
# just want to share same for loop for ModuleList and Module
|
elif sp_mode == "megatron":
|
||||||
if not isinstance(model, nn.ModuleList):
|
mlp_cls = MegatronFeedForward
|
||||||
model = [model]
|
else:
|
||||||
|
mlp_cls = FSTPFeedForward
|
||||||
for _chunk in model:
|
return mlp_cls
|
||||||
if isinstance(_chunk, NaiveAMPModel):
|
|
||||||
_chunk = _chunk.model
|
|
||||||
|
|
||||||
for _chunk_name, children in _chunk.named_children():
|
|
||||||
if isinstance(children, nn.ModuleList):
|
|
||||||
for idx, block in enumerate(children):
|
|
||||||
index = 0
|
|
||||||
self.block_module[idx] = {}
|
|
||||||
for _sub_name, sub in block.named_children():
|
|
||||||
sub_modules = list(sub.children())
|
|
||||||
if len(sub_modules) > 0:
|
|
||||||
for name, child in sub.named_children():
|
|
||||||
if isinstance(child, FSTPLinear):
|
|
||||||
|
|
||||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
|
||||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
|
||||||
if child.bias is not None:
|
|
||||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
|
||||||
|
|
||||||
self.FSTP_modules.append(child)
|
|
||||||
self.module_block[child] = idx
|
|
||||||
self.block_module[idx][index] = child
|
|
||||||
self.module_name_index[child] = index
|
|
||||||
index = index + 1
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
def _register_sync_parameters_hook(self) -> None:
|
|
||||||
"""
|
|
||||||
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
|
||||||
block_index = self.module_block[module]
|
|
||||||
name_index = self.module_name_index[module]
|
|
||||||
if name_index == 0:
|
|
||||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
|
||||||
weight_handler.wait()
|
|
||||||
self.FSTP_global_weights[module] = total_weight
|
|
||||||
|
|
||||||
# start the all-gather for next module
|
|
||||||
next_module = self.block_module[block_index][name_index + 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.module_handler[next_module] = weights_handler
|
|
||||||
else:
|
|
||||||
handler = self.module_handler[module]
|
|
||||||
handler.wait()
|
|
||||||
if name_index != 4:
|
|
||||||
next_module = self.block_module[block_index][name_index + 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.module_handler[next_module] = weights_handler
|
|
||||||
|
|
||||||
def _post_forward_hook(module: nn.Module, input, output):
|
|
||||||
if module in self.FSTP_global_weights:
|
|
||||||
del self.FSTP_global_weights[module]
|
|
||||||
if module in self.module_handler:
|
|
||||||
del self.module_handler[module]
|
|
||||||
|
|
||||||
def _pre_backward_hook(module: nn.Module, grad_output):
|
|
||||||
block_index = self.module_block[module]
|
|
||||||
name_index = self.module_name_index[module]
|
|
||||||
if name_index == 4:
|
|
||||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
|
||||||
weight_handler.wait()
|
|
||||||
self.FSTP_global_weights[module] = total_weight
|
|
||||||
|
|
||||||
# start the all-gather for next module
|
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.module_handler[next_module] = weights_handler
|
|
||||||
else:
|
|
||||||
handler = self.module_handler[module]
|
|
||||||
handler.wait()
|
|
||||||
if name_index != 0:
|
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.module_handler[next_module] = weights_handler
|
|
||||||
|
|
||||||
def _post_backward_hook(module, grad_input, grad_output):
|
|
||||||
del self.FSTP_global_weights[module]
|
|
||||||
|
|
||||||
for module in self.FSTP_modules:
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
module.register_forward_pre_hook(_pre_forward_hook)
|
|
||||||
module.register_forward_hook(_post_forward_hook)
|
|
||||||
# module.register_backward_pre_hook(_pre_backward_hook)
|
|
||||||
# module.register_backward_hook(_post_backward_hook)
|
|
||||||
module.register_full_backward_pre_hook(_pre_backward_hook)
|
|
||||||
module.register_full_backward_hook(_post_backward_hook)
|
|
||||||
|
|
||||||
|
def get_linear_cls(sp_mode: str, parallel_mode: str):
|
||||||
|
if parallel_mode == "column":
|
||||||
|
if sp_mode in ["none", "flash-attn"]:
|
||||||
|
cls = ColumnParallelLinearTorch
|
||||||
|
elif sp_mode == "megatron":
|
||||||
|
cls = MegatronColumnParallelLinearTorch
|
||||||
|
else:
|
||||||
|
cls = FSTPLinear
|
||||||
|
elif parallel_mode == 'row':
|
||||||
|
if sp_mode in ["none", "flash-attn"]:
|
||||||
|
cls = RowParallelLinearTorch
|
||||||
|
elif sp_mode == "megatron":
|
||||||
|
cls = MegatronRowParallelLinearTorch
|
||||||
|
else:
|
||||||
|
cls = FSTPLinear
|
||||||
|
return cls
|
||||||
|
|
||||||
class CoarseGrainedFSTPAllGatherSyncHandler:
|
class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
"""
|
"""
|
||||||
|
@ -468,7 +441,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
sub_modules = list(sub.children())
|
sub_modules = list(sub.children())
|
||||||
if len(sub_modules) > 0:
|
if len(sub_modules) > 0:
|
||||||
for name, child in sub.named_children():
|
for name, child in sub.named_children():
|
||||||
# print(f"name: {name}", flush=True)
|
|
||||||
if name == "out_proj":
|
if name == "out_proj":
|
||||||
self.FSTP_outs.append(child)
|
self.FSTP_outs.append(child)
|
||||||
self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
|
|
|
@ -15,9 +15,12 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import (
|
from internlm.model.linear import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
|
MegatronFeedForward,
|
||||||
FSTPFeedForward,
|
FSTPFeedForward,
|
||||||
RewardModelLinear,
|
RewardModelLinear,
|
||||||
ScaleColumnParallelLinear,
|
ScaleColumnParallelLinear,
|
||||||
|
MegatronScaleColumnParallelLinear,
|
||||||
|
get_mlp_cls,
|
||||||
)
|
)
|
||||||
from internlm.model.multi_head_attention import MHA
|
from internlm.model.multi_head_attention import MHA
|
||||||
from internlm.model.utils import (
|
from internlm.model.utils import (
|
||||||
|
@ -77,8 +80,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
tp_mode: str = "origin_tp",
|
sp_mode: str = "none",
|
||||||
block_idx: int = 0,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
@ -103,8 +105,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tp_mode=tp_mode,
|
sp_mode=sp_mode,
|
||||||
block_idx=block_idx,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout1 = nn.Dropout(drop_rate)
|
self.dropout1 = nn.Dropout(drop_rate)
|
||||||
|
@ -116,7 +117,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
|
|
||||||
if use_swiglu:
|
if use_swiglu:
|
||||||
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
|
mlp_cls = get_mlp_cls(sp_mode)
|
||||||
self.mlp = mlp_cls(
|
self.mlp = mlp_cls(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
int(hidden_size * mlp_ratio),
|
int(hidden_size * mlp_ratio),
|
||||||
|
@ -299,12 +300,16 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
checkpoint_layer_num = int(num_layers * checkpoint)
|
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||||
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
|
self.sp_mode = gpc.config.parallel["tensor"]["sp"]
|
||||||
|
if self.sp_mode == "none":
|
||||||
|
gpc.config.parallel.sequence_parallel = False
|
||||||
|
else:
|
||||||
|
gpc.config.parallel.sequence_parallel = True
|
||||||
|
|
||||||
if is_reward:
|
if is_reward:
|
||||||
head_cls = RewardModelLinear
|
head_cls = RewardModelLinear
|
||||||
else:
|
else:
|
||||||
head_cls = ScaleColumnParallelLinear
|
head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear
|
||||||
if first:
|
if first:
|
||||||
if embed_split_hidden:
|
if embed_split_hidden:
|
||||||
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
||||||
|
@ -345,8 +350,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_scaled_init=use_scaled_init,
|
use_scaled_init=use_scaled_init,
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
tp_mode=self.tp_mode,
|
sp_mode=self.sp_mode,
|
||||||
block_idx=lid,
|
|
||||||
)
|
)
|
||||||
for lid in range(num_layers)
|
for lid in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -393,7 +397,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||||
indexes = indexes[0]
|
indexes = indexes[0]
|
||||||
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
||||||
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
|
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
|
||||||
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
||||||
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||||
|
|
|
@ -42,6 +42,9 @@ from internlm.model.linear import (
|
||||||
ColumnParallelLinearTorch,
|
ColumnParallelLinearTorch,
|
||||||
FSTPLinear,
|
FSTPLinear,
|
||||||
RowParallelLinearTorch,
|
RowParallelLinearTorch,
|
||||||
|
MegatronColumnParallelLinearTorch,
|
||||||
|
MegatronRowParallelLinearTorch,
|
||||||
|
get_linear_cls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,8 +178,7 @@ class MHA(nn.Module):
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
tp_mode: str = "origin_tp",
|
sp_mode: str = "none",
|
||||||
block_idx: int = 0,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -204,7 +206,7 @@ class MHA(nn.Module):
|
||||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||||
|
|
||||||
# notice here should change bias=True
|
# notice here should change bias=True
|
||||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
Wqkv_cls = get_linear_cls(sp_mode, "column")
|
||||||
self.Wqkv = Wqkv_cls(
|
self.Wqkv = Wqkv_cls(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
3 * embed_dim,
|
3 * embed_dim,
|
||||||
|
@ -220,12 +222,12 @@ class MHA(nn.Module):
|
||||||
self.inner_cross_attn = inner_cross_attn_cls(
|
self.inner_cross_attn = inner_cross_attn_cls(
|
||||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||||
)
|
)
|
||||||
if tp_mode == "fstp":
|
if sp_mode == "intern":
|
||||||
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
||||||
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
||||||
|
|
||||||
# output projection always have the bias (for now)
|
# output projection always have the bias (for now)
|
||||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
out_proj_cls = get_linear_cls(sp_mode, 'row')
|
||||||
self.out_proj = out_proj_cls(
|
self.out_proj = out_proj_cls(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
|
|
@ -164,7 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
|
||||||
|
|
||||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||||
class FusedDenseFunc(torch.autograd.Function):
|
class FusedDenseFunc(torch.autograd.Function):
|
||||||
"tp fused dense function"
|
"FusedDenseFunc for tensor parallel in flash-attn implementation."
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
|
@ -255,9 +255,96 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class MegatronFusedDenseFunc(torch.autograd.Function):
|
||||||
|
'''
|
||||||
|
FusedDenseFunc for tensor parallel in megatron implementation.
|
||||||
|
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
|
||||||
|
so that the all-gather in backward is ommited.
|
||||||
|
'''
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@custom_fwd
|
||||||
|
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0):
|
||||||
|
"""
|
||||||
|
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||||
|
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
||||||
|
"""
|
||||||
|
ctx.compute_weight_gradient = weight.requires_grad
|
||||||
|
ctx.return_residual = return_residual
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.sequence_parallel = sequence_parallel
|
||||||
|
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
|
x = x.contiguous()
|
||||||
|
if process_group is not None and sequence_parallel:
|
||||||
|
# We want to kick off the all_gather early, before weight dtype conversion
|
||||||
|
total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim)
|
||||||
|
else:
|
||||||
|
total_x = x
|
||||||
|
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
|
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
||||||
|
weight = weight.contiguous()
|
||||||
|
if process_group is not None and sequence_parallel:
|
||||||
|
handle_x.wait()
|
||||||
|
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||||
|
batch_dim = batch_shape.numel()
|
||||||
|
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||||
|
if min(batch_dim, n, *weight.shape) > 65535 * 32:
|
||||||
|
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
||||||
|
output = F.linear(total_x, weight, bias)
|
||||||
|
if ctx.compute_weight_gradient:
|
||||||
|
ctx.save_for_backward(total_x, weight)
|
||||||
|
else:
|
||||||
|
ctx.save_for_backward(weight)
|
||||||
|
return output if not return_residual else (output, x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
|
def backward(ctx, grad_output, *args):
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
if ctx.return_residual:
|
||||||
|
(grad_input,) = args
|
||||||
|
grad_input = grad_input.contiguous()
|
||||||
|
process_group = ctx.process_group
|
||||||
|
sequence_parallel = ctx.sequence_parallel
|
||||||
|
|
||||||
|
if ctx.compute_weight_gradient:
|
||||||
|
total_x, weight = ctx.saved_tensors
|
||||||
|
else:
|
||||||
|
(weight,) = ctx.saved_tensors
|
||||||
|
total_x = None
|
||||||
|
batch_shape = grad_output.shape[:-1]
|
||||||
|
batch_dim = batch_shape.numel()
|
||||||
|
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||||
|
if ctx.needs_input_grad[0]:
|
||||||
|
if not ctx.return_residual:
|
||||||
|
grad_input = F.linear(grad_output, weight.t())
|
||||||
|
else:
|
||||||
|
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
|
||||||
|
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||||
|
if process_group is not None:
|
||||||
|
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||||
|
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||||
|
else:
|
||||||
|
grad_input = None
|
||||||
|
if ctx.needs_input_grad[1]:
|
||||||
|
assert ctx.compute_weight_gradient
|
||||||
|
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||||
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = None
|
||||||
|
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||||
|
if process_group is not None and ctx.needs_input_grad[0]:
|
||||||
|
handle_grad_input.wait()
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||||
class FusedDenseFuncTorch(FusedDenseFunc):
|
class FusedDenseFuncTorch(FusedDenseFunc):
|
||||||
"""A custom PyTorch module extending FusedDenseFunc."""
|
'''FusedDenseFunc in flash implementation for supporting torch.float32'''
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
|
@ -307,17 +394,61 @@ class FusedDenseFuncTorch(FusedDenseFunc):
|
||||||
handle_grad_input.wait()
|
handle_grad_input.wait()
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
class MegatronFusedDenseFuncTorch(FusedDenseFunc):
|
||||||
|
'''FusedDenseFunc in megatron implementation for supporting torch.float32'''
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
|
def backward(ctx, grad_output, *args):
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
if ctx.return_residual:
|
||||||
|
(grad_input,) = args
|
||||||
|
grad_input = grad_input.contiguous()
|
||||||
|
process_group = ctx.process_group
|
||||||
|
sequence_parallel = ctx.sequence_parallel
|
||||||
|
gather_dim = ctx.gather_dim
|
||||||
|
if ctx.compute_weight_gradient:
|
||||||
|
total_x, weight = ctx.saved_tensors
|
||||||
|
else:
|
||||||
|
(weight,) = ctx.saved_tensors
|
||||||
|
total_x = None
|
||||||
|
batch_shape = grad_output.shape[:-1]
|
||||||
|
batch_dim = batch_shape.numel()
|
||||||
|
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||||
|
if ctx.needs_input_grad[0]:
|
||||||
|
if not ctx.return_residual:
|
||||||
|
grad_input = F.linear(grad_output, weight.t())
|
||||||
|
else:
|
||||||
|
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
|
||||||
|
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||||
|
if process_group is not None:
|
||||||
|
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||||
|
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||||
|
else:
|
||||||
|
grad_input = None
|
||||||
|
if ctx.needs_input_grad[1]:
|
||||||
|
assert ctx.compute_weight_gradient
|
||||||
|
# we remove the cuda independence, which is different from flash_attn.
|
||||||
|
grad_weight, grad_bias = linear_bias_wgrad_torch(
|
||||||
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = None
|
||||||
|
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||||
|
if process_group is not None and ctx.needs_input_grad[0]:
|
||||||
|
handle_grad_input.wait()
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
class FSTPFusedDenseFunc(torch.autograd.Function):
|
class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
"FSTP fused dense function"
|
"FusedDenseFunc for FSTP, which is optimized based on flash implementation."
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None, block_index=None, module_name=None):
|
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, overlap_handler=None, block_index=None, module_name=None):
|
||||||
ctx.compute_weight_gradient = weight.requires_grad
|
ctx.compute_weight_gradient = weight.requires_grad
|
||||||
ctx.return_residual = return_residual
|
ctx.return_residual = return_residual
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.all_gather_handler = all_gather_handler
|
ctx.overlap_handler = overlap_handler
|
||||||
ctx.module = module
|
ctx.module = module
|
||||||
ctx.block_index = block_index
|
ctx.block_index = block_index
|
||||||
ctx.module_name = module_name
|
ctx.module_name = module_name
|
||||||
|
@ -329,13 +460,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
# do all_gather for weight and bias before actual computation
|
# do all_gather for weight and bias before actual computation
|
||||||
if all_gather_handler is not None:# and module in all_gather_handler.FSTP_global_weights:
|
if overlap_handler is not None:
|
||||||
# total_weight = all_gather_handler.FSTP_global_weights[module]
|
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
||||||
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
|
||||||
else:
|
else:
|
||||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
handle_weight.wait()
|
handle_weight.wait()
|
||||||
|
# TODO memory pool for bias
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||||
handle_bias.wait()
|
handle_bias.wait()
|
||||||
|
@ -356,6 +486,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
|
if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
|
||||||
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
||||||
output = F.linear(total_x, total_weight, total_bias)
|
output = F.linear(total_x, total_weight, total_bias)
|
||||||
|
# release memory
|
||||||
del total_weight
|
del total_weight
|
||||||
del total_bias
|
del total_bias
|
||||||
if ctx.compute_weight_gradient:
|
if ctx.compute_weight_gradient:
|
||||||
|
@ -372,8 +503,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
(grad_input,) = args
|
(grad_input,) = args
|
||||||
grad_input = grad_input.contiguous()
|
grad_input = grad_input.contiguous()
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
all_gather_handler = ctx.all_gather_handler
|
overlap_handler = ctx.overlap_handler
|
||||||
module = ctx.module
|
|
||||||
block_index = ctx.block_index
|
block_index = ctx.block_index
|
||||||
module_name = ctx.module_name
|
module_name = ctx.module_name
|
||||||
|
|
||||||
|
@ -389,51 +519,35 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
if overlap_handler is not None:
|
||||||
# # do all-gather for weight before backward
|
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
||||||
# if module in all_gather_handler.FSTP_global_weights:
|
else:
|
||||||
# total_weight = all_gather_handler.FSTP_global_weights[module]
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
# else:
|
handle_weight.wait()
|
||||||
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
|
||||||
# handle_weight.wait()
|
|
||||||
else:
|
else:
|
||||||
total_weight = weight
|
total_weight = weight
|
||||||
|
|
||||||
# compute weight grad
|
# compute weight grad
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
assert ctx.compute_weight_gradient
|
assert ctx.compute_weight_gradient
|
||||||
|
|
||||||
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
)
|
)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
if gpc.config.fstp_handler is not None:
|
if overlap_handler is not None:
|
||||||
# grad_weight_async, handle_grad_weight = all_reduce_raw(grad_weight, process_group, async_op=True)
|
|
||||||
# assert hasattr(weight, "_fstp_all_reduce_str")
|
|
||||||
# all_gather_handler.all_reduce_handlers[weight._fstp_all_reduce_str] = (handle_grad_weight, grad_weight_async)
|
|
||||||
# grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
|
||||||
# if grad_bias is not None:
|
|
||||||
# grad_bias_async, handle_grad_bias = all_reduce_raw(grad_bias, process_group, async_op=True)
|
|
||||||
# assert hasattr(bias, "_fstp_all_reduce_str")
|
|
||||||
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
|
|
||||||
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
|
||||||
|
|
||||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
|
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
|
||||||
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
||||||
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||||
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
|
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
|
||||||
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||||
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||||
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||||
else:
|
else:
|
||||||
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
||||||
# grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
|
||||||
# if grad_bias is not None:
|
|
||||||
# grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
|
||||||
else:
|
else:
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||||
|
@ -449,7 +563,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
del total_weight
|
del total_weight
|
||||||
|
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
if world_size > 1 and gpc.config.fstp_handler is None:
|
if world_size > 1 and overlap_handler is None:
|
||||||
handle_grad_weight.wait()
|
handle_grad_weight.wait()
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
handle_grad_bias.wait()
|
handle_grad_bias.wait()
|
||||||
|
@ -473,6 +587,22 @@ def fused_dense_func_torch(
|
||||||
else:
|
else:
|
||||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||||
|
|
||||||
|
def megatron_fused_dense_func_torch(
|
||||||
|
x: Tensor,
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Optional[Tensor] = None,
|
||||||
|
return_residual: bool = False,
|
||||||
|
process_group: Optional[ProcessGroup] = None,
|
||||||
|
sequence_parallel: bool = True,
|
||||||
|
gather_dim: int = 0,
|
||||||
|
):
|
||||||
|
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||||
|
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||||
|
)
|
||||||
|
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||||
|
return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||||
|
else:
|
||||||
|
return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||||
|
|
||||||
def fstp_fused_dense_func(
|
def fstp_fused_dense_func(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
|
|
|
@ -40,11 +40,6 @@ from .utils import compute_norm
|
||||||
inf = math.inf
|
inf = math.inf
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
def print_memory(msg):
|
|
||||||
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
|
|
||||||
print("===========================================")
|
|
||||||
|
|
||||||
|
|
||||||
class HybridZeroOptimizer(BaseOptimizer):
|
class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
Hybrid Zero Optimizer.
|
Hybrid Zero Optimizer.
|
||||||
|
@ -70,7 +65,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
hysteresis = grad_scal_cfg.hysteresis
|
hysteresis = grad_scal_cfg.hysteresis
|
||||||
max_scale = grad_scal_cfg.max_scale
|
max_scale = grad_scal_cfg.max_scale
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
|
||||||
self._fstp_handler = gpc.config.fstp_handler
|
self._fstp_handler = gpc.config.fstp_handler
|
||||||
|
|
||||||
# Zero related args
|
# Zero related args
|
||||||
|
@ -358,20 +353,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||||
# if not hasattr(_param, "_fstp_all_reduce_str"):
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# key = getattr(_param, "_fstp_all_reduce_str")
|
|
||||||
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
|
|
||||||
# comm_handle.wait()
|
|
||||||
# with torch.no_grad():
|
|
||||||
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
|
|
||||||
# _param.grad.add_(_grad)
|
|
||||||
# # self._fstp_handler.reduce_scatter_handlers[key] = None
|
|
||||||
# del _grad
|
|
||||||
# del self._fstp_handler.all_reduce_handlers[key]
|
|
||||||
# self._fstp_handler.all_reduce_handlers[key] = None
|
|
||||||
# assert key in self._fstp_handler.all_reduce_handlers
|
|
||||||
|
|
||||||
bucket.reset_by_rank(rank)
|
bucket.reset_by_rank(rank)
|
||||||
|
|
||||||
|
@ -402,21 +384,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||||
|
|
||||||
# if not hasattr(_param, "_fstp_all_reduce_str"):
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# key = getattr(_param, "_fstp_all_reduce_str")
|
|
||||||
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
|
|
||||||
# comm_handle.wait()
|
|
||||||
# with torch.no_grad():
|
|
||||||
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
|
|
||||||
# _param.grad.add_(_grad)
|
|
||||||
# # self._fstp_handler.reduce_scatter_handlers[key] = None
|
|
||||||
# del _grad
|
|
||||||
# del self._fstp_handler.all_reduce_handlers[key]
|
|
||||||
# self._fstp_handler.all_reduce_handlers[key] = None
|
|
||||||
# assert key in self._fstp_handler.all_reduce_handlers
|
|
||||||
|
|
||||||
current_bucket.reset_by_rank(reduce_rank)
|
current_bucket.reset_by_rank(reduce_rank)
|
||||||
|
|
||||||
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
|
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||||
|
@ -634,7 +601,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# if not overlapping communication (no reduction hook is attached)
|
# if not overlapping communication (no reduction hook is attached)
|
||||||
# we need to manually reduce these gradients
|
# we need to manually reduce these gradients
|
||||||
print_memory("No 1")
|
|
||||||
if not self._overlap_sync_grad:
|
if not self._overlap_sync_grad:
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
for param in self._fp16_param_groups[group_id]:
|
for param in self._fp16_param_groups[group_id]:
|
||||||
|
@ -659,7 +625,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
bucket.empty()
|
bucket.empty()
|
||||||
self._bucket_in_progress = []
|
self._bucket_in_progress = []
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
self._param_store.clear_grads_of_previous_reduced_params()
|
||||||
print_memory("No 2")
|
|
||||||
# compute norm for gradients in the last bucket
|
# compute norm for gradients in the last bucket
|
||||||
total_norms = {}
|
total_norms = {}
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
|
@ -681,19 +646,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||||
total_norms[group_name] = scaled_norm_tensor.item()
|
total_norms[group_name] = scaled_norm_tensor.item()
|
||||||
print_memory("No 3")
|
|
||||||
timer("sync_grad").start()
|
timer("sync_grad").start()
|
||||||
self._sync_grad()
|
self._sync_grad()
|
||||||
timer("sync_grad").stop()
|
timer("sync_grad").stop()
|
||||||
|
|
||||||
print_memory("No 4")
|
res = self._step(closure=closure, norms=total_norms)
|
||||||
|
|
||||||
try:
|
|
||||||
res = self._step(closure=closure, norms=total_norms)
|
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
|
||||||
print(e, flush=True)
|
|
||||||
print(torch.cuda.memory_summary(), flush=True)
|
|
||||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -740,7 +697,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._grad_store._averaged_gradients = dict()
|
self._grad_store._averaged_gradients = dict()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, norms
|
return False, norms
|
||||||
print_memory("No 5")
|
|
||||||
# copy the grad of fp16 param to fp32 param
|
# copy the grad of fp16 param to fp32 param
|
||||||
single_grad_partition_groups = []
|
single_grad_partition_groups = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
|
@ -781,7 +737,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||||
print_memory("No 6")
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
# get the global norm
|
# get the global norm
|
||||||
global_norm_groups = {}
|
global_norm_groups = {}
|
||||||
|
@ -804,12 +759,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||||
# to send them updated their own parameters.
|
# to send them updated their own parameters.
|
||||||
if self.has_params:
|
if self.has_params:
|
||||||
print_memory("No 7")
|
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
print_memory("No 8")
|
|
||||||
# release the fp32 grad
|
# release the fp32 grad
|
||||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||||
print_memory("No 9")
|
|
||||||
# update fp16 partition updated by the current rank
|
# update fp16 partition updated by the current rank
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
if self.param_group_has_params[group_id]:
|
if self.param_group_has_params[group_id]:
|
||||||
|
@ -818,7 +770,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
)
|
)
|
||||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
fp16_param.data.copy_(fp32_param)
|
fp16_param.data.copy_(fp32_param)
|
||||||
print_memory("No 10")
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.stream(self._comm_bcast_stream):
|
with torch.cuda.stream(self._comm_bcast_stream):
|
||||||
self.broadcast_params()
|
self.broadcast_params()
|
||||||
|
@ -829,7 +780,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# so synchronization is maintained
|
# so synchronization is maintained
|
||||||
for group_name, global_norm in global_norm_groups.items():
|
for group_name, global_norm in global_norm_groups.items():
|
||||||
global_norm_groups[group_name] = global_norm / loss_scale
|
global_norm_groups[group_name] = global_norm / loss_scale
|
||||||
print_memory("No 11")
|
|
||||||
return True, global_norm_groups
|
return True, global_norm_groups
|
||||||
|
|
||||||
def broadcast_params(self):
|
def broadcast_params(self):
|
||||||
|
|
|
@ -38,7 +38,6 @@ from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import (
|
from internlm.model.linear import (
|
||||||
CoarseGrainedFSTPAllGatherSyncHandler,
|
CoarseGrainedFSTPAllGatherSyncHandler,
|
||||||
FeedForward,
|
FeedForward,
|
||||||
FSTPAllGatherSyncHandler,
|
|
||||||
RewardModelLinear,
|
RewardModelLinear,
|
||||||
ScaleColumnParallelLinear,
|
ScaleColumnParallelLinear,
|
||||||
)
|
)
|
||||||
|
@ -111,7 +110,7 @@ def initialize_model():
|
||||||
|
|
||||||
gpc.config.fstp_handler = None
|
gpc.config.fstp_handler = None
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
|
||||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
handler._register_sync_parameters_hook()
|
handler._register_sync_parameters_hook()
|
||||||
|
|
4
train.py
4
train.py
|
@ -195,7 +195,7 @@ def main(args):
|
||||||
# start iterating the train data and begin training
|
# start iterating the train data and begin training
|
||||||
for batch_count in range(train_state.batch_count, total_steps):
|
for batch_count in range(train_state.batch_count, total_steps):
|
||||||
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||||
torch.cuda.memory._record_memory_history()
|
# torch.cuda.memory._record_memory_history()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
timer("one-batch").start()
|
timer("one-batch").start()
|
||||||
|
|
||||||
|
@ -300,7 +300,7 @@ def main(args):
|
||||||
if gpc.config.fstp_handler is not None:
|
if gpc.config.fstp_handler is not None:
|
||||||
gpc.config.fstp_handler.zero_const_pool = {}
|
gpc.config.fstp_handler.zero_const_pool = {}
|
||||||
gpc.config.fstp_handler.reduce_scatter_memory = {}
|
gpc.config.fstp_handler.reduce_scatter_memory = {}
|
||||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
ckpt_manager.wait_async_upload_finish()
|
ckpt_manager.wait_async_upload_finish()
|
||||||
|
|
Loading…
Reference in New Issue