[plugin] hybrid support zero bubble pipeline (#6060)

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>
pull/6075/head
flybird11111 2024-09-27 14:48:55 +08:00 committed by GitHub
parent b804fdc297
commit af6aa9ed06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 140 additions and 53 deletions

View File

@ -140,7 +140,7 @@ jobs:
- name: Install Colossal-AI - name: Install Colossal-AI
run: | run: |
BUILD_EXT=1 pip install -v -e . BUILD_EXT=1 pip install -v .
pip install --no-cache-dir -r requirements/requirements-test.txt pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache - name: Store Colossal-AI Cache

View File

@ -55,7 +55,7 @@ jobs:
if: steps.check-avai.outputs.avai == 'true' if: steps.check-avai.outputs.avai == 'true'
run: | run: |
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
BUILD_EXT=1 pip install -v -e . BUILD_EXT=1 pip install -v .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
pip install --no-cache-dir -r requirements/requirements-test.txt pip install --no-cache-dir -r requirements/requirements-test.txt

View File

@ -43,7 +43,7 @@ class MixedPrecisionMixin(ABC):
dtype: torch.dtype dtype: torch.dtype
@abstractmethod @abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor: def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
"""Called before backward. """Called before backward.
Args: Args:

View File

@ -85,13 +85,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
master_params.append(master_p) master_params.append(master_p)
group["params"] = master_params group["params"] = master_params
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
loss = self.mixed_precision.pre_backward(loss) loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs) loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad) torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)
def zero_grad(self, *args, **kwargs): def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys(): for p in self.working_to_master_map.keys():

View File

@ -46,9 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
growth_interval=growth_interval, growth_interval=growth_interval,
) )
def backward(self, loss: Tensor, *args, **kwargs) -> None: def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
scaled_loss = self.scale_loss(loss) scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs) scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]: def step(self, *args, **kwargs) -> Optional[float]:
out = self.scaler.step(self.optim, *args, **kwargs) out = self.scaler.step(self.optim, *args, **kwargs)

View File

@ -28,7 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@ -288,7 +288,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
super().__init__(optim) super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r""" r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -306,7 +306,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs) super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -315,7 +315,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
# If gradient synchronization is is not required, return. # If gradient synchronization is is not required, return.
return return
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
""" """
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -332,7 +332,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -512,7 +512,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
max_norm=max_norm, max_norm=max_norm,
) )
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r""" r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -529,7 +529,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs) super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -538,7 +538,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# If gradient synchronization is is not required, return. # If gradient synchronization is is not required, return.
return return
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
""" """
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -554,7 +554,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -768,7 +768,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
else: else:
return return
def backward(self, loss, retain_graph=False): def backward(self, loss, inputs=None, retain_graph=False):
""" """
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -784,7 +784,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph) super().backward(loss, inputs=inputs, retain_graph=retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -793,7 +793,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# If gradient synchronization is is not required, return. # If gradient synchronization is is not required, return.
return return
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
""" """
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -809,7 +809,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
None None
""" """
# Call the superclass backward_by_grad method to compute gradients. # Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -1013,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None, custom_policy: Policy = None,
pp_style: str = "1f1b", pp_style: str = "1f1b",
num_model_chunks: int = 1, num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None, num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
@ -1029,6 +1030,9 @@ class HybridParallelPlugin(PipelinePluginBase):
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline."
if enable_sequence_parallelism: if enable_sequence_parallelism:
self.sequence_parallelism_mode = ( self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
@ -1088,29 +1092,39 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.scheduler = None
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert ( assert (
num_microbatches is not None or microbatch_size is not None num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert ( assert (
self.zero_stage <= 1 self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
if pp_style == "zbv":
self.logger.warning(
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
)
self.stage_manager = PipelineStageManager( self.stage_manager = PipelineStageManager(
self.pg_mesh, self.pg_mesh,
pipeline_axis=self.pp_axis, pipeline_axis=self.pp_axis,
enable_interleave=(pp_style == "interleaved"), enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
use_zbv=(pp_style == "zbv"),
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage, num_layers_per_stage=num_layers_per_stage,
) )
if pp_style == "interleaved": if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule( self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager, stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches, num_microbatch=num_microbatches,
@ -1119,12 +1133,20 @@ class HybridParallelPlugin(PipelinePluginBase):
overlap_p2p=overlap_p2p, overlap_p2p=overlap_p2p,
) )
elif pp_style == "1f1b": elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule( self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager, stage_manager=self.stage_manager,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache, enable_metadata_cache=enable_metadata_cache,
) )
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else: else:
raise NotImplementedError() raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn": if sequence_parallelism_mode == "ring_attn":
@ -1236,7 +1258,6 @@ class HybridParallelPlugin(PipelinePluginBase):
# Replace with distributed implementation if exists # Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
self.logger.warning( self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
@ -1352,7 +1373,7 @@ class HybridParallelPlugin(PipelinePluginBase):
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx, model._wait_all_gather(): with ctx, model._wait_all_gather():
outputs = self.schedule.forward_backward_step( outputs = self.scheduler.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs model, data_iter, criterion, optimizer, return_loss, return_outputs
) )

View File

@ -280,7 +280,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.scheduler = None
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
@ -304,7 +304,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if pp_style == "interleaved": if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule( self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager, stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches, num_microbatch=num_microbatches,
@ -313,7 +313,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_p2p=overlap_p2p, overlap_p2p=overlap_p2p,
) )
elif pp_style == "1f1b": elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule( self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager, stage_manager=self.stage_manager,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
microbatch_size=microbatch_size, microbatch_size=microbatch_size,

View File

@ -49,11 +49,11 @@ class OptimizerWrapper:
""" """
self.optim.zero_grad(*args, **kwargs) self.optim.zero_grad(*args, **kwargs)
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
""" """
Performs a backward pass on the loss. Performs a backward pass on the loss.
""" """
loss.backward(*args, **kwargs) loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
""" """

View File

@ -136,7 +136,11 @@ class PipelineStageManager:
if not self.is_interleave or ignore_chunk: if not self.is_interleave or ignore_chunk:
return self.stage == self.num_stages - 1 return self.stage == self.num_stages - 1
else: else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 # use zero bubble pipeline
if self.use_zbv:
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
@property @property
def num_stages(self) -> int: def num_stages(self) -> int:

View File

@ -261,7 +261,9 @@ class LlamaPolicy(Policy):
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices: for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.norm)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm) held_layers.append(module.norm)
else: else:
@ -351,7 +353,9 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
return held_layers return held_layers
@ -404,7 +408,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.score)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score) held_layers.append(self.model.score)
return held_layers return held_layers

View File

@ -373,7 +373,7 @@ class GeminiDDP(ModelWrapper):
loss.backward() loss.backward()
self._post_backward() self._post_backward()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
@staticmethod @staticmethod

View File

@ -298,12 +298,14 @@ class GeminiOptimizer(OptimizerWrapper):
loss = self.mix_precision_mixin.pre_backward(loss) loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss) self.module.backward(loss)
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): def backward_by_grad(
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
):
# This function is called except the last stage of pipeline parallel # This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank # It receives the scaled grad from the previous rank
# No need to scale the grad again # No need to scale the grad again
# Need to unscale when optimizing # Need to unscale when optimizing
grad = self.mix_precision_mixin.pre_backward_by_grad(grad) grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
self.module.backward_by_grad(tensor, grad) self.module.backward_by_grad(tensor, grad)
def _maybe_move_fp32_params(self): def _maybe_move_fp32_params(self):

View File

@ -408,7 +408,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# torch.optim.Optimizer methods # torch.optim.Optimizer methods
################################ ################################
def backward(self, loss, retain_graph=False): def backward(self, loss, inputs=None, retain_graph=False):
assert not ( assert not (
self._partition_grads and not self.require_grad_sync self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and no_sync are not compatible" ), "ZeRO2(partition_grads) and no_sync are not compatible"
@ -416,7 +416,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.mixed_precision_mixin is not None: if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss) loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph) loss.backward(inputs=inputs, retain_graph=retain_graph)
if not self.require_grad_sync: if not self.require_grad_sync:
return return
@ -427,14 +427,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self._overlap_communication: if self._overlap_communication:
get_accelerator().synchronize() get_accelerator().synchronize()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
assert not ( assert not (
self._partition_grads and not self.require_grad_sync self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None: if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad) torch.autograd.backward(
tensor,
grad,
inputs=inputs,
retain_graph=retain_graph,
)
if not self.require_grad_sync: if not self.require_grad_sync:
return return

View File

@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin(
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn criterion = loss_fn
plugin = pluggin_cls(**test_config) plugin = pluggin_cls(**test_config)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
@ -311,8 +310,16 @@ def check_output_hidden_state(
): ):
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): if stage_manager:
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else:
sharded_hidden_state = sharded_output.last_hidden_state
elif stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else:
sharded_hidden_state = sharded_output.last_hidden_state
else: else:
sharded_hidden_state = sharded_output.last_hidden_state sharded_hidden_state = sharded_output.last_hidden_state
@ -390,7 +397,6 @@ def get_grad_tensors_for_check(
pass pass
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}") print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
grad_to_check[suffix] = { grad_to_check[suffix] = {
"org_grad": org_grad.float(), "org_grad": org_grad.float(),
"shard_grad": shard_grad.float(), "shard_grad": shard_grad.float(),

View File

@ -7,6 +7,7 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer import PipelineGradientCheckpointConfig
from colossalai.shardformer.layer.utils import Randomizer from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.api import clear_layout_converter
@ -33,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
if enable_gradient_checkpointing: if enable_gradient_checkpointing:
# org_model.gradient_checkpointing_enable() # org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
@ -112,12 +113,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step() sharded_optimizer.step()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): check_flag = False
if stage_manager is None:
check_flag = True
else:
if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True):
check_flag = True
elif stage_manager.is_last_stage(ignore_chunk=True):
check_flag = True
if check_flag:
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "LlamaModel": if org_model.__class__.__name__ == "LlamaModel":
check_output_hidden_state( check_output_hidden_state(
org_output, org_output,
@ -282,10 +291,39 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "zbv",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 0,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"parallel_output": False,
},
], ],
) )
def run_llama_test(test_config): def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
if test_config.get("pp_style", None) == "zbv":
mem_f = 34 * 32 + 5 * 4 * 16
mem_w = -32 * 32
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=test_config["pp_size"],
n_micro=test_config["num_microbatches"],
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
test_config["scheduler_nodes"] = scheduler_nodes
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
continue continue