|
|
|
@ -14,7 +14,6 @@ from colossalai.logging import disable_existing_loggers
|
|
|
|
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode |
|
|
|
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler |
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager |
|
|
|
|
from colossalai.tensor.d_tensor.api import clear_layout_converter |
|
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
|
|
|
from tests.kit.model_zoo import model_zoo |
|
|
|
|
|
|
|
|
@ -701,56 +700,13 @@ def run_with_hybridplugin(test_config):
|
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
def run_with_moehybridplugin(test_config): |
|
|
|
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") |
|
|
|
|
model_zoo.get_sub_registry("transformers_bert") |
|
|
|
|
test_config["use_lazy_init"] = False |
|
|
|
|
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel |
|
|
|
|
test_config["initial_scale"] = 2**16 # avoid overflow |
|
|
|
|
model_list = [ |
|
|
|
|
"transformers_bert", |
|
|
|
|
] |
|
|
|
|
clear_layout_converter() |
|
|
|
|
torch.set_default_dtype(torch.bfloat16) |
|
|
|
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): |
|
|
|
|
data_gen_fn() |
|
|
|
|
# print(f"data {data}") |
|
|
|
|
# if name in model_list: |
|
|
|
|
# ( |
|
|
|
|
# org_model, |
|
|
|
|
# org_optimizer, |
|
|
|
|
# sharded_model, |
|
|
|
|
# sharded_optimizer, |
|
|
|
|
# criterion, |
|
|
|
|
# booster, |
|
|
|
|
# ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD) |
|
|
|
|
|
|
|
|
|
# org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( |
|
|
|
|
# org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster |
|
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
# stage_manager = booster.plugin.stage_manager |
|
|
|
|
# tp_group = booster.plugin.tp_group |
|
|
|
|
|
|
|
|
|
# bert = unwrap_model(org_model, "BertModel", "bert") |
|
|
|
|
# sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") |
|
|
|
|
# weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] |
|
|
|
|
|
|
|
|
|
# org_optimizer.step() |
|
|
|
|
# sharded_optimizer.step() |
|
|
|
|
|
|
|
|
|
# # check weights |
|
|
|
|
# if test_config["precision"] == "bf16": |
|
|
|
|
# atol, rtol = 5e-4, 5e-4 |
|
|
|
|
# else: |
|
|
|
|
# atol, rtol = 5e-4, 5e-4 |
|
|
|
|
# if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) |
|
|
|
|
# # check optim states |
|
|
|
|
# # check_dist_optim_state(org_optimizer, sharded_optimizer.optim) |
|
|
|
|
|
|
|
|
|
# clear_layout_converter() |
|
|
|
|
# Randomizer.reset_index() |
|
|
|
|
# torch.cuda.empty_cache() |
|
|
|
|
# print(f"Bert Model Zoo Test Passed") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO:6) support booster & Hybrid base 4) |
|
|
|
|