mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] Add overlap optional for HybridParallelPlugin (#4615)
* add optional overlap for plugin * remove fixed todopull/4620/head
parent
a39a5c66fe
commit
86d22581e4
|
@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
|
@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism)
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
|
|
@ -180,7 +180,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
overlap = ctx.overlap
|
||||
|
||||
if not overlap:
|
||||
# TODO: overlap SP input with gradient computation
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
|
||||
total_input = input_parallel
|
||||
|
@ -191,7 +190,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
# TODO: overlap SP input with gradient computation
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
input_list = [
|
||||
|
|
Loading…
Reference in New Issue