mirror of https://github.com/hpcaitech/ColossalAI
[moe] full test for deepseek and mixtral (pp + sp to fix)
parent
162e2d935c
commit
783aafa327
|
@ -1122,6 +1122,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
|
||||
self.logger.info(
|
||||
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
||||
ranks=[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
)
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
sequence_parallel_process_group=self.sp_group,
|
||||
|
|
|
@ -147,9 +147,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
world_size = dist.get_world_size()
|
||||
|
||||
if self.enable_sequence_parallelism:
|
||||
# if sequence parallelism is enabled, we reuse the same group for ep and sp
|
||||
if self.sequence_parallelism_mode == "all_to_all":
|
||||
# when sequence parallelism is enabled, ep_group reuses sp_group
|
||||
# if sequence parallelism is enabled, ep_group reuses sp_group
|
||||
if self.ep_size != self.sp_size:
|
||||
raise ValueError(
|
||||
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
|
||||
|
@ -157,8 +156,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
# since we are reusing sp_group, moe_dp_group will be derived as dp_group
|
||||
self.moe_dp_size = self.dp_size
|
||||
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
|
||||
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
self.moe_dp_group = self.dp_group
|
||||
self.dp_sp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
self.ep_group = self.sp_group
|
||||
self.moe_tp_group = self.tp_group
|
||||
else:
|
||||
|
@ -177,6 +176,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.moe_dp_group = None
|
||||
self.ep_group = None
|
||||
self.moe_tp_group = None
|
||||
self.dp_sp_group = self.dp_group
|
||||
|
||||
# create submesh for ep, moe_dp, moe_tp
|
||||
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
||||
|
@ -225,8 +225,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
|
||||
self.logger.info(
|
||||
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size=}\n"
|
||||
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
||||
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}\n"
|
||||
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
|
@ -254,7 +254,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=self.dp_sp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=self.use_ddp,
|
||||
|
@ -302,7 +302,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
force_overlap_comm=self.force_overlap_comm,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
dp_process_group=self.dp_sp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.legacy.moe.manager import MOE_MANAGER
|
||||
from colossalai.legacy.moe.utils import get_activation
|
||||
from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut
|
||||
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|||
|
||||
from colossalai.legacy.moe.load_balance import LoadBalancer
|
||||
from colossalai.legacy.moe.utils import create_ep_hierarchical_group, get_noise_generator
|
||||
from colossalai.moe.operators import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
|
||||
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
|
||||
from colossalai.shardformer.layer.moe import MLPExperts
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.legacy.moe.manager import MOE_MANAGER
|
||||
from colossalai.legacy.moe.utils import get_activation
|
||||
from colossalai.moe.operators import EPGradScalerIn, EPGradScalerOut
|
||||
from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
|
||||
|
||||
|
|
|
@ -10,7 +10,13 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
|
||||
from colossalai.moe._operation import (
|
||||
DPGradScalerIn,
|
||||
DPGradScalerOut,
|
||||
EPGradScalerIn,
|
||||
EPGradScalerOut,
|
||||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
|
|
@ -23,7 +23,13 @@ from transformers.models.mixtral.modeling_mixtral import (
|
|||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe.operators import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven
|
||||
from colossalai.moe._operation import (
|
||||
DPGradScalerIn,
|
||||
DPGradScalerOut,
|
||||
EPGradScalerIn,
|
||||
EPGradScalerOut,
|
||||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
|
@ -245,6 +251,7 @@ class MixtralPipelineForwards:
|
|||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
print("input_ids", input_ids.shape)
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
|
@ -372,16 +379,29 @@ class MixtralPipelineForwards:
|
|||
if output_router_logits and past_router_logits is not None:
|
||||
all_router_logits = past_router_logits + all_router_logits
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def mixtral_for_causal_lm_forward(
|
||||
|
|
|
@ -34,7 +34,10 @@ class DeepseekPolicy(Policy):
|
|||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
if self.pipeline_stage_manager is not None:
|
||||
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
||||
# if both are enabled, one of them will be ignored
|
||||
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
|
||||
raise NotImplementedError(
|
||||
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
@ -136,6 +139,10 @@ class DeepseekPolicy(Policy):
|
|||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
||||
# if both are enabled, one of them will be ignored
|
||||
raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "DeepseekModel":
|
||||
module = self.model
|
||||
|
|
|
@ -62,6 +62,10 @@ class MixtralPolicy(Policy):
|
|||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
if self.pipeline_stage_manager is not None:
|
||||
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
||||
# if both are enabled, one of them will be ignored
|
||||
raise NotImplementedError("Sequence parallelism is not supported with pipeline parallelism.")
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_mixtral_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
|
@ -69,19 +73,18 @@ class MixtralPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_mixtral_flash_attention_model_forward(
|
||||
self.shard_config,
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_mixtral_flash_attention_model_forward(
|
||||
self.shard_config,
|
||||
sp_mode=sp_mode,
|
||||
sp_size=sp_size,
|
||||
sp_group=sp_group,
|
||||
),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -202,6 +205,10 @@ class MixtralPolicy(Policy):
|
|||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
||||
# if both are enabled, one of them will be ignored
|
||||
raise NotImplementedError("Pipeline parallelism is not supported with sequence parallelism.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
|
|
|
@ -100,7 +100,7 @@ class BucketStore(BaseStore):
|
|||
|
||||
return self._grad_in_bucket
|
||||
|
||||
def get_flatten_grad(self, dtype=None) -> Tensor:
|
||||
def get_flatten_grad(self) -> Tensor:
|
||||
"""Return the flattened gradients slices in the bucket, the data organization of the flattened tensor:
|
||||
[grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
|
||||
|
||||
|
|
|
@ -303,7 +303,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
for bucket_store in self.pg_to_bucket_store.values():
|
||||
bucket_store.build_grad_in_bucket()
|
||||
|
||||
flat_grads = bucket_store.get_flatten_grad(self._dtype)
|
||||
flat_grads = bucket_store.get_flatten_grad()
|
||||
flat_grads /= bucket_store.world_size
|
||||
|
||||
# ready to add other tensors to bucket
|
||||
|
|
|
@ -1,133 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||
|
||||
NUM_BATCH = 4
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
stage, ep_size, tp_size = config
|
||||
dtype = torch.float16
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
tp_size=tp_size,
|
||||
moe_tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
zero_stage=stage,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision="fp32",
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
seed_all(10086)
|
||||
|
||||
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
|
||||
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
|
||||
config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
|
||||
config.num_hidden_layers = 2
|
||||
config.num_attention_heads = NUM_HEADS
|
||||
config.num_key_value_heads = NUM_HEADS
|
||||
config.n_routed_experts = NUM_EXPERTS
|
||||
config.num_experts_per_tok = TOP_K
|
||||
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
|
||||
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
zero_model = deepcopy(torch_model).to(dtype)
|
||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
# create different input
|
||||
seed_all(1453 + rank)
|
||||
|
||||
torch_model.train()
|
||||
zero_model.train()
|
||||
for _ in range(2):
|
||||
input_data = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
|
||||
|
||||
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||
zero_optimizer.backward(zero_output)
|
||||
zero_optimizer.step()
|
||||
zero_optimizer.zero_grad()
|
||||
dist.all_reduce(zero_output)
|
||||
|
||||
all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(all_inputs, input_data)
|
||||
|
||||
torch_output_sum = 0
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
torch_output_sum += torch_output.detach()
|
||||
# avg dp grads
|
||||
for p in torch_model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad /= dist.get_world_size()
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
loose_close(zero_output, torch_output_sum, dtype=dtype)
|
||||
|
||||
# use checkpoint to load sharded zero model
|
||||
model_dir = "./test_deepseek"
|
||||
if dist.get_rank() == 0:
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
booster.save_model(zero_model, model_dir, shard=True)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
|
||||
check_model_equal(torch_model, saved_model)
|
||||
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
shutil.rmtree(model_dir)
|
||||
|
||||
print(f"{dist.get_rank()} test passed")
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_with_original_model()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mistral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mistral(world_size=4)
|
|
@ -1,143 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||
|
||||
NUM_BATCH = 4
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
ep_size, stage, dp_size, pp_size, tp_size, sp_size = config
|
||||
print(config)
|
||||
rank = torch.distributed.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=pp_size,
|
||||
num_microbatches=pp_size,
|
||||
tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
ep_size=ep_size,
|
||||
moe_tp_size=tp_size,
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision=precision,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
seed_all(10086)
|
||||
|
||||
config = MixtralConfig(
|
||||
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=NUM_HEADS,
|
||||
num_key_value_heads=NUM_HEADS,
|
||||
num_local_experts=NUM_EXPERTS,
|
||||
num_experts_per_tok=TOP_K,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
zero_model = deepcopy(torch_model).to(dtype)
|
||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
# create different input
|
||||
seed_all(1453 + rank)
|
||||
|
||||
torch_model.train()
|
||||
zero_model.train()
|
||||
for _ in range(2):
|
||||
input_data = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
|
||||
dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input
|
||||
dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input
|
||||
|
||||
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||
zero_optimizer.backward(zero_output)
|
||||
zero_optimizer.step()
|
||||
zero_optimizer.zero_grad()
|
||||
dist.all_reduce(zero_output)
|
||||
|
||||
all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(all_inputs, input_data)
|
||||
|
||||
torch_output_sum = 0
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
torch_output_sum += torch_output.detach()
|
||||
# avg dp grads
|
||||
for p in torch_model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad /= dist.get_world_size()
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
loose_close(zero_output, torch_output_sum, dtype=dtype)
|
||||
|
||||
# use checkpoint to load sharded zero model
|
||||
model_dir = "./test_mixtral"
|
||||
if dist.get_rank() == 0:
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
booster.save_model(zero_model, model_dir, shard=True)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
||||
check_model_equal(torch_model, saved_model)
|
||||
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
shutil.rmtree(model_dir)
|
||||
|
||||
print(f"{dist.get_rank()} test passed")
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_with_original_model()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mistral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mistral(world_size=8)
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe.operators import MoeCombine, MoeDispatch, moe_cumsum
|
||||
from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
|
||||
|
||||
NUM_EXPERTS = 4
|
||||
BATCH_SIZE = 4
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||
|
||||
NUM_BATCH = 8
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
||||
NUM_LAYERS = 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
# TODO only need to keep one or two cases
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(2, 1, 1, 4, 1),
|
||||
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
|
||||
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
|
||||
(2, 1, 1, 2, 1),
|
||||
# (2, 1, 1, 1, 2), # TODO support deepseek sp
|
||||
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
|
||||
(4, 1, 1, 1, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
|
||||
],
|
||||
)
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
ep_size, stage, pp_size, tp_size, sp_size = config
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
print(config)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=pp_size,
|
||||
num_microbatches=pp_size,
|
||||
tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
ep_size=ep_size,
|
||||
moe_tp_size=tp_size,
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision=precision,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
dp_size = plugin.dp_size
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# init model with the same seed
|
||||
seed_all(10086)
|
||||
|
||||
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
|
||||
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
|
||||
config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
|
||||
config.num_hidden_layers = 2
|
||||
config.num_attention_heads = NUM_HEADS
|
||||
config.num_key_value_heads = NUM_HEADS
|
||||
config.n_routed_experts = NUM_EXPERTS
|
||||
config.num_experts_per_tok = TOP_K
|
||||
|
||||
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
parallel_model = deepcopy(torch_model)
|
||||
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
|
||||
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
|
||||
|
||||
# create different input along dp axis
|
||||
seed_all(1453 + rank)
|
||||
|
||||
torch_model.train()
|
||||
parallel_model.train()
|
||||
for _ in range(2):
|
||||
# gen random input
|
||||
input_embeddings = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
dist.all_reduce(
|
||||
input_embeddings, group=plugin.pp_group
|
||||
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
||||
|
||||
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
|
||||
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
|
||||
|
||||
# run the model with hybrid parallel
|
||||
if booster.plugin.stage_manager is not None:
|
||||
# for test with pp
|
||||
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
||||
sharded_output = booster.execute_pipeline(
|
||||
data_iter,
|
||||
parallel_model,
|
||||
lambda x, y: x[0].mean(),
|
||||
parallel_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
if booster.plugin.stage_manager.is_last_stage():
|
||||
parallel_output = sharded_output["loss"]
|
||||
else:
|
||||
parallel_output = torch.tensor(12345.0, device="cuda")
|
||||
|
||||
# broadcast along pp axis
|
||||
dist.broadcast(
|
||||
parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
|
||||
)
|
||||
else:
|
||||
# for test without pp
|
||||
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
|
||||
parallel_optimizer.backward(parallel_output)
|
||||
parallel_optimizer.step()
|
||||
parallel_optimizer.zero_grad()
|
||||
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
||||
|
||||
# ===================================================================================
|
||||
# run normal model with all dp(different) inputs
|
||||
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||
torch_output_sum = 0
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
torch_output_sum += torch_output.detach()
|
||||
# avg dp grads follows zero optimizer
|
||||
for p in torch_model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad /= dp_size
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
|
||||
# use checkpoint to load sharded zero model
|
||||
model_dir = "./test_mixtral"
|
||||
if rank == world_size - 1:
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
dist.barrier()
|
||||
booster.save_model(parallel_model, model_dir, shard=True)
|
||||
dist.barrier()
|
||||
|
||||
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
|
||||
check_model_equal(torch_model, saved_model)
|
||||
dist.barrier()
|
||||
|
||||
if rank == world_size - 1:
|
||||
shutil.rmtree(model_dir)
|
||||
|
||||
print(f"rank {dist.get_rank()} test passed")
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_with_original_model()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mistral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mistral(world_size=8)
|
|
@ -1,229 +1,188 @@
|
|||
# modified from test_shard_mistral.py
|
||||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
# TODO: SGD failed for full dp
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
|
||||
)
|
||||
|
||||
org_model = org_model.to(torch.float16)
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
|
||||
|
||||
# unwrap model
|
||||
mixtral_model = unwrap_model(org_model, "MixtralModel", "model")
|
||||
shard_mixtral_model = unwrap_model(sharded_model, "MixtralModel", "model")
|
||||
|
||||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||
|
||||
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||
if (
|
||||
# booster.plugin.zero_stage in [1, 2]
|
||||
booster.plugin.shard_config.enable_sequence_parallelism
|
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
rank = dist.get_rank()
|
||||
name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
|
||||
for n, p in shard_mixtral_model.named_parameters():
|
||||
zero_grad = sharded_optimizer.get_param_grad(p)
|
||||
if name_to_p[n].grad is None:
|
||||
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
|
||||
continue
|
||||
assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 5e-5, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(
|
||||
mixtral_model,
|
||||
shard_mixtral_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False,
|
||||
)
|
||||
col_layer_grads = get_grad_tensors_for_check(
|
||||
mixtral_model,
|
||||
shard_mixtral_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
for n, p in shard_mixtral_model.named_parameters():
|
||||
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
for n, p in shard_mixtral_model.named_parameters():
|
||||
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 2e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
try:
|
||||
check_weight(
|
||||
mixtral_model,
|
||||
shard_mixtral_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
except Exception as e:
|
||||
rank = dist.get_rank()
|
||||
print(f"{rank=}, Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
NUM_BATCH = 8
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
||||
NUM_LAYERS = 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
# TODO only need to keep one or two cases
|
||||
@parameterize(
|
||||
"test_config",
|
||||
"config",
|
||||
[
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 1,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 0,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp16",
|
||||
# }, # [dp(4)] + [moe_dp(4)]
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 1,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# }, # [dp(2) + pp(2)] + [moe_pp(2)]
|
||||
# {
|
||||
# "tp_size": 2,
|
||||
# "pp_size": 2,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 1,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"ep_size": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"find_unused_parameters": True,
|
||||
},
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 1,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 0,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 1,
|
||||
# "ep_size": 4,
|
||||
# "overlap_communication": False,
|
||||
# "zero_stage": 0,
|
||||
# "precision": "fp32"
|
||||
# }, # full dp for non-moe and full ep for moe
|
||||
(2, 1, 1, 4, 1),
|
||||
(2, 1, 2, 1, 1),
|
||||
(2, 1, 2, 2, 1),
|
||||
(2, 1, 1, 2, 1),
|
||||
(2, 1, 1, 1, 2),
|
||||
(2, 1, 4, 1, 1),
|
||||
(4, 1, 1, 1, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
(4, 1, 2, 1, 1),
|
||||
],
|
||||
)
|
||||
def run_mixtral_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_mixtral")
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
ep_size, stage, pp_size, tp_size, sp_size = config
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=pp_size,
|
||||
num_microbatches=pp_size,
|
||||
tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
ep_size=ep_size,
|
||||
moe_tp_size=tp_size,
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision=precision,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
dp_size = plugin.dp_size
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# init model with the same seed
|
||||
seed_all(10086)
|
||||
|
||||
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||
config = MixtralConfig(
|
||||
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
||||
num_hidden_layers=NUM_LAYERS,
|
||||
num_attention_heads=NUM_HEADS,
|
||||
num_key_value_heads=NUM_HEADS,
|
||||
num_local_experts=NUM_EXPERTS,
|
||||
num_experts_per_tok=TOP_K,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
parallel_model = deepcopy(torch_model)
|
||||
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
|
||||
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
|
||||
|
||||
# create different input along dp axis
|
||||
seed_all(1453 + rank)
|
||||
|
||||
torch_model.train()
|
||||
parallel_model.train()
|
||||
for _ in range(2):
|
||||
# gen random input
|
||||
input_embeddings = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
dist.all_reduce(
|
||||
input_embeddings, group=plugin.pp_group
|
||||
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
||||
|
||||
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
|
||||
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
|
||||
|
||||
# run the model with hybrid parallel
|
||||
if booster.plugin.stage_manager is not None:
|
||||
# for test with pp
|
||||
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
||||
sharded_output = booster.execute_pipeline(
|
||||
data_iter,
|
||||
parallel_model,
|
||||
lambda x, y: x.last_hidden_state.mean(),
|
||||
parallel_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
if booster.plugin.stage_manager.is_last_stage():
|
||||
parallel_output = sharded_output["loss"]
|
||||
else:
|
||||
parallel_output = torch.tensor(12345.0, device="cuda")
|
||||
|
||||
# broadcast along pp axis
|
||||
dist.broadcast(
|
||||
parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
|
||||
)
|
||||
else:
|
||||
# for test without pp
|
||||
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
|
||||
parallel_optimizer.backward(parallel_output)
|
||||
parallel_optimizer.step()
|
||||
parallel_optimizer.zero_grad()
|
||||
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
||||
|
||||
# ===================================================================================
|
||||
# run normal model with all dp(different) inputs
|
||||
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||
torch_output_sum = 0
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
torch_output_sum += torch_output.detach()
|
||||
# avg dp grads follows zero optimizer
|
||||
for p in torch_model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad /= dp_size
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
|
||||
# use checkpoint to load sharded zero model
|
||||
model_dir = "./test_mixtral"
|
||||
if rank == world_size - 1:
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
dist.barrier()
|
||||
booster.save_model(parallel_model, model_dir, shard=True)
|
||||
dist.barrier()
|
||||
|
||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
||||
check_model_equal(torch_model, saved_model)
|
||||
dist.barrier()
|
||||
|
||||
if rank == world_size - 1:
|
||||
shutil.rmtree(model_dir)
|
||||
|
||||
print(f"rank {dist.get_rank()} test passed")
|
||||
|
||||
|
||||
def check_mixtral(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_mixtral_test()
|
||||
run_zero_with_original_model()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_mixtral():
|
||||
spawn(check_mixtral, 4)
|
||||
def test_mistral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral()
|
||||
test_mistral(world_size=8)
|
||||
|
|
Loading…
Reference in New Issue