mirror of https://github.com/hpcaitech/ColossalAI
[plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <935724073@qq.com>pull/6075/head
parent
b804fdc297
commit
af6aa9ed06
|
@ -140,7 +140,7 @@ jobs:
|
||||||
|
|
||||||
- name: Install Colossal-AI
|
- name: Install Colossal-AI
|
||||||
run: |
|
run: |
|
||||||
BUILD_EXT=1 pip install -v -e .
|
BUILD_EXT=1 pip install -v .
|
||||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||||
|
|
||||||
- name: Store Colossal-AI Cache
|
- name: Store Colossal-AI Cache
|
||||||
|
|
|
@ -55,7 +55,7 @@ jobs:
|
||||||
if: steps.check-avai.outputs.avai == 'true'
|
if: steps.check-avai.outputs.avai == 'true'
|
||||||
run: |
|
run: |
|
||||||
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
||||||
BUILD_EXT=1 pip install -v -e .
|
BUILD_EXT=1 pip install -v .
|
||||||
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class MixedPrecisionMixin(ABC):
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def pre_backward(self, loss: Tensor) -> Tensor:
|
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
|
||||||
"""Called before backward.
|
"""Called before backward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -85,13 +85,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||||
master_params.append(master_p)
|
master_params.append(master_p)
|
||||||
group["params"] = master_params
|
group["params"] = master_params
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
loss = self.mixed_precision.pre_backward(loss)
|
loss = self.mixed_precision.pre_backward(loss)
|
||||||
loss.backward(*args, **kwargs)
|
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
||||||
tensor.backward(grad)
|
torch.autograd.backward(
|
||||||
|
tensors=tensor,
|
||||||
|
grad_tensors=grad,
|
||||||
|
inputs=inputs,
|
||||||
|
retain_graph=retain_graph,
|
||||||
|
)
|
||||||
|
|
||||||
def zero_grad(self, *args, **kwargs):
|
def zero_grad(self, *args, **kwargs):
|
||||||
for p in self.working_to_master_map.keys():
|
for p in self.working_to_master_map.keys():
|
||||||
|
|
|
@ -46,9 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
||||||
growth_interval=growth_interval,
|
growth_interval=growth_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
|
||||||
scaled_loss = self.scale_loss(loss)
|
scaled_loss = self.scale_loss(loss)
|
||||||
scaled_loss.backward(*args, **kwargs)
|
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
def step(self, *args, **kwargs) -> Optional[float]:
|
def step(self, *args, **kwargs) -> Optional[float]:
|
||||||
out = self.scaler.step(self.optim, *args, **kwargs)
|
out = self.scaler.step(self.optim, *args, **kwargs)
|
||||||
|
|
|
@ -28,7 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
|
@ -288,7 +288,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
@ -306,7 +306,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, *args, **kwargs)
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
@ -315,7 +315,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
@ -332,7 +332,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
@ -512,7 +512,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
@ -529,7 +529,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, *args, **kwargs)
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
@ -538,7 +538,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
@ -554,7 +554,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
@ -768,7 +768,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=False):
|
def backward(self, loss, inputs=None, retain_graph=False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
@ -784,7 +784,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, retain_graph)
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
@ -793,7 +793,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
@ -809,7 +809,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward_by_grad method to compute gradients.
|
# Call the superclass backward_by_grad method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
@ -1013,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
custom_policy: Policy = None,
|
custom_policy: Policy = None,
|
||||||
pp_style: str = "1f1b",
|
pp_style: str = "1f1b",
|
||||||
num_model_chunks: int = 1,
|
num_model_chunks: int = 1,
|
||||||
|
scheduler_nodes: List = None,
|
||||||
num_layers_per_stage: Optional[List[int]] = None,
|
num_layers_per_stage: Optional[List[int]] = None,
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||||
enable_metadata_cache: bool = True,
|
enable_metadata_cache: bool = True,
|
||||||
|
@ -1029,6 +1030,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
not pp_style == "zbv" or scheduler_nodes is not None
|
||||||
|
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
||||||
if enable_sequence_parallelism:
|
if enable_sequence_parallelism:
|
||||||
self.sequence_parallelism_mode = (
|
self.sequence_parallelism_mode = (
|
||||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||||
|
@ -1088,29 +1092,39 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.scheduler = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
assert (
|
||||||
|
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
|
||||||
|
), "num_model_chunks must be 1 when using 1f1b"
|
||||||
|
assert (
|
||||||
|
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
|
||||||
|
), "num_model_chunks must be 2 when using zero bubble pipeline"
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||||
assert (
|
assert (
|
||||||
self.zero_stage <= 1
|
self.zero_stage <= 1
|
||||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||||
|
if pp_style == "zbv":
|
||||||
|
self.logger.warning(
|
||||||
|
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
|
||||||
|
)
|
||||||
self.stage_manager = PipelineStageManager(
|
self.stage_manager = PipelineStageManager(
|
||||||
self.pg_mesh,
|
self.pg_mesh,
|
||||||
pipeline_axis=self.pp_axis,
|
pipeline_axis=self.pp_axis,
|
||||||
enable_interleave=(pp_style == "interleaved"),
|
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||||
|
use_zbv=(pp_style == "zbv"),
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_layers_per_stage=num_layers_per_stage,
|
num_layers_per_stage=num_layers_per_stage,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pp_style == "interleaved":
|
if pp_style == "interleaved":
|
||||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||||
self.schedule = InterleavedSchedule(
|
self.scheduler = InterleavedSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_microbatch=num_microbatches,
|
num_microbatch=num_microbatches,
|
||||||
|
@ -1119,12 +1133,20 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
overlap_p2p=overlap_p2p,
|
overlap_p2p=overlap_p2p,
|
||||||
)
|
)
|
||||||
elif pp_style == "1f1b":
|
elif pp_style == "1f1b":
|
||||||
self.schedule = OneForwardOneBackwardSchedule(
|
self.scheduler = OneForwardOneBackwardSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
)
|
)
|
||||||
|
elif pp_style == "zbv":
|
||||||
|
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||||
|
stage_manager=self.stage_manager,
|
||||||
|
schedule=scheduler_nodes,
|
||||||
|
num_model_chunks=num_model_chunks,
|
||||||
|
num_microbatch=num_microbatches,
|
||||||
|
microbatch_size=microbatch_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if sequence_parallelism_mode == "ring_attn":
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
|
@ -1236,7 +1258,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
|
|
||||||
# Replace with distributed implementation if exists
|
# Replace with distributed implementation if exists
|
||||||
optimizer = cast_to_distributed(optimizer)
|
optimizer = cast_to_distributed(optimizer)
|
||||||
|
|
||||||
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
||||||
|
@ -1352,7 +1373,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||||
|
|
||||||
with ctx, model._wait_all_gather():
|
with ctx, model._wait_all_gather():
|
||||||
outputs = self.schedule.forward_backward_step(
|
outputs = self.scheduler.forward_backward_step(
|
||||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -280,7 +280,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.scheduler = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
|
@ -304,7 +304,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
|
|
||||||
if pp_style == "interleaved":
|
if pp_style == "interleaved":
|
||||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||||
self.schedule = InterleavedSchedule(
|
self.scheduler = InterleavedSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_microbatch=num_microbatches,
|
num_microbatch=num_microbatches,
|
||||||
|
@ -313,7 +313,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
overlap_p2p=overlap_p2p,
|
overlap_p2p=overlap_p2p,
|
||||||
)
|
)
|
||||||
elif pp_style == "1f1b":
|
elif pp_style == "1f1b":
|
||||||
self.schedule = OneForwardOneBackwardSchedule(
|
self.scheduler = OneForwardOneBackwardSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
|
|
|
@ -49,11 +49,11 @@ class OptimizerWrapper:
|
||||||
"""
|
"""
|
||||||
self.optim.zero_grad(*args, **kwargs)
|
self.optim.zero_grad(*args, **kwargs)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Performs a backward pass on the loss.
|
Performs a backward pass on the loss.
|
||||||
"""
|
"""
|
||||||
loss.backward(*args, **kwargs)
|
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -136,7 +136,11 @@ class PipelineStageManager:
|
||||||
if not self.is_interleave or ignore_chunk:
|
if not self.is_interleave or ignore_chunk:
|
||||||
return self.stage == self.num_stages - 1
|
return self.stage == self.num_stages - 1
|
||||||
else:
|
else:
|
||||||
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
# use zero bubble pipeline
|
||||||
|
if self.use_zbv:
|
||||||
|
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
|
||||||
|
else:
|
||||||
|
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_stages(self) -> int:
|
def num_stages(self) -> int:
|
||||||
|
|
|
@ -261,7 +261,9 @@ class LlamaPolicy(Policy):
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
for start_idx, end_idx in stage_indices:
|
for start_idx, end_idx in stage_indices:
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -351,7 +353,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
@ -404,7 +408,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.score)
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.score)
|
held_layers.append(self.model.score)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
|
|
@ -373,7 +373,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
|
||||||
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
|
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -298,12 +298,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||||
loss = self.mix_precision_mixin.pre_backward(loss)
|
loss = self.mix_precision_mixin.pre_backward(loss)
|
||||||
self.module.backward(loss)
|
self.module.backward(loss)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
def backward_by_grad(
|
||||||
|
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
|
||||||
|
):
|
||||||
# This function is called except the last stage of pipeline parallel
|
# This function is called except the last stage of pipeline parallel
|
||||||
# It receives the scaled grad from the previous rank
|
# It receives the scaled grad from the previous rank
|
||||||
# No need to scale the grad again
|
# No need to scale the grad again
|
||||||
# Need to unscale when optimizing
|
# Need to unscale when optimizing
|
||||||
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
|
grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
self.module.backward_by_grad(tensor, grad)
|
self.module.backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
def _maybe_move_fp32_params(self):
|
def _maybe_move_fp32_params(self):
|
||||||
|
|
|
@ -408,7 +408,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
# torch.optim.Optimizer methods
|
# torch.optim.Optimizer methods
|
||||||
################################
|
################################
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=False):
|
def backward(self, loss, inputs=None, retain_graph=False):
|
||||||
assert not (
|
assert not (
|
||||||
self._partition_grads and not self.require_grad_sync
|
self._partition_grads and not self.require_grad_sync
|
||||||
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
||||||
|
@ -416,7 +416,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
if self.mixed_precision_mixin is not None:
|
if self.mixed_precision_mixin is not None:
|
||||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||||
|
|
||||||
loss.backward(retain_graph=retain_graph)
|
loss.backward(inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if not self.require_grad_sync:
|
if not self.require_grad_sync:
|
||||||
return
|
return
|
||||||
|
@ -427,14 +427,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
if self._overlap_communication:
|
if self._overlap_communication:
|
||||||
get_accelerator().synchronize()
|
get_accelerator().synchronize()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
assert not (
|
assert not (
|
||||||
self._partition_grads and not self.require_grad_sync
|
self._partition_grads and not self.require_grad_sync
|
||||||
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||||
|
|
||||||
if self.mixed_precision_mixin is not None:
|
if self.mixed_precision_mixin is not None:
|
||||||
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
||||||
torch.autograd.backward(tensor, grad)
|
torch.autograd.backward(
|
||||||
|
tensor,
|
||||||
|
grad,
|
||||||
|
inputs=inputs,
|
||||||
|
retain_graph=retain_graph,
|
||||||
|
)
|
||||||
|
|
||||||
if not self.require_grad_sync:
|
if not self.require_grad_sync:
|
||||||
return
|
return
|
||||||
|
|
|
@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin(
|
||||||
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
criterion = loss_fn
|
criterion = loss_fn
|
||||||
|
|
||||||
plugin = pluggin_cls(**test_config)
|
plugin = pluggin_cls(**test_config)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
@ -311,8 +310,16 @@ def check_output_hidden_state(
|
||||||
):
|
):
|
||||||
org_hidden_state = org_output.last_hidden_state
|
org_hidden_state = org_output.last_hidden_state
|
||||||
|
|
||||||
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
|
if stage_manager:
|
||||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
if stage_manager.use_zbv:
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||||
|
else:
|
||||||
|
sharded_hidden_state = sharded_output.last_hidden_state
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||||
|
else:
|
||||||
|
sharded_hidden_state = sharded_output.last_hidden_state
|
||||||
else:
|
else:
|
||||||
sharded_hidden_state = sharded_output.last_hidden_state
|
sharded_hidden_state = sharded_output.last_hidden_state
|
||||||
|
|
||||||
|
@ -390,7 +397,6 @@ def get_grad_tensors_for_check(
|
||||||
pass
|
pass
|
||||||
if verbose and dist.get_rank() == 0:
|
if verbose and dist.get_rank() == 0:
|
||||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||||
|
|
||||||
grad_to_check[suffix] = {
|
grad_to_check[suffix] = {
|
||||||
"org_grad": org_grad.float(),
|
"org_grad": org_grad.float(),
|
||||||
"shard_grad": shard_grad.float(),
|
"shard_grad": shard_grad.float(),
|
||||||
|
|
|
@ -7,6 +7,7 @@ from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
from colossalai.shardformer.layer.utils import Randomizer
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
@ -33,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
)
|
)
|
||||||
if enable_gradient_checkpointing:
|
if enable_gradient_checkpointing:
|
||||||
# org_model.gradient_checkpointing_enable()
|
# org_model.gradient_checkpointing_enable()
|
||||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
@ -112,12 +113,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
sharded_optimizer.step()
|
sharded_optimizer.step()
|
||||||
|
|
||||||
# check last hidden state & loss
|
# check last hidden state & loss
|
||||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
check_flag = False
|
||||||
|
if stage_manager is None:
|
||||||
|
check_flag = True
|
||||||
|
else:
|
||||||
|
if stage_manager.use_zbv:
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
check_flag = True
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
check_flag = True
|
||||||
|
if check_flag:
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-5, 1e-3
|
atol, rtol = 1e-5, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
if org_model.__class__.__name__ == "LlamaModel":
|
if org_model.__class__.__name__ == "LlamaModel":
|
||||||
check_output_hidden_state(
|
check_output_hidden_state(
|
||||||
org_output,
|
org_output,
|
||||||
|
@ -282,10 +291,39 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"pp_style": "zbv",
|
||||||
|
"num_model_chunks": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 0,
|
||||||
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"parallel_output": False,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_llama_test(test_config):
|
def run_llama_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
|
if test_config.get("pp_style", None) == "zbv":
|
||||||
|
mem_f = 34 * 32 + 5 * 4 * 16
|
||||||
|
mem_w = -32 * 32
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
scheduler_nodes = PipelineGraph(
|
||||||
|
n_stage=test_config["pp_size"],
|
||||||
|
n_micro=test_config["num_microbatches"],
|
||||||
|
f_cost=1000,
|
||||||
|
b_cost=1000,
|
||||||
|
w_cost=1000,
|
||||||
|
c_cost=1,
|
||||||
|
f_mem=mem_f,
|
||||||
|
b_mem=mem_b,
|
||||||
|
w_mem=mem_w,
|
||||||
|
).get_v_schedule()
|
||||||
|
test_config["scheduler_nodes"] = scheduler_nodes
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
|
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue