mirror of https://github.com/hpcaitech/ColossalAI
[shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)
* fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloompull/4526/head
parent
376533a564
commit
c554b7f559
|
@ -211,43 +211,36 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# create new stream for calculate the gradient
|
|
||||||
calculate_stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
# do all gather in default stream
|
|
||||||
input_ = input_.contiguous()
|
input_ = input_.contiguous()
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
|
|
||||||
|
# do all gather in is async way
|
||||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||||
|
# calculate gradient and prepare data asynchronously with all-gather
|
||||||
# calculate gradient in calculate_stream
|
# calculate
|
||||||
with torch.cuda.stream(calculate_stream):
|
grad_input = grad_output.matmul(weight)
|
||||||
# calculate
|
grad_output = grad_output.contiguous()
|
||||||
grad_input = grad_output.matmul(weight)
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
grad_output = grad_output.contiguous()
|
if len(grad_output.shape) > 2:
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
if len(grad_output.shape) > 2:
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
# prepare data
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
input_list = [
|
||||||
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||||
# prepare data
|
]
|
||||||
input_list = [
|
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
# wait until all-gather finished
|
||||||
]
|
|
||||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
|
||||||
|
|
||||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
|
||||||
gather_handle.wait()
|
gather_handle.wait()
|
||||||
|
|
||||||
|
# do reduce-scatter in async way
|
||||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
with torch.cuda.stream(calculate_stream):
|
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
# calculate gradient
|
||||||
if len(input_parallel.shape) > 2:
|
if len(input_parallel.shape) > 2:
|
||||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||||
print(grad_output.shape, input_parallel.shape)
|
grad_weight = grad_output.t().matmul(input_parallel)
|
||||||
grad_weight = grad_output.t().matmul(input_parallel)
|
# wait until reduce-scatter finished
|
||||||
|
|
||||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
|
||||||
reducescatter_handle.wait()
|
reducescatter_handle.wait()
|
||||||
|
|
||||||
return output, grad_weight, grad_bias, None, None, None, None
|
return output, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
|
@ -75,7 +75,7 @@ class Linear1D_Col(ParallelModule):
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
seq_parallel: bool = False,
|
seq_parallel: bool = False,
|
||||||
seq_parallel_dim: int = 1,
|
seq_parallel_dim: int = 1,
|
||||||
overlap: bool = False,
|
overlap: torch.cuda.Stream = None,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
|
|
|
@ -56,6 +56,7 @@ class BertPolicy(Policy):
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||||
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
||||||
"attention.self.all_head_size":
|
"attention.self.all_head_size":
|
||||||
|
@ -71,17 +72,26 @@ class BertPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.self.query",
|
suffix="attention.self.query",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={
|
||||||
|
"seq_parallel": use_sequence_parallel,
|
||||||
|
"overlap": overlap
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.self.key",
|
suffix="attention.self.key",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={
|
||||||
|
"seq_parallel": use_sequence_parallel,
|
||||||
|
"overlap": overlap
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.self.value",
|
suffix="attention.self.value",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={
|
||||||
|
"seq_parallel": use_sequence_parallel,
|
||||||
|
"overlap": overlap
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.self.dropout",
|
suffix="attention.self.dropout",
|
||||||
|
@ -99,7 +109,10 @@ class BertPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="intermediate.dense",
|
suffix="intermediate.dense",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={
|
||||||
|
"seq_parallel": use_sequence_parallel,
|
||||||
|
"overlap": overlap
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dense",
|
suffix="output.dense",
|
||||||
|
|
|
@ -45,6 +45,7 @@ class BloomPolicy(Policy):
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||||
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
|
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
|
||||||
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
@ -55,7 +56,10 @@ class BloomPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.query_key_value",
|
suffix="self_attention.query_key_value",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={'seq_parallel': use_sequence_parallel}),
|
kwargs={
|
||||||
|
'seq_parallel': use_sequence_parallel,
|
||||||
|
'overlap': overlap
|
||||||
|
}),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.dense",
|
suffix="self_attention.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
|
@ -67,7 +71,10 @@ class BloomPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dense_h_to_4h",
|
suffix="mlp.dense_h_to_4h",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={'seq_parallel': use_sequence_parallel}),
|
kwargs={
|
||||||
|
'seq_parallel': use_sequence_parallel,
|
||||||
|
'overlap': overlap
|
||||||
|
}),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dense_4h_to_h",
|
suffix="mlp.dense_4h_to_h",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
|
|
|
@ -50,6 +50,7 @@ class ChatGLMPolicy(Policy):
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||||
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
|
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
@ -81,7 +82,8 @@ class ChatGLMPolicy(Policy):
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
'seq_parallel': use_sequence_parallel,
|
'seq_parallel': use_sequence_parallel,
|
||||||
'seq_parallel_dim': 0
|
'seq_parallel_dim': 0,
|
||||||
|
'overlap': overlap
|
||||||
}),
|
}),
|
||||||
SubModuleReplacementDescription(suffix="self_attention.dense",
|
SubModuleReplacementDescription(suffix="self_attention.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
|
|
|
@ -20,6 +20,8 @@ class ShardConfig:
|
||||||
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
||||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
||||||
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False.
|
||||||
|
enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False.
|
||||||
"""
|
"""
|
||||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||||
|
@ -29,6 +31,7 @@ class ShardConfig:
|
||||||
enable_flash_attention: bool = False
|
enable_flash_attention: bool = False
|
||||||
enable_jit_fused: bool = False
|
enable_jit_fused: bool = False
|
||||||
enable_sequence_parallelism: bool = False
|
enable_sequence_parallelism: bool = False
|
||||||
|
enable_sequence_overlap: bool = False
|
||||||
|
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
# data_parallel_size: int
|
# data_parallel_size: int
|
||||||
|
@ -41,6 +44,11 @@ class ShardConfig:
|
||||||
return self._tensor_parallel_size
|
return self._tensor_parallel_size
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
|
||||||
|
raise ValueError(
|
||||||
|
"enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True")
|
||||||
|
if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
|
||||||
|
raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
|
||||||
if not self.enable_tensor_parallelism:
|
if not self.enable_tensor_parallelism:
|
||||||
self._tensor_parallel_size = 1
|
self._tensor_parallel_size = 1
|
||||||
else:
|
else:
|
||||||
|
@ -59,3 +67,4 @@ class ShardConfig:
|
||||||
self.enable_flash_attention = True
|
self.enable_flash_attention = True
|
||||||
self.enable_jit_fused = True
|
self.enable_jit_fused = True
|
||||||
self.enable_sequence_parallelism = True
|
self.enable_sequence_parallelism = True
|
||||||
|
self.enable_sequence_overlap = True
|
||||||
|
|
|
@ -168,7 +168,7 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||||
|
|
||||||
@parameterize('lazy_init', [False, True])
|
@parameterize('lazy_init', [False, True])
|
||||||
@parameterize('seq_parallel', [False, True])
|
@parameterize('seq_parallel', [False, True])
|
||||||
@parameterize('overlap', [False, True])
|
@parameterize('overlap', [True])
|
||||||
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
|
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
|
||||||
check_linear_1d_col(lazy_init, seq_parallel, overlap)
|
check_linear_1d_col(lazy_init, seq_parallel, overlap)
|
||||||
check_linear_1d_row(lazy_init, seq_parallel)
|
check_linear_1d_row(lazy_init, seq_parallel)
|
||||||
|
|
Loading…
Reference in New Issue