[moe] add checkpoint for moe models (#3354)

* [moe] add checkpoint for moe models

* [hotfix] fix bugs in unit test
pull/3367/head
HELSON 2023-03-31 09:20:33 +08:00 committed by GitHub
parent fee2af8610
commit 1a1d68b053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 517 additions and 384 deletions

View File

@ -1,3 +1,4 @@
from .checkpoint import load_moe_model, save_moe_model
from .experts import Experts, FFNExperts, TPExperts from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, MoeModule from .layers import MoeLayer, MoeModule
from .routers import MoeRouter, Top1Router, Top2Router from .routers import MoeRouter, Top1Router, Top2Router
@ -5,5 +6,5 @@ from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_expert
__all__ = [ __all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter' 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model'
] ]

View File

@ -0,0 +1,40 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from .experts import MoeExperts
def save_moe_model(model: nn.Module, save_path: str):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, save_path)
dist.barrier()
def load_moe_model(model: nn.Module, load_path: str):
state_dict = torch.load(load_path)
for prefix, module in model.named_modules():
if prefix.endswith('.moe_layer.experts'):
# this module should be an Experts instance
assert isinstance(module, MoeExperts)
ep_rank = dist.get_rank(module.dist_info.ep_group)
num_local = module.num_local_experts
for i in range(num_local):
expert_id = ep_rank * num_local + i
for name, _ in module.experts[i].named_parameters():
cur_key = f'{prefix}.experts.{i}.{name}'
param_key = f'{prefix}.experts.{expert_id}.{name}'
load_param = state_dict[param_key]
state_dict[cur_key] = load_param
for name, _ in module.experts[0].named_parameters():
pop_pre = f'{prefix}.experts.'
pop_suf = f'.{name}'
for i in range(num_local, module.num_total_experts):
pop_key = f'{pop_pre}{i}{pop_suf}'
state_dict.pop(pop_key)
model.load_state_dict(state_dict)

View File

@ -1,12 +1,15 @@
import math import math
from copy import deepcopy
from typing import Type
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_decrator from colossalai.zero.init_ctx import no_shard_zero_decrator
from typing import Type
class MoeExperts(nn.Module): class MoeExperts(nn.Module):
@ -20,6 +23,7 @@ class MoeExperts(nn.Module):
assert comm_name in {"all_to_all", "all_gather"}, \ assert comm_name in {"all_to_all", "all_gather"}, \
"This kind of communication has not been implemented yet.\n Please use Experts build function." "This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm_name = comm_name self.comm_name = comm_name
self.num_total_experts = num_experts
# Get the configuration of experts' deployment and parallel information from moe contex # Get the configuration of experts' deployment and parallel information from moe contex
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
@ -61,6 +65,33 @@ class Experts(MoeExperts):
output = torch.cat(expert_output, dim=1).contiguous() output = torch.cat(expert_output, dim=1).contiguous()
return output return output
def state_dict(self, destination=None, prefix='', keep_vars=False):
assert keep_vars == False, "Only support keep_vars=False now"
dp_rank = dist.get_rank(self.dist_info.dp_group)
ep_rank = dist.get_rank(self.dist_info.ep_group)
submodule_dict = dict()
example_submodule = None
for name, subm in self.experts.named_modules():
if subm is self.experts:
continue
module_number = self.num_local_experts * ep_rank + int(name)
submodule_dict[module_number] = subm
example_submodule = subm
if dp_rank == 0:
local_prefix = prefix + 'experts.'
buffer_module = deepcopy(example_submodule)
for i in range(self.num_total_experts):
source_rank = i // self.num_local_experts
current_prefix = local_prefix + str(i) + '.'
comm_module = submodule_dict.get(i, buffer_module)
for name, param in comm_module.named_parameters():
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
if ep_rank == 0:
destination[current_prefix + name] = param.data.cpu()
dist.barrier()
class FFNExperts(MoeExperts): class FFNExperts(MoeExperts):
"""Use torch.bmm to speed up for multiple experts. """Use torch.bmm to speed up for multiple experts.

View File

@ -1,17 +1,24 @@
import math import math
from typing import Optional, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.nn.layer.moe._operation import (
from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \ COL_MOE_KERNEL_FLAG,
ReduceScatter, MoeDispatch, MoeCombine AllGather,
from colossalai.nn.layer.moe.experts import MoeExperts, Experts AllToAll,
from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator MoeCombine,
MoeDispatch,
ReduceScatter,
)
from colossalai.nn.layer.moe.experts import Experts, MoeExperts
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from typing import Optional, Type, Tuple
@no_shard_zero_decrator(is_replicated=True) @no_shard_zero_decrator(is_replicated=True)
@ -178,16 +185,16 @@ class MoeModule(nn.Module):
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None: if expert_instance is not None:
self.experts = expert_instance my_experts = expert_instance
else: else:
assert expert_cls is not None, \ assert expert_cls is not None, \
"Expert class can't be None when experts instance is not given" "Expert class can't be None when experts instance is not given"
self.experts = Experts(expert_cls, num_experts, **expert_args) my_experts = Experts(expert_cls, num_experts, **expert_args)
self.moe_layer = MoeLayer(dim_model=dim_model, self.moe_layer = MoeLayer(dim_model=dim_model,
num_experts=num_experts, num_experts=num_experts,
router=self.moe_router, router=self.moe_router,
experts=self.experts) experts=my_experts)
def forward(self, inputs: torch.Tensor): def forward(self, inputs: torch.Tensor):
moe_output, l_aux = self.moe_layer(inputs) moe_output, l_aux = self.moe_layer(inputs)

View File

@ -0,0 +1,54 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer.moe import load_moe_model, save_moe_model
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_tensor.common_utils import debug_print
from tests.test_zero.common import CONFIG
def exam_moe_checkpoint():
with ColoInitContext(device=get_current_device()):
model = MoeModel(checkpoint=True)
save_moe_model(model, 'temp_path.pth')
with ColoInitContext(device=get_current_device()):
other_model = MoeModel(checkpoint=True)
load_moe_model(other_model, 'temp_path.pth')
state_0 = model.state_dict()
state_1 = other_model.state_dict()
for k, v in state_0.items():
u = state_1.get(k)
assert torch.equal(u.data, v.data)
if dist.get_rank() == 0:
os.remove('temp_path.pth')
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
exam_moe_checkpoint()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_checkpoint(world_size=4)