mirror of https://github.com/hpcaitech/ColossalAI
[feat] moehybrid support zerobubble;
parent
6c2a120bed
commit
9bc3b6e220
|
@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn.optimizer import cast_to_distributed
|
||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
||||
|
@ -207,6 +208,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
scheduler_nodes: List = None,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
|
@ -282,8 +284,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||
assert (
|
||||
pp_style == "interleaved" or pp_style == "zbv"
|
||||
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
|
@ -293,7 +297,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_layers_per_stage=num_layers_per_stage,
|
||||
)
|
||||
|
@ -315,6 +319,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
self.schedule = ZeroBubbleVPipeScheduler(
|
||||
schedule=scheduler_nodes,
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from colossalai.logging import disable_existing_loggers
|
|||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
@ -724,23 +725,83 @@ def run_with_hybridplugin(test_config):
|
|||
"test_config",
|
||||
[
|
||||
{
|
||||
"batch_size": 8,
|
||||
"pp_style": "zbv",
|
||||
"tp_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"precision": "bf16",
|
||||
"num_model_chunk": 2,
|
||||
"num_model_chunks": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_with_moehybridplugin(test_config):
|
||||
model_zoo.get_sub_registry("transformers_bert")
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||
test_config["use_lazy_init"] = False
|
||||
test_config["initial_scale"] = 2**16
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
]
|
||||
clear_layout_converter()
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name in model_list:
|
||||
# base param
|
||||
model = model_fn()
|
||||
data = data_gen_fn()
|
||||
criterion = loss_fn
|
||||
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
|
||||
|
||||
output = model(**data)
|
||||
loss = criterion(output)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(f"output {output}")
|
||||
|
||||
# # pp param
|
||||
# model_pp = deepcopy(model)
|
||||
# data_pp = deepcopy(data)
|
||||
# optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
|
||||
|
||||
# # init pipeline graph
|
||||
# h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
|
||||
# mem_f = 34 * h + 5 * a * s
|
||||
# mem_w = -32 * h
|
||||
# mem_b = -mem_w - mem_f
|
||||
# graph = PipelineGraph(
|
||||
# n_stage=test_config["pp_size"],
|
||||
# n_micro=test_config["num_microbatches"],
|
||||
# f_cost=1,
|
||||
# b_cost=1,
|
||||
# w_cost=1,
|
||||
# c_cost=1,
|
||||
# f_mem=mem_f,
|
||||
# b_mem=mem_b,
|
||||
# w_mem=mem_w,
|
||||
# # max_mem=mem_f * (p * 2 + m_offset),
|
||||
# )
|
||||
|
||||
# zbv_schedule = graph.get_v_schedule()
|
||||
|
||||
# test_config["scheduler_nodes"] = zbv_schedule
|
||||
# plugin = MoeHybridParallelPlugin(
|
||||
# **test_config
|
||||
# )
|
||||
# model_pp, optimizer_pp, criterion, data_pp = plugin.configure(
|
||||
# model = model_pp,
|
||||
# optimizer = optimizer_pp,
|
||||
# criterion = criterion,
|
||||
# dataloader = data_pp,
|
||||
# )
|
||||
|
||||
# output_pp = plugin.execute_pipeline(
|
||||
# data_iter=iter(data),
|
||||
# model=model,
|
||||
# criterion=criterion,
|
||||
# optimizer=optimizer,
|
||||
# return_loss = True,
|
||||
# return_outputs = True,
|
||||
# )
|
||||
|
||||
|
||||
# TODO:6) support booster & Hybrid base 4)
|
||||
|
@ -752,8 +813,9 @@ def run_dist(rank, world_size, port):
|
|||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
# run_fwd_bwd_iter_input()
|
||||
run_fwd_bwd_vschedule_with_optim()
|
||||
# run_fwd_bwd_vschedule_with_optim()
|
||||
# run_with_moehybridplugin()
|
||||
run_with_moehybridplugin()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue