|
|
|
@ -14,16 +14,9 @@ from colossalai.logging import disable_existing_loggers
|
|
|
|
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
|
|
|
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
from colossalai.shardformer.layer.utils import Randomizer
|
|
|
|
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
|
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|
|
|
|
from tests.kit.model_zoo import model_zoo
|
|
|
|
|
from tests.test_shardformer.test_model._utils import (
|
|
|
|
|
build_model_from_hybrid_plugin,
|
|
|
|
|
check_weight,
|
|
|
|
|
run_forward_backward_with_hybrid_plugin,
|
|
|
|
|
unwrap_model,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MlpModel(nn.Module):
|
|
|
|
@ -437,7 +430,7 @@ def run_fwd_bwd_iter_input(test_config):
|
|
|
|
|
local_chunk.append(sub_model)
|
|
|
|
|
else:
|
|
|
|
|
# layer 3 & 4 to chunk 3 on rank3
|
|
|
|
|
local_chunk = torch.nn.Sequential().to(rank)
|
|
|
|
|
local_chunk = torch.nn.ModuleList().to(rank)
|
|
|
|
|
for idx, sub_model in enumerate(model.layers):
|
|
|
|
|
if idx == 3 or idx == 4:
|
|
|
|
|
local_chunk.append(sub_model)
|
|
|
|
@ -594,7 +587,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|
|
|
|
local_chunk.append(sub_model)
|
|
|
|
|
else:
|
|
|
|
|
# layer 3 & 4 to chunk 3 on rank3
|
|
|
|
|
local_chunk = torch.nn.Sequential().to(rank)
|
|
|
|
|
local_chunk = torch.nn.ModuleList().to(rank)
|
|
|
|
|
for idx, sub_model in enumerate(model.layers):
|
|
|
|
|
if idx == 3 or idx == 4:
|
|
|
|
|
local_chunk.append(sub_model)
|
|
|
|
@ -718,44 +711,46 @@ def run_with_moehybridplugin(test_config):
|
|
|
|
|
clear_layout_converter()
|
|
|
|
|
torch.set_default_dtype(torch.bfloat16)
|
|
|
|
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
|
|
|
if name in model_list:
|
|
|
|
|
(
|
|
|
|
|
org_model,
|
|
|
|
|
org_optimizer,
|
|
|
|
|
sharded_model,
|
|
|
|
|
sharded_optimizer,
|
|
|
|
|
criterion,
|
|
|
|
|
booster,
|
|
|
|
|
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
|
|
|
|
|
|
|
|
|
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
|
|
|
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stage_manager = booster.plugin.stage_manager
|
|
|
|
|
tp_group = booster.plugin.tp_group
|
|
|
|
|
|
|
|
|
|
bert = unwrap_model(org_model, "BertModel", "bert")
|
|
|
|
|
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
|
|
|
|
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
|
|
|
|
|
|
|
|
|
|
org_optimizer.step()
|
|
|
|
|
sharded_optimizer.step()
|
|
|
|
|
|
|
|
|
|
# check weights
|
|
|
|
|
if test_config["precision"] == "bf16":
|
|
|
|
|
atol, rtol = 5e-4, 5e-4
|
|
|
|
|
else:
|
|
|
|
|
atol, rtol = 5e-4, 5e-4
|
|
|
|
|
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
|
|
|
|
# check optim states
|
|
|
|
|
# check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
|
|
|
|
|
|
|
|
|
|
clear_layout_converter()
|
|
|
|
|
Randomizer.reset_index()
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
print(f"Bert Model Zoo Test Passed")
|
|
|
|
|
data_gen_fn()
|
|
|
|
|
# print(f"data {data}")
|
|
|
|
|
# if name in model_list:
|
|
|
|
|
# (
|
|
|
|
|
# org_model,
|
|
|
|
|
# org_optimizer,
|
|
|
|
|
# sharded_model,
|
|
|
|
|
# sharded_optimizer,
|
|
|
|
|
# criterion,
|
|
|
|
|
# booster,
|
|
|
|
|
# ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
|
|
|
|
|
|
|
|
|
|
# org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
|
|
|
|
# org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
# stage_manager = booster.plugin.stage_manager
|
|
|
|
|
# tp_group = booster.plugin.tp_group
|
|
|
|
|
|
|
|
|
|
# bert = unwrap_model(org_model, "BertModel", "bert")
|
|
|
|
|
# sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
|
|
|
|
# weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
|
|
|
|
|
|
|
|
|
|
# org_optimizer.step()
|
|
|
|
|
# sharded_optimizer.step()
|
|
|
|
|
|
|
|
|
|
# # check weights
|
|
|
|
|
# if test_config["precision"] == "bf16":
|
|
|
|
|
# atol, rtol = 5e-4, 5e-4
|
|
|
|
|
# else:
|
|
|
|
|
# atol, rtol = 5e-4, 5e-4
|
|
|
|
|
# if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
# check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
|
|
|
|
# # check optim states
|
|
|
|
|
# # check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
|
|
|
|
|
|
|
|
|
|
# clear_layout_converter()
|
|
|
|
|
# Randomizer.reset_index()
|
|
|
|
|
# torch.cuda.empty_cache()
|
|
|
|
|
# print(f"Bert Model Zoo Test Passed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO:6) support booster & Hybrid base 4)
|
|
|
|
@ -766,8 +761,9 @@ def run_with_moehybridplugin(test_config):
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
|
|
|
disable_existing_loggers()
|
|
|
|
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
|
|
|
run_fwd_bwd_iter_input()
|
|
|
|
|
# run_fwd_bwd_iter_input()
|
|
|
|
|
run_fwd_bwd_vschedule_with_optim()
|
|
|
|
|
# run_with_moehybridplugin()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
|