Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI into dev/zero_bubble

pull/6083/head
duanjunwen 2024-10-15 06:31:45 +00:00
commit 52dcc73313
4 changed files with 33 additions and 32 deletions

View File

@ -1166,22 +1166,6 @@ class HybridParallelPlugin(PipelinePluginBase):
num_microbatch=num_microbatches, num_microbatch=num_microbatches,
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
) )
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else: else:
raise NotImplementedError() raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn": if sequence_parallelism_mode == "ring_attn":

View File

@ -289,9 +289,9 @@ class LlamaPolicy(Policy):
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices: for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
held_layers.append(module.norm) not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
elif stage_manager.is_last_stage(ignore_chunk=True): ):
held_layers.append(module.norm) held_layers.append(module.norm)
else: else:
@ -383,13 +383,15 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
held_layers.append(self.model.lm_head) not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
elif stage_manager.is_last_stage(ignore_chunk=True): ):
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
return []
llama_model = self.model.model llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if ( if (
@ -443,9 +445,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
held_layers.append(self.model.score) not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
elif stage_manager.is_last_stage(ignore_chunk=True): ):
held_layers.append(self.model.score) held_layers.append(self.model.score)
return held_layers return held_layers

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
@ -72,6 +73,9 @@ class MlpModel(nn.Module):
else: else:
return {"hidden_states": held_layers(hidden_states)} return {"hidden_states": held_layers(hidden_states)}
def no_sync(self):
return nullcontext()
def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups): def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups):
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):

View File

@ -114,13 +114,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss # check last hidden state & loss
check_flag = False check_flag = False
if stage_manager is None: if (
check_flag = True (stage_manager is None)
else: or (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True))
if stage_manager.use_zbv: or (not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True))
if stage_manager.is_first_stage(ignore_chunk=True): ):
check_flag = True
elif stage_manager.is_last_stage(ignore_chunk=True):
check_flag = True check_flag = True
if check_flag: if check_flag:
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
@ -292,6 +290,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"enable_gradient_checkpointing": True, "enable_gradient_checkpointing": True,
"parallel_output": False, "parallel_output": False,
}, },
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "zbv",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"parallel_output": False,
},
], ],
) )
def run_llama_test(test_config): def run_llama_test(test_config):