mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI into dev/zero_bubble
commit
52dcc73313
|
@ -1166,22 +1166,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
num_microbatch=num_microbatches,
|
num_microbatch=num_microbatches,
|
||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
)
|
)
|
||||||
elif pp_style == "zbv":
|
|
||||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
|
||||||
stage_manager=self.stage_manager,
|
|
||||||
schedule=scheduler_nodes,
|
|
||||||
num_model_chunks=num_model_chunks,
|
|
||||||
num_microbatch=num_microbatches,
|
|
||||||
microbatch_size=microbatch_size,
|
|
||||||
)
|
|
||||||
elif pp_style == "zbv":
|
|
||||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
|
||||||
stage_manager=self.stage_manager,
|
|
||||||
schedule=scheduler_nodes,
|
|
||||||
num_model_chunks=num_model_chunks,
|
|
||||||
num_microbatch=num_microbatches,
|
|
||||||
microbatch_size=microbatch_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if sequence_parallelism_mode == "ring_attn":
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
|
|
|
@ -289,9 +289,9 @@ class LlamaPolicy(Policy):
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
for start_idx, end_idx in stage_indices:
|
for start_idx, end_idx in stage_indices:
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||||
held_layers.append(module.norm)
|
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
):
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -383,13 +383,15 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||||
held_layers.append(self.model.lm_head)
|
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
):
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
|
||||||
|
return []
|
||||||
llama_model = self.model.model
|
llama_model = self.model.model
|
||||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||||
if (
|
if (
|
||||||
|
@ -443,9 +445,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||||
held_layers.append(self.model.score)
|
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
):
|
||||||
held_layers.append(self.model.score)
|
held_layers.append(self.model.score)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
@ -72,6 +73,9 @@ class MlpModel(nn.Module):
|
||||||
else:
|
else:
|
||||||
return {"hidden_states": held_layers(hidden_states)}
|
return {"hidden_states": held_layers(hidden_states)}
|
||||||
|
|
||||||
|
def no_sync(self):
|
||||||
|
return nullcontext()
|
||||||
|
|
||||||
|
|
||||||
def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups):
|
def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups):
|
||||||
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
|
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
|
||||||
|
|
|
@ -114,13 +114,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
|
|
||||||
# check last hidden state & loss
|
# check last hidden state & loss
|
||||||
check_flag = False
|
check_flag = False
|
||||||
if stage_manager is None:
|
if (
|
||||||
check_flag = True
|
(stage_manager is None)
|
||||||
else:
|
or (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True))
|
||||||
if stage_manager.use_zbv:
|
or (not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True))
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
):
|
||||||
check_flag = True
|
|
||||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
|
||||||
check_flag = True
|
check_flag = True
|
||||||
if check_flag:
|
if check_flag:
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
|
@ -292,6 +290,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"enable_gradient_checkpointing": True,
|
"enable_gradient_checkpointing": True,
|
||||||
"parallel_output": False,
|
"parallel_output": False,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"pp_style": "zbv",
|
||||||
|
"num_model_chunks": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"parallel_output": False,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_llama_test(test_config):
|
def run_llama_test(test_config):
|
||||||
|
|
Loading…
Reference in New Issue