mirror of https://github.com/hpcaitech/ColossalAI
[chore] solve moe ckpt test failure and some other arg pass failure
parent
9f9e268265
commit
05a78d2f41
|
@ -446,7 +446,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||||
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||||
|
|
|
@ -69,8 +69,6 @@ class EPDeepseekMoE(nn.Module):
|
||||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||||
|
|
||||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||||
for p in self.experts.parameters():
|
|
||||||
set_moe_tensor_ep_group(p, ep_group)
|
|
||||||
|
|
||||||
# setup moe_dp group
|
# setup moe_dp group
|
||||||
self.moe_dp_group = moe_dp_group
|
self.moe_dp_group = moe_dp_group
|
||||||
|
@ -87,6 +85,9 @@ class EPDeepseekMoE(nn.Module):
|
||||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
|
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
|
||||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
|
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
|
||||||
|
|
||||||
|
for p in self.experts.parameters():
|
||||||
|
set_moe_tensor_ep_group(p, ep_group)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module,
|
module,
|
||||||
|
|
|
@ -74,8 +74,6 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||||
|
|
||||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||||
for p in self.experts.parameters():
|
|
||||||
set_moe_tensor_ep_group(p, ep_group)
|
|
||||||
|
|
||||||
# setup moe_dp group
|
# setup moe_dp group
|
||||||
self.moe_dp_group = moe_dp_group
|
self.moe_dp_group = moe_dp_group
|
||||||
|
@ -92,6 +90,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
|
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
|
||||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
|
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
|
||||||
|
|
||||||
|
for p in self.experts.parameters():
|
||||||
|
set_moe_tensor_ep_group(p, ep_group)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: MixtralSparseMoeBlock,
|
module: MixtralSparseMoeBlock,
|
||||||
|
|
|
@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||||
)
|
)
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||||
|
|
||||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
||||||
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
||||||
|
@ -66,7 +67,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None,
|
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
|
||||||
initial_scale: int = 2**16, # grad scaler config
|
initial_scale: int = 2**16, # grad scaler config
|
||||||
min_scale: int = 1,
|
min_scale: int = 1,
|
||||||
growth_factor: float = 2.0,
|
growth_factor: float = 2.0,
|
||||||
|
@ -92,7 +93,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
self._verbose = verbose
|
self._verbose = verbose
|
||||||
|
|
||||||
if dp_process_group is not None and pg_to_param_list is not None:
|
if (dp_process_group is not None) and (pg_to_param_list is not None):
|
||||||
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
|
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
|
||||||
|
|
||||||
if pg_to_param_list is None:
|
if pg_to_param_list is None:
|
||||||
|
@ -301,6 +302,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
def _run_reduction(self):
|
def _run_reduction(self):
|
||||||
for bucket_store in self.pg_to_bucket_store.values():
|
for bucket_store in self.pg_to_bucket_store.values():
|
||||||
|
if bucket_store.num_elements_in_bucket() <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
bucket_store.build_grad_in_bucket()
|
bucket_store.build_grad_in_bucket()
|
||||||
|
|
||||||
flat_grads = bucket_store.get_flatten_grad()
|
flat_grads = bucket_store.get_flatten_grad()
|
||||||
|
@ -350,8 +354,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int
|
self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int
|
||||||
) -> None:
|
) -> None:
|
||||||
for rank, grad_list in enumerate(origin_grad_list):
|
for rank, grad_list in enumerate(origin_grad_list):
|
||||||
if len(grad_list) == 0:
|
|
||||||
continue
|
|
||||||
sync_tensor(flat_grad_list[rank], grad_list)
|
sync_tensor(flat_grad_list[rank], grad_list)
|
||||||
for grad in grad_list:
|
for grad in grad_list:
|
||||||
param_id = bucket_store.get_param_id_of_grad(grad)
|
param_id = bucket_store.get_param_id_of_grad(grad)
|
||||||
|
@ -648,11 +650,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
param_group = self._working_param_groups[group_id]
|
param_group = self._working_param_groups[group_id]
|
||||||
for param in param_group:
|
for param in param_group:
|
||||||
if param.requires_grad:
|
if is_moe_tensor(param) and param.requires_grad and param.grad is None:
|
||||||
if param.grad is None:
|
# TODO better of of doing this
|
||||||
# for moe params, all experts should have gradient
|
# assign zero grad to unrouted expert to avoid hang during grad reduction
|
||||||
# TODO better way of doing this
|
param.grad = torch.zeros_like(param)
|
||||||
param.grad = torch.zeros_like(param)
|
|
||||||
|
if param.requires_grad and param.grad is not None:
|
||||||
self._add_to_bucket(param, group_id)
|
self._add_to_bucket(param, group_id)
|
||||||
|
|
||||||
self._run_reduction()
|
self._run_reduction()
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||||
|
assert loose_close(a, b, dtype), f"{name} not close {a.mean()} {b.mean()}"
|
||||||
|
|
||||||
|
|
||||||
|
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||||
rtol = None
|
rtol = None
|
||||||
atol = None
|
atol = None
|
||||||
if dtype is torch.float16:
|
if dtype is torch.float16:
|
||||||
|
@ -12,10 +16,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||||
atol = 4e-3
|
atol = 4e-3
|
||||||
else:
|
else:
|
||||||
assert dtype is torch.float32
|
assert dtype is torch.float32
|
||||||
rtol = 1e-5
|
rtol = 1e-05
|
||||||
atol = 1e-5
|
atol = 1e-08
|
||||||
|
|
||||||
a = a.detach().to(dtype)
|
a = a.detach().to(dtype)
|
||||||
b = b.detach().to(dtype).to(a.device)
|
b = b.detach().to(dtype).to(a.device)
|
||||||
|
|
||||||
assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"
|
return torch.allclose(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_equal(model1, model2):
|
||||||
|
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||||
|
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||||
|
assert_loose_close(p1, p2, p1.dtype)
|
||||||
|
|
|
@ -22,6 +22,7 @@ def check_deepseek_moe_layer():
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
|
zero_stage=1,
|
||||||
ep_size=dist.get_world_size(),
|
ep_size=dist.get_world_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,7 +43,13 @@ def check_deepseek_moe_layer():
|
||||||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||||
orig_output = orig_model(x)
|
orig_output = orig_model(x)
|
||||||
model = deepcopy(orig_model)
|
model = deepcopy(orig_model)
|
||||||
model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
|
model = EPDeepseekMoE.from_native_module(
|
||||||
|
model,
|
||||||
|
ep_group=plugin.ep_group,
|
||||||
|
moe_dp_group=plugin.moe_dp_group,
|
||||||
|
moe_tp_group=plugin.moe_tp_group,
|
||||||
|
tp_group=plugin.tp_group,
|
||||||
|
)
|
||||||
ep_output = model(x)
|
ep_output = model(x)
|
||||||
assert_close(orig_output, ep_output)
|
assert_close(orig_output, ep_output)
|
||||||
orig_loss = orig_output.mean()
|
orig_loss = orig_output.mean()
|
||||||
|
@ -62,7 +69,7 @@ def run_dist(rank: int, world_size: int, port: int):
|
||||||
check_deepseek_moe_layer()
|
check_deepseek_moe_layer()
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("world_size", [2, 4])
|
@pytest.mark.skip("tested in corresponding sharderformer")
|
||||||
@pytest.mark.parametrize("world_size", [2])
|
@pytest.mark.parametrize("world_size", [2])
|
||||||
def test_deepseek_moe_layer(world_size: int):
|
def test_deepseek_moe_layer(world_size: int):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
|
@ -23,6 +23,7 @@ def check_mixtral_moe_layer():
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
|
zero_stage=1,
|
||||||
ep_size=dist.get_world_size(),
|
ep_size=dist.get_world_size(),
|
||||||
)
|
)
|
||||||
config = MixtralConfig(
|
config = MixtralConfig(
|
||||||
|
@ -63,7 +64,8 @@ def run_dist(rank: int, world_size: int, port: int):
|
||||||
check_mixtral_moe_layer()
|
check_mixtral_moe_layer()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("world_size", [2, 4])
|
@pytest.mark.skip("tested in corresponding sharderformer")
|
||||||
|
@pytest.mark.parametrize("world_size", [2])
|
||||||
def test_mixtral_moe_layer(world_size: int):
|
def test_mixtral_moe_layer(world_size: int):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from copy import deepcopy
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.optim import Adam
|
from torch.optim import SGD, Adam
|
||||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||||
|
|
||||||
|
@ -14,20 +14,15 @@ import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.testing import parameterize, spawn
|
from colossalai.testing import parameterize, spawn
|
||||||
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.testing.utils import spawn
|
from colossalai.testing.utils import spawn
|
||||||
from tests.test_moe.moe_utils import loose_close
|
from tests.test_moe.moe_utils import check_model_equal
|
||||||
|
|
||||||
tokens, n_experts = 7, 4
|
tokens, n_experts = 7, 4
|
||||||
hidden_size = 8
|
hidden_size = 8
|
||||||
top_k = 2
|
top_k = 2
|
||||||
|
|
||||||
|
|
||||||
def check_model_equal(model1, model2):
|
|
||||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
|
||||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
|
||||||
loose_close(p1, p2, p1.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer_snapshot(optim):
|
def get_optimizer_snapshot(optim):
|
||||||
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
||||||
param_groups = []
|
param_groups = []
|
||||||
|
@ -86,34 +81,33 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou
|
||||||
num_experts_per_tok=top_k,
|
num_experts_per_tok=top_k,
|
||||||
num_attention_heads=2,
|
num_attention_heads=2,
|
||||||
num_key_value_heads=2,
|
num_key_value_heads=2,
|
||||||
|
num_hidden_layers=2,
|
||||||
),
|
),
|
||||||
MixtralForCausalLM,
|
MixtralForCausalLM,
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def check_moe_checkpoint(test_config):
|
def check_moe_checkpoint(test_config):
|
||||||
|
dtype, precision = torch.float16, "fp16"
|
||||||
|
config, model_cls = test_config
|
||||||
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
|
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
|
||||||
with context as f:
|
with context as f:
|
||||||
torch.cuda.set_device(dist.get_rank())
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
broadcast_objects = [f] # any picklable object
|
broadcast_objects = [f] # any picklable object
|
||||||
else:
|
else:
|
||||||
broadcast_objects = [None]
|
broadcast_objects = [None]
|
||||||
dist.broadcast_object_list(broadcast_objects, src=0)
|
dist.broadcast_object_list(broadcast_objects, src=0)
|
||||||
|
|
||||||
config = test_config[0]
|
|
||||||
model_cls = test_config[1]
|
|
||||||
torch.manual_seed(0)
|
|
||||||
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
||||||
orig_model = model_cls(config).cuda()
|
orig_model = model_cls(config).cuda().to(dtype)
|
||||||
|
|
||||||
|
seed_all(10086)
|
||||||
model = deepcopy(orig_model)
|
model = deepcopy(orig_model)
|
||||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
pp_size=2,
|
pp_size=2, ep_size=2, tp_size=1, microbatch_size=1, zero_stage=1, precision=precision
|
||||||
ep_size=2,
|
|
||||||
tp_size=1,
|
|
||||||
microbatch_size=1,
|
|
||||||
zero_stage=1,
|
|
||||||
)
|
)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
||||||
|
@ -135,12 +129,12 @@ def check_moe_checkpoint(test_config):
|
||||||
booster.save_model(model, model_dir, shard=True)
|
booster.save_model(model, model_dir, shard=True)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
saved_model = model_cls.from_pretrained(model_dir).cuda()
|
saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
|
||||||
check_model_equal(orig_model, saved_model)
|
check_model_equal(orig_model, saved_model)
|
||||||
saved_model.save_pretrained(hf_model_dir)
|
saved_model.save_pretrained(hf_model_dir)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
# check load model
|
# check load model
|
||||||
new_model = model_cls(config).cuda()
|
new_model = model_cls(config).cuda().to(dtype)
|
||||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||||
booster.load_model(new_model, hf_model_dir)
|
booster.load_model(new_model, hf_model_dir)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from tests.test_moe.moe_utils import loose_close
|
from tests.test_moe.moe_utils import assert_loose_close
|
||||||
|
|
||||||
NUM_BATCH = 4
|
NUM_BATCH = 4
|
||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||||
|
@ -22,7 +22,7 @@ TOP_K = 2
|
||||||
|
|
||||||
|
|
||||||
@parameterize("stage", [1])
|
@parameterize("stage", [1])
|
||||||
@parameterize("ep_size", [1, 2, 4])
|
@parameterize("ep_size", [2])
|
||||||
def run_zero_with_original_model(stage: int, ep_size: int):
|
def run_zero_with_original_model(stage: int, ep_size: int):
|
||||||
tp_size = dist.get_world_size() // ep_size
|
tp_size = dist.get_world_size() // ep_size
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
@ -85,7 +85,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
||||||
zero_optimizer.backward(zero_output)
|
zero_optimizer.backward(zero_output)
|
||||||
# torch-ddp forward
|
# torch-ddp forward
|
||||||
hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||||
loose_close(zero_output, hybrid_output, dtype=dtype)
|
assert_loose_close(zero_output, hybrid_output, dtype=dtype)
|
||||||
# torch-ddp backward
|
# torch-ddp backward
|
||||||
hybrid_optimizer.backward(hybrid_output)
|
hybrid_optimizer.backward(hybrid_output)
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
||||||
continue
|
continue
|
||||||
if zero_grad.shape != name_to_p[n].grad.shape: # TODO check sharded and sliced moe
|
if zero_grad.shape != name_to_p[n].grad.shape: # TODO check sharded and sliced moe
|
||||||
continue
|
continue
|
||||||
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
|
assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
|
||||||
|
|
||||||
# zero-dp step
|
# zero-dp step
|
||||||
zero_optimizer.step()
|
zero_optimizer.step()
|
||||||
|
@ -110,7 +110,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
||||||
for n, p in zero_model.named_parameters():
|
for n, p in zero_model.named_parameters():
|
||||||
if p.data.shape != name_to_p[n].data.shape: # TODO check sharded and sliced moe
|
if p.data.shape != name_to_p[n].data.shape: # TODO check sharded and sliced moe
|
||||||
continue
|
continue
|
||||||
loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
|
assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
|
||||||
|
|
||||||
print(f"{dist.get_rank()} test passed")
|
print(f"{dist.get_rank()} test passed")
|
||||||
|
|
||||||
|
@ -120,6 +120,7 @@ def run_dist(rank, world_size, port):
|
||||||
run_zero_with_original_model()
|
run_zero_with_original_model()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("tested in corresponding sharderformer")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.booster.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from tests.test_moe.moe_utils import loose_close
|
from tests.test_moe.moe_utils import assert_loose_close
|
||||||
|
|
||||||
NUM_BATCH = 4
|
NUM_BATCH = 4
|
||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||||
|
@ -22,7 +22,7 @@ TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
@parameterize("stage", [1])
|
@parameterize("stage", [1])
|
||||||
@parameterize("ep_size", [1, 2, 4])
|
@parameterize("ep_size", [2, 4])
|
||||||
def run_zero_with_original_model(stage: int, ep_size: int):
|
def run_zero_with_original_model(stage: int, ep_size: int):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
||||||
|
|
||||||
# torch-ddp forward
|
# torch-ddp forward
|
||||||
ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||||
loose_close(zero_output, ddp_output, dtype=dtype)
|
assert_loose_close(zero_output, ddp_output, dtype=dtype)
|
||||||
# torch-ddp backward
|
# torch-ddp backward
|
||||||
ddp_output.backward()
|
ddp_output.backward()
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
||||||
if name_to_p[n].grad is None:
|
if name_to_p[n].grad is None:
|
||||||
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
|
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
|
||||||
continue
|
continue
|
||||||
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
|
assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
|
||||||
|
|
||||||
# zero-dp step
|
# zero-dp step
|
||||||
zero_optimizer.step()
|
zero_optimizer.step()
|
||||||
|
@ -97,7 +97,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
||||||
|
|
||||||
# check updated param
|
# check updated param
|
||||||
for n, p in zero_model.named_parameters():
|
for n, p in zero_model.named_parameters():
|
||||||
loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
|
assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
|
||||||
|
|
||||||
print(f"{dist.get_rank()} test passed")
|
print(f"{dist.get_rank()} test passed")
|
||||||
|
|
||||||
|
@ -107,6 +107,7 @@ def run_dist(rank, world_size, port):
|
||||||
run_zero_with_original_model()
|
run_zero_with_original_model()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("tested in corresponding sharderformer")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
|
@ -14,8 +14,7 @@ from colossalai.booster.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from tests.test_moe.moe_utils import loose_close
|
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
|
||||||
|
|
||||||
NUM_BATCH = 8
|
NUM_BATCH = 8
|
||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
||||||
|
@ -25,18 +24,21 @@ NUM_HEADS = 4
|
||||||
TOP_K = 1
|
TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
# TODO only need to keep one or two cases
|
CHECKED_CONFIG = [ # FOR_WORLD=8
|
||||||
|
(2, 1, 1, 4, 1),
|
||||||
|
(4, 1, 1, 2, 1),
|
||||||
|
(4, 1, 1, 1, 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
(2, 1, 1, 4, 1),
|
|
||||||
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
|
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
|
||||||
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
|
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
|
||||||
(2, 1, 1, 2, 1),
|
(2, 1, 1, 2, 1),
|
||||||
# (2, 1, 1, 1, 2), # TODO support deepseek sp
|
# (2, 1, 1, 1, 2), # TODO support deepseek sp
|
||||||
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
|
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
|
||||||
(4, 1, 1, 1, 1),
|
|
||||||
(4, 1, 1, 2, 1),
|
|
||||||
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
|
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -66,9 +68,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
|
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
# init model with the same seed
|
|
||||||
seed_all(10086)
|
|
||||||
|
|
||||||
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||||
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
|
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
|
||||||
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
|
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
|
||||||
|
@ -79,6 +78,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
config.n_routed_experts = NUM_EXPERTS
|
config.n_routed_experts = NUM_EXPERTS
|
||||||
config.num_experts_per_tok = TOP_K
|
config.num_experts_per_tok = TOP_K
|
||||||
|
|
||||||
|
# init model with the same seed
|
||||||
|
seed_all(10086)
|
||||||
|
|
||||||
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
|
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
|
||||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||||
|
|
||||||
|
@ -148,7 +150,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
|
|
||||||
# use checkpoint to load sharded zero model
|
# use checkpoint to load sharded zero model
|
||||||
model_dir = "./test_mixtral"
|
model_dir = "./test_mixtral"
|
||||||
|
@ -175,7 +177,7 @@ def run_dist(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [8])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_mistral(world_size):
|
def test_mistral(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
|
@ -15,8 +15,7 @@ from colossalai.booster.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from tests.test_moe.moe_utils import loose_close
|
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
|
||||||
|
|
||||||
NUM_BATCH = 8
|
NUM_BATCH = 8
|
||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
||||||
|
@ -25,20 +24,21 @@ HIDDEN_SIZE_PER_HEAD = 4
|
||||||
NUM_HEADS = 4
|
NUM_HEADS = 4
|
||||||
TOP_K = 1
|
TOP_K = 1
|
||||||
|
|
||||||
|
CHECKED_CONFIG = [ # FOR WORLD=4
|
||||||
|
(2, 1, 2, 2, 1),
|
||||||
|
(2, 1, 1, 2, 1),
|
||||||
|
(2, 1, 4, 1, 1),
|
||||||
|
(4, 1, 1, 1, 1),
|
||||||
|
(4, 1, 1, 2, 1),
|
||||||
|
(4, 1, 2, 1, 1),
|
||||||
|
(2, 1, 2, 1, 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# TODO only need to keep one or two cases
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
(2, 1, 1, 4, 1),
|
|
||||||
(2, 1, 2, 1, 1),
|
|
||||||
(2, 1, 2, 2, 1),
|
|
||||||
(2, 1, 1, 2, 1),
|
(2, 1, 1, 2, 1),
|
||||||
(2, 1, 1, 1, 2),
|
|
||||||
(2, 1, 4, 1, 1),
|
|
||||||
(4, 1, 1, 1, 1),
|
|
||||||
(4, 1, 1, 2, 1),
|
|
||||||
(4, 1, 2, 1, 1),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
|
@ -67,9 +67,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
|
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
# init model with the same seed
|
|
||||||
seed_all(10086)
|
|
||||||
|
|
||||||
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||||
config = MixtralConfig(
|
config = MixtralConfig(
|
||||||
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||||
|
@ -82,6 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# init model with the same seed
|
||||||
|
seed_all(10086)
|
||||||
|
|
||||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
|
|
||||||
# use checkpoint to load sharded zero model
|
# use checkpoint to load sharded zero model
|
||||||
model_dir = "./test_mixtral"
|
model_dir = "./test_mixtral"
|
||||||
|
@ -178,7 +178,7 @@ def run_dist(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [8])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_mistral(world_size):
|
def test_mistral(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
Loading…
Reference in New Issue