[fix] rm zbv in hybridplugin

pull/6034/head
duanjunwen 2024-09-02 10:00:43 +00:00
parent 6d18d38d5c
commit 77fe44286c
2 changed files with 65 additions and 44 deletions

View File

@ -28,8 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@ -1093,10 +1092,8 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert ( assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
pp_style == "interleaved" or pp_style == "zbv"
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert ( assert (
num_microbatches is not None or microbatch_size is not None num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
@ -1106,7 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.stage_manager = PipelineStageManager( self.stage_manager = PipelineStageManager(
self.pg_mesh, self.pg_mesh,
pipeline_axis=self.pp_axis, pipeline_axis=self.pp_axis,
enable_interleave=(pp_style == "interleaved") or (pp_style == "zbv"), enable_interleave=(pp_style == "interleaved"),
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage, num_layers_per_stage=num_layers_per_stage,
) )
@ -1128,31 +1125,6 @@ class HybridParallelPlugin(PipelinePluginBase):
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache, enable_metadata_cache=enable_metadata_cache,
) )
elif pp_style == "zbv":
h, a, s = 4096, 32, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
zbv_schedule = PipelineGraph(
n_stage=self.pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
self.schedule = ZeroBubbleVPipeScheduler(
schedule=zbv_schedule,
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
)
else: else:
raise NotImplementedError() raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn": if sequence_parallelism_mode == "ring_attn":

View File

@ -14,7 +14,16 @@ from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
class MlpModel(nn.Module): class MlpModel(nn.Module):
@ -679,6 +688,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# TODO:4) support Hybrid base 3) # TODO:4) support Hybrid base 3)
def run_with_hybridplugin(test_config):
pass
# TODO:5) support MoEHybrid base 3)
@parameterize( @parameterize(
"test_config", "test_config",
[ [
@ -693,20 +707,55 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
}, },
], ],
) )
def run_with_hybridplugin(test_config): def run_with_moehybridplugin(test_config):
pass sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
# TODO:5) support MoEHybrid base 3) stage_manager = booster.plugin.stage_manager
def run_with_moehybridplugin( tp_group = booster.plugin.tp_group
rank: int,
world_size: int, bert = unwrap_model(org_model, "BertModel", "bert")
port: int, sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
num_microbatch: int, weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
batch_size: int,
num_model_chunk: int, org_optimizer.step()
): sharded_optimizer.step()
pass
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-4, 5e-4
else:
atol, rtol = 5e-4, 5e-4
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check optim states
# check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Bert Model Zoo Test Passed")
# TODO:6) support booster & Hybrid base 4) # TODO:6) support booster & Hybrid base 4)