[zero] solve hang

colossalchat
botbw 2024-07-09 08:14:00 +00:00 committed by Hongxin Liu
parent b5bfeb2efd
commit 13b48ac0aa
8 changed files with 218 additions and 335 deletions

View File

@ -30,6 +30,7 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
force_overlap_comm: bool, # force overlap comm
dp_process_group: ProcessGroup, # dp pg for comm
moe_dp_group: ProcessGroup, # moe dp pg for comm
param_info: OrderedDict,
@ -48,7 +49,16 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None,
):
):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
if not force_overlap_comm and (overlap_communication or partition_grad):
raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True")
if force_overlap_comm:
overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.")
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
@ -88,7 +98,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
TODO: add docstring
"""
def __init__(self, ep_size: int, moe_tp_size: int = 1, *args, **kwargs) -> None:
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
@ -120,6 +130,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# TODO do it in a better way
self.shard_config.ep_group = self.ep_group
self.force_overlap_comm = force_overlap_comm
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
@ -168,11 +180,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
if not(self.dp_size > 1 or self.moe_dp_size > 1):
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
)
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
dp_process_group=self.dp_group,
moe_dp_group=self.moe_dp_group,

View File

@ -110,12 +110,8 @@ class BucketStore(BaseStore):
flat_grad = []
for grad_list in self._grad_in_bucket.values():
if len(grad_list) > 0:
flat_grad.append(_flatten_dense_tensors(grad_list))
if len(flat_grad) > 0:
flat_grad = _flatten_dense_tensors(flat_grad)
else:
flat_grad = torch.tensor([], device=self.comm_stream.device, dtype=dtype)
flat_grad.append(_flatten_dense_tensors(grad_list))
flat_grad = _flatten_dense_tensors(flat_grad)
return flat_grad
def get_param_id_of_grad(self, grad: Tensor) -> int:

View File

@ -19,7 +19,6 @@ class GradientStore(BaseStore):
"""
self._grads_of_params = dict()
# stage 2
self._partition_grads = partition_grad
self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict()

View File

@ -648,7 +648,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad and param.grad is not None:
if param.requires_grad:
if param.grad is None:
# for moe params, all experts should have gradient
# TODO better way of doing this
param.grad = torch.zeros_like(param)
self._add_to_bucket(param, group_id)
self._run_reduction()

View File

@ -137,7 +137,7 @@ def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) ->
local_param.data.copy_(all_param.data)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
rtol = None
atol = None
if dtype is torch.float16:
@ -150,4 +150,4 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
assert_close(a, b, rtol=rtol, atol=atol)
assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"

View File

@ -1,238 +1,134 @@
import os
import warnings
from typing import Dict
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
# from colossalai.shardformer.layer import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler
NUM_BATCH=4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS=2
TOP_K = 2
def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from local model
Args:
tp_model (MoeModule)
local_model (MoeModule)
"""
for (tp_name, tp_param), (local_name, local_param) in zip(
tp_model.named_parameters(), local_model.named_parameters()
):
assert tp_name == local_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
assert torch.allclose(tp_param, local_param)
assert torch.allclose(tp_param.grad, local_param.grad)
else:
tp_param.data.copy_(local_param.data)
continue
tp_rank = get_ep_rank(tp_param)
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0]
tp_slice = [slice(None)] * tp_dim + [
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
]
if assert_grad_flag:
assert torch.allclose(tp_param, local_param[tuple(tp_slice)])
assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)])
else:
tp_param.data.copy_(local_param[tuple(tp_slice)].data)
def split_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
@parameterize("tp_size", [1, 2, 4])
def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int=1):
dtype = torch.bfloat16
Args:
tp_model (MoeModule)
ep_model (MoeModule)
"""
for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
assert tp_name == ep_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
assert torch.allclose(tp_param, ep_param)
assert torch.allclose(tp_param.grad, ep_param.grad)
else:
tp_param.data.copy_(ep_param.data)
continue
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
seed_all(10086)
# get tp param
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1
tp_rank = get_ep_rank(tp_param)
tp_slice = [slice(None)] * tp_dim + [
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
]
new_tp_param = all_param[tuple(tp_slice)]
if assert_grad_flag:
new_grad = all_grad[tuple(tp_slice)]
if assert_grad_flag:
assert torch.allclose(tp_param, new_tp_param)
assert torch.allclose(tp_param.grad, new_grad)
else:
tp_param.data.copy_(new_tp_param.data)
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name == ep_name
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param)
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
assert batch_size % world_size == 0
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
if enable_hierarchical_comm:
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
ep_model = SparseMLP(
num_experts=num_experts,
hidden_size=dim,
intermediate_size=dim * 2,
enable_hierarchical_comm=enable_hierarchical_comm,
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=2,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_accelerator().get_current_device())
tp_model = tp_model.to(get_accelerator().get_current_device())
local_model = local_model.to(get_accelerator().get_current_device())
torch_model = MixtralModel(config).to(dtype).cuda()
# sync ep param
sync_moe_model_param(ep_model)
dist_dict = MOE_MANAGER.parallel_info_dict
assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
ep_grad_handler = MoeGradientHandler(ep_model)
# sync local param
sync_local_from_ep(local_model, ep_model)
# sync tp param
sync_tp_from_ep(tp_model, ep_model)
tp_grad_handler = MoeGradientHandler(tp_model)
rank = dist.get_rank()
input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device())
micro_batch_size = batch_size // world_size
index = rank * micro_batch_size
# NOTE: ep & tp takes in sharded data for each process
shard_data = input_data.detach()[index : index + micro_batch_size]
out_local = local_model(input_data)
MOE_MANAGER.reset_loss()
out_tp = tp_model(shard_data)
MOE_MANAGER.reset_loss()
out_ep = ep_model(shard_data)
MOE_MANAGER.reset_loss()
assert torch.allclose(
out_tp, out_ep, atol=1e-6
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
try:
out_local_slice = out_local[index : index + micro_batch_size]
assert torch.allclose(
out_ep, out_local_slice, atol=1e-6
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
except AssertionError:
"""
e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
However, in ep mode, there are 2 separate routers dealing with sharded data.
Assume router 0 handles token [01] and router 1 handles token [23].
Note that for each router the capacity is only 1 !!!
Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both.
The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
"""
warnings.warn(
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
booster = Booster(
plugin=MoeHybridParallelPlugin(
tp_size=tp_size,
pp_size=1,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
initial_scale=1
)
)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
out_local.mean().backward()
out_tp.mean().backward()
tp_grad_handler.handle_gradient()
out_ep.mean().backward()
ep_grad_handler.handle_gradient()
assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
try:
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
except AssertionError:
warnings.warn(
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
booster = Booster(
plugin=HybridParallelPlugin(
tp_size=tp_size,
pp_size=1,
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
)
)
hybrid_model, hybrid_optimizer, _, _, _ = booster.boost(torch_model, torch.optim.SGD(torch_model.parameters(), lr=1))
# create different input
seed_all(1453 + rank)
hybrid_model.train()
zero_model.train()
for _ in range(2):
# zero-dp forward
input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)
# torch-ddp forward
hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
loose_close(zero_output, hybrid_output, dtype=dtype)
# torch-ddp backward
hybrid_optimizer.backward(hybrid_output)
# check grad
name_to_p = {n: p for n, p in hybrid_model.named_parameters()}
for n, p in zero_model.named_parameters():
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n])
continue
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
# original model step
hybrid_optimizer.step()
# check updated param
for n, p in zero_model.named_parameters():
loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize(
"config",
[
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
],
)
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
def test_moe_ep_tp(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
test_moe_ep_tp(world_size=4)

View File

@ -5,20 +5,20 @@ import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.test_moe.moe_utils import loose_close
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
NUM_BATCH=4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS=2
TOP_K = 2
def split_grad(grad, world_size):
@ -31,94 +31,87 @@ def split_grad(grad, world_size):
return splited_grad
@parameterize("stage", [1, 2])
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
def run_zero_with_original_model(stage: int, ep_size: int):
dtype = torch.float16
dtype = torch.bfloat16
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
tp_size=1,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
initial_scale=1
)
booster = Booster(plugin=plugin)
seed_all(10086)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=2,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
)
orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
torch_model = MixtralModel(config).to(dtype).cuda()
ori_model = DDP(
orig_model.cuda(),
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
ddp_model = DDP(
torch_model.cuda(),
process_group=plugin.dp_group,
find_unused_parameters=True, # important for torch ddp, not all experts are routed
).cuda()
ddp_optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1)
zero_model = deepcopy(orig_model).to(dtype)
zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
pg_param_list = {plugin.dp_group: [], plugin.moe_dp_group: []}
for p in zero_model.parameters():
if is_moe_tensor(p):
pg_param_list[plugin.moe_dp_group].append(p)
else:
pg_param_list[plugin.dp_group].append(p)
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
pg_to_param_list=pg_param_list,
master_weights=False,
initial_scale=1,
overlap_communication=True,
partition_grad=stage == 2,
)
ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
# create
# create different input
seed_all(1453 + rank)
ddp_model.train()
zero_model.train()
for _ in range(2):
# zero-dp forward
input_data = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
zero_output, _ = zero_model(input_data.to(dtype))
input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)
# torch-ddp forward
ori_output, _ = ori_model(input_data.to(dtype))
loose_close(zero_output, ori_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
loose_close(zero_output, ddp_output, dtype=dtype)
# torch-ddp backward
ori_output.mean().backward()
ddp_output.backward()
# check grad
name_to_p = {n: p for n, p in ori_model.module.named_parameters()}
name_to_p = {n: p for n, p in ddp_model.named_parameters()}
for n, p in zero_model.named_parameters():
print(f"rank {dist.get_rank()} {n}")
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
assert zero_grad is None
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
continue
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype)
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
# original model step
ori_optimizer.step()
ddp_optimizer.step()
# check updated param
for n, p in zero_model.named_parameters():
loose_close(p.data, name_to_p[n].data, dtype=dtype)
loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
@ -131,9 +124,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
def test_moe_ep_tp(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_zero_model(world_size=4)
test_moe_ep_tp(world_size=4)

View File

@ -113,65 +113,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 1,
"ep_size": 1,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(2) + tp(2)] + [moe_dp(4)]
{
"tp_size": 2,
"pp_size": 1,
"ep_size": 2,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(2) + tp(2)] + [ep(2) + moe_dp(2)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"pp_size": 1,
"ep_size": 1,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_dp(4)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_dp(4)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 4,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [ep(4))]
{
"tp_size": 1,
"pp_size": 1,
"ep_size": 2,
"zero_stage": 2,
"precision": "fp32",
}, # [dp(4)] + [ep(2) + moe_tp(2)]
{
"tp_size": 1,
"pp_size": 1,
"ep_size": 4,
"zero_stage": 2,
"precision": "fp32"
}, # full dp for non-moe and full ep for moe
{
"tp_size": 1,
"pp_size": 1,
"ep_size": 1,
"zero_stage": 2,
"precision": "fp32"
}, # full dp for moe and non-moe
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 1,
# "zero_stage": 1,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [moe_dp(4)]
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 4,
# "zero_stage": 1,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [ep(4))]
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 2,
# "zero_stage": 0,
# "precision": "fp32",
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 4,
# "zero_stage": 0,
# "precision": "fp32"
# }, # full dp for non-moe and full ep for moe
],
)
def run_mixtral_test(test_config):