|
|
|
@ -3,22 +3,21 @@ from copy import deepcopy
|
|
|
|
|
import pytest |
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
from transformers.models.mixtral.configuration_mixtral import MixtralConfig |
|
|
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralModel |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.booster.booster import Booster |
|
|
|
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin |
|
|
|
|
from colossalai.booster.plugin import HybridParallelPlugin |
|
|
|
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin |
|
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
|
|
|
from colossalai.testing.random import seed_all |
|
|
|
|
from tests.test_moe.moe_utils import loose_close |
|
|
|
|
|
|
|
|
|
NUM_BATCH=4 |
|
|
|
|
NUM_BATCH = 4 |
|
|
|
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 |
|
|
|
|
HIDDEN_SIZE_PER_HEAD = 4 |
|
|
|
|
NUM_HEADS=2 |
|
|
|
|
NUM_HEADS = 4 |
|
|
|
|
TOP_K = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -35,7 +34,7 @@ def split_grad(grad, world_size):
|
|
|
|
|
@parameterize("stage", [1]) |
|
|
|
|
@parameterize("ep_size", [1, 2, 4]) |
|
|
|
|
@parameterize("tp_size", [1, 2, 4]) |
|
|
|
|
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1): |
|
|
|
|
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int = 1): |
|
|
|
|
dtype = torch.bfloat16 |
|
|
|
|
|
|
|
|
|
rank = torch.distributed.get_rank() |
|
|
|
@ -56,19 +55,14 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
|
|
|
|
|
|
|
|
|
|
zero_model = deepcopy(torch_model).to(dtype) |
|
|
|
|
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) |
|
|
|
|
booster = Booster( |
|
|
|
|
moe_booster = Booster( |
|
|
|
|
plugin=MoeHybridParallelPlugin( |
|
|
|
|
tp_size=tp_size, |
|
|
|
|
pp_size=1, |
|
|
|
|
ep_size=ep_size, |
|
|
|
|
zero_stage=stage, |
|
|
|
|
overlap_communication=False, |
|
|
|
|
initial_scale=1 |
|
|
|
|
tp_size=tp_size, pp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1 |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) |
|
|
|
|
zero_model, zero_optimizer, _, _, _ = moe_booster.boost(zero_model, zero_optimizer) |
|
|
|
|
|
|
|
|
|
booster = Booster( |
|
|
|
|
hybird_booster = Booster( |
|
|
|
|
plugin=HybridParallelPlugin( |
|
|
|
|
tp_size=tp_size, |
|
|
|
|
pp_size=1, |
|
|
|
@ -77,8 +71,9 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
|
|
|
|
|
initial_scale=1, |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
hybrid_model, hybrid_optimizer, _, _, _ = booster.boost(torch_model, torch.optim.SGD(torch_model.parameters(), lr=1)) |
|
|
|
|
|
|
|
|
|
hybrid_model, hybrid_optimizer, _, _, _ = hybird_booster.boost( |
|
|
|
|
torch_model, torch.optim.SGD(torch_model.parameters(), lr=1) |
|
|
|
|
) |
|
|
|
|
# create different input |
|
|
|
|
seed_all(1453 + rank) |
|
|
|
|
|
|
|
|
@ -86,7 +81,9 @@ def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
|
|
|
|
|
zero_model.train() |
|
|
|
|
for _ in range(2): |
|
|
|
|
# zero-dp forward |
|
|
|
|
input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda() |
|
|
|
|
input_data = torch.rand( |
|
|
|
|
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True |
|
|
|
|
).cuda() |
|
|
|
|
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() |
|
|
|
|
# zero-dp backward |
|
|
|
|
zero_optimizer.backward(zero_output) |
|
|
|
|