update utils and fwd bwd

pull/5190/head
Xuanlei Zhao 2023-12-28 18:54:56 +08:00
parent a5580e6289
commit 0a3aae509b
3 changed files with 73 additions and 77 deletions

View File

@ -149,7 +149,14 @@ class Top1Router(MoeRouter):
low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device()) low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
).rsample ).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_loss: bool = False,
use_norm: bool = False,
) -> Tuple:
""" """
Args: Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).

View File

@ -2,12 +2,21 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe import SparseMLP from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict from colossalai.moe.utils import get_moe_epsize_param_dict
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
from tests.test_moe.moe_utils import MoeModel
def delete_moe_info(model):
for _, param in model.named_parameters():
if hasattr(param, "moe_info"):
delattr(param, "moe_info")
class MoeModel(nn.Module): class MoeModel(nn.Module):
@ -85,6 +94,58 @@ def assert_not_equal_in_group(tensor, process_group=None):
for i in range(world_size - 1): for i in range(world_size - 1):
a = tensor_list[i] a = tensor_list[i]
b = tensor_list[i + 1] b = tensor_list[i + 1]
assert not torch.allclose(a, b), \ assert not torch.allclose(a, b), (
(f"expected tensors on rank {i} and {i + 1} not to be equal " f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
f"but they are, {a} vs {b}") )
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name == ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)

View File

@ -1,85 +1,13 @@
import pytest import pytest
import torch import torch
import torch.distributed as dist
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def delete_moe_info(model):
for name, param in model.named_parameters():
if hasattr(param, "moe_info"):
delattr(param, "moe_info")
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name == ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def run_zero_test(local_rank, stage=1): def run_zero_test(local_rank, stage=1):