mirror of https://github.com/hpcaitech/ColossalAI
[fix] rm zbv in hybridplugin
parent
6d18d38d5c
commit
77fe44286c
|
@ -28,8 +28,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
|
@ -1093,10 +1092,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||||
assert (
|
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||||
pp_style == "interleaved" or pp_style == "zbv"
|
|
||||||
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||||
|
@ -1106,7 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.stage_manager = PipelineStageManager(
|
self.stage_manager = PipelineStageManager(
|
||||||
self.pg_mesh,
|
self.pg_mesh,
|
||||||
pipeline_axis=self.pp_axis,
|
pipeline_axis=self.pp_axis,
|
||||||
enable_interleave=(pp_style == "interleaved") or (pp_style == "zbv"),
|
enable_interleave=(pp_style == "interleaved"),
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_layers_per_stage=num_layers_per_stage,
|
num_layers_per_stage=num_layers_per_stage,
|
||||||
)
|
)
|
||||||
|
@ -1128,31 +1125,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
)
|
)
|
||||||
elif pp_style == "zbv":
|
|
||||||
h, a, s = 4096, 32, 1024
|
|
||||||
mem_f = 34 * h + 5 * a * s
|
|
||||||
mem_w = -32 * h
|
|
||||||
mem_b = -mem_w - mem_f
|
|
||||||
zbv_schedule = PipelineGraph(
|
|
||||||
n_stage=self.pp_size,
|
|
||||||
n_micro=num_microbatches,
|
|
||||||
f_cost=1,
|
|
||||||
b_cost=1,
|
|
||||||
w_cost=1,
|
|
||||||
c_cost=1,
|
|
||||||
f_mem=mem_f,
|
|
||||||
b_mem=mem_b,
|
|
||||||
w_mem=mem_w,
|
|
||||||
).get_v_schedule()
|
|
||||||
self.schedule = ZeroBubbleVPipeScheduler(
|
|
||||||
schedule=zbv_schedule,
|
|
||||||
stage_manager=self.stage_manager,
|
|
||||||
num_model_chunks=num_model_chunks,
|
|
||||||
num_microbatch=num_microbatches,
|
|
||||||
microbatch_size=microbatch_size,
|
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
|
||||||
overlap_p2p=overlap_p2p,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if sequence_parallelism_mode == "ring_attn":
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
|
|
|
@ -14,7 +14,16 @@ from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_weight,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
|
@ -679,6 +688,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
|
|
||||||
|
|
||||||
# TODO:4) support Hybrid base 3)
|
# TODO:4) support Hybrid base 3)
|
||||||
|
def run_with_hybridplugin(test_config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# TODO:5) support MoEHybrid base 3)
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
@ -693,20 +707,55 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_hybridplugin(test_config):
|
def run_with_moehybridplugin(test_config):
|
||||||
pass
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
|
test_config["use_lazy_init"] = False
|
||||||
|
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
|
||||||
|
test_config["initial_scale"] = 2**16 # avoid overflow
|
||||||
|
model_list = [
|
||||||
|
"transformers_bert",
|
||||||
|
]
|
||||||
|
clear_layout_converter()
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
if name in model_list:
|
||||||
|
(
|
||||||
|
org_model,
|
||||||
|
org_optimizer,
|
||||||
|
sharded_model,
|
||||||
|
sharded_optimizer,
|
||||||
|
criterion,
|
||||||
|
booster,
|
||||||
|
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
|
||||||
|
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
|
||||||
# TODO:5) support MoEHybrid base 3)
|
stage_manager = booster.plugin.stage_manager
|
||||||
def run_with_moehybridplugin(
|
tp_group = booster.plugin.tp_group
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
bert = unwrap_model(org_model, "BertModel", "bert")
|
||||||
port: int,
|
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
||||||
num_microbatch: int,
|
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
|
||||||
batch_size: int,
|
|
||||||
num_model_chunk: int,
|
org_optimizer.step()
|
||||||
):
|
sharded_optimizer.step()
|
||||||
pass
|
|
||||||
|
# check weights
|
||||||
|
if test_config["precision"] == "bf16":
|
||||||
|
atol, rtol = 5e-4, 5e-4
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-4, 5e-4
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
||||||
|
# check optim states
|
||||||
|
# check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print(f"Bert Model Zoo Test Passed")
|
||||||
|
|
||||||
|
|
||||||
# TODO:6) support booster & Hybrid base 4)
|
# TODO:6) support booster & Hybrid base 4)
|
||||||
|
|
Loading…
Reference in New Issue