[feat] moehybrid support zerobubble;

pull/6065/head
duanjunwen 3 months ago
parent 6c2a120bed
commit 9bc3b6e220

@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
@ -207,6 +208,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
@ -282,8 +284,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style == "interleaved" or pp_style == "zbv"
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
@ -293,7 +297,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
@ -315,6 +319,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "zbv":
self.schedule = ZeroBubbleVPipeScheduler(
schedule=scheduler_nodes,
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
overlap_p2p=overlap_p2p,
)
else:
raise NotImplementedError()

@ -14,6 +14,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@ -724,23 +725,83 @@ def run_with_hybridplugin(test_config):
"test_config",
[
{
"batch_size": 8,
"pp_style": "zbv",
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
"num_model_chunks": 2,
},
],
)
def run_with_moehybridplugin(test_config):
model_zoo.get_sub_registry("transformers_bert")
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["initial_scale"] = 2**16
model_list = [
"transformers_bert",
]
clear_layout_converter()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
# base param
model = model_fn()
data = data_gen_fn()
criterion = loss_fn
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
output = model(**data)
loss = criterion(output)
loss.backward()
optimizer.step()
print(f"output {output}")
# # pp param
# model_pp = deepcopy(model)
# data_pp = deepcopy(data)
# optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
# # init pipeline graph
# h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
# mem_f = 34 * h + 5 * a * s
# mem_w = -32 * h
# mem_b = -mem_w - mem_f
# graph = PipelineGraph(
# n_stage=test_config["pp_size"],
# n_micro=test_config["num_microbatches"],
# f_cost=1,
# b_cost=1,
# w_cost=1,
# c_cost=1,
# f_mem=mem_f,
# b_mem=mem_b,
# w_mem=mem_w,
# # max_mem=mem_f * (p * 2 + m_offset),
# )
# zbv_schedule = graph.get_v_schedule()
# test_config["scheduler_nodes"] = zbv_schedule
# plugin = MoeHybridParallelPlugin(
# **test_config
# )
# model_pp, optimizer_pp, criterion, data_pp = plugin.configure(
# model = model_pp,
# optimizer = optimizer_pp,
# criterion = criterion,
# dataloader = data_pp,
# )
# output_pp = plugin.execute_pipeline(
# data_iter=iter(data),
# model=model,
# criterion=criterion,
# optimizer=optimizer,
# return_loss = True,
# return_outputs = True,
# )
# TODO:6) support booster & Hybrid base 4)
@ -752,8 +813,9 @@ def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# run_fwd_bwd_iter_input()
run_fwd_bwd_vschedule_with_optim()
# run_fwd_bwd_vschedule_with_optim()
# run_with_moehybridplugin()
run_with_moehybridplugin()
@pytest.mark.dist

Loading…
Cancel
Save