mirror of https://github.com/hpcaitech/ColossalAI
[moe] add checkpoint for moe models (#3354)
* [moe] add checkpoint for moe models * [hotfix] fix bugs in unit testpull/3367/head
parent
fee2af8610
commit
1a1d68b053
|
@ -1,3 +1,4 @@
|
||||||
|
from .checkpoint import load_moe_model, save_moe_model
|
||||||
from .experts import Experts, FFNExperts, TPExperts
|
from .experts import Experts, FFNExperts, TPExperts
|
||||||
from .layers import MoeLayer, MoeModule
|
from .layers import MoeLayer, MoeModule
|
||||||
from .routers import MoeRouter, Top1Router, Top2Router
|
from .routers import MoeRouter, Top1Router, Top2Router
|
||||||
|
@ -5,5 +6,5 @@ from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_expert
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
||||||
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter'
|
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .experts import MoeExperts
|
||||||
|
|
||||||
|
|
||||||
|
def save_moe_model(model: nn.Module, save_path: str):
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
torch.save(state_dict, save_path)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def load_moe_model(model: nn.Module, load_path: str):
|
||||||
|
state_dict = torch.load(load_path)
|
||||||
|
|
||||||
|
for prefix, module in model.named_modules():
|
||||||
|
if prefix.endswith('.moe_layer.experts'):
|
||||||
|
# this module should be an Experts instance
|
||||||
|
assert isinstance(module, MoeExperts)
|
||||||
|
|
||||||
|
ep_rank = dist.get_rank(module.dist_info.ep_group)
|
||||||
|
num_local = module.num_local_experts
|
||||||
|
for i in range(num_local):
|
||||||
|
expert_id = ep_rank * num_local + i
|
||||||
|
for name, _ in module.experts[i].named_parameters():
|
||||||
|
cur_key = f'{prefix}.experts.{i}.{name}'
|
||||||
|
param_key = f'{prefix}.experts.{expert_id}.{name}'
|
||||||
|
load_param = state_dict[param_key]
|
||||||
|
state_dict[cur_key] = load_param
|
||||||
|
|
||||||
|
for name, _ in module.experts[0].named_parameters():
|
||||||
|
pop_pre = f'{prefix}.experts.'
|
||||||
|
pop_suf = f'.{name}'
|
||||||
|
for i in range(num_local, module.num_total_experts):
|
||||||
|
pop_key = f'{pop_pre}{i}{pop_suf}'
|
||||||
|
state_dict.pop(pop_key)
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
|
@ -1,12 +1,15 @@
|
||||||
import math
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.context import ParallelMode, seed
|
from colossalai.context import ParallelMode, seed
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
|
|
||||||
class MoeExperts(nn.Module):
|
class MoeExperts(nn.Module):
|
||||||
|
@ -20,6 +23,7 @@ class MoeExperts(nn.Module):
|
||||||
assert comm_name in {"all_to_all", "all_gather"}, \
|
assert comm_name in {"all_to_all", "all_gather"}, \
|
||||||
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
||||||
self.comm_name = comm_name
|
self.comm_name = comm_name
|
||||||
|
self.num_total_experts = num_experts
|
||||||
# Get the configuration of experts' deployment and parallel information from moe contex
|
# Get the configuration of experts' deployment and parallel information from moe contex
|
||||||
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
||||||
|
|
||||||
|
@ -61,6 +65,33 @@ class Experts(MoeExperts):
|
||||||
output = torch.cat(expert_output, dim=1).contiguous()
|
output = torch.cat(expert_output, dim=1).contiguous()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||||
|
assert keep_vars == False, "Only support keep_vars=False now"
|
||||||
|
dp_rank = dist.get_rank(self.dist_info.dp_group)
|
||||||
|
ep_rank = dist.get_rank(self.dist_info.ep_group)
|
||||||
|
submodule_dict = dict()
|
||||||
|
example_submodule = None
|
||||||
|
for name, subm in self.experts.named_modules():
|
||||||
|
if subm is self.experts:
|
||||||
|
continue
|
||||||
|
module_number = self.num_local_experts * ep_rank + int(name)
|
||||||
|
submodule_dict[module_number] = subm
|
||||||
|
example_submodule = subm
|
||||||
|
|
||||||
|
if dp_rank == 0:
|
||||||
|
local_prefix = prefix + 'experts.'
|
||||||
|
buffer_module = deepcopy(example_submodule)
|
||||||
|
for i in range(self.num_total_experts):
|
||||||
|
source_rank = i // self.num_local_experts
|
||||||
|
current_prefix = local_prefix + str(i) + '.'
|
||||||
|
comm_module = submodule_dict.get(i, buffer_module)
|
||||||
|
for name, param in comm_module.named_parameters():
|
||||||
|
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
|
||||||
|
if ep_rank == 0:
|
||||||
|
destination[current_prefix + name] = param.data.cpu()
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
class FFNExperts(MoeExperts):
|
class FFNExperts(MoeExperts):
|
||||||
"""Use torch.bmm to speed up for multiple experts.
|
"""Use torch.bmm to speed up for multiple experts.
|
||||||
|
|
|
@ -1,17 +1,24 @@
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.nn.layer.moe._operation import (
|
||||||
from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \
|
COL_MOE_KERNEL_FLAG,
|
||||||
ReduceScatter, MoeDispatch, MoeCombine
|
AllGather,
|
||||||
from colossalai.nn.layer.moe.experts import MoeExperts, Experts
|
AllToAll,
|
||||||
from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator
|
MoeCombine,
|
||||||
|
MoeDispatch,
|
||||||
|
ReduceScatter,
|
||||||
|
)
|
||||||
|
from colossalai.nn.layer.moe.experts import Experts, MoeExperts
|
||||||
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
|
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
|
||||||
|
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||||
from typing import Optional, Type, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
@no_shard_zero_decrator(is_replicated=True)
|
@no_shard_zero_decrator(is_replicated=True)
|
||||||
|
@ -178,16 +185,16 @@ class MoeModule(nn.Module):
|
||||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||||
|
|
||||||
if expert_instance is not None:
|
if expert_instance is not None:
|
||||||
self.experts = expert_instance
|
my_experts = expert_instance
|
||||||
else:
|
else:
|
||||||
assert expert_cls is not None, \
|
assert expert_cls is not None, \
|
||||||
"Expert class can't be None when experts instance is not given"
|
"Expert class can't be None when experts instance is not given"
|
||||||
self.experts = Experts(expert_cls, num_experts, **expert_args)
|
my_experts = Experts(expert_cls, num_experts, **expert_args)
|
||||||
|
|
||||||
self.moe_layer = MoeLayer(dim_model=dim_model,
|
self.moe_layer = MoeLayer(dim_model=dim_model,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
router=self.moe_router,
|
router=self.moe_router,
|
||||||
experts=self.experts)
|
experts=my_experts)
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor):
|
def forward(self, inputs: torch.Tensor):
|
||||||
moe_output, l_aux = self.moe_layer(inputs)
|
moe_output, l_aux = self.moe_layer(inputs)
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.context import MOE_CONTEXT
|
||||||
|
from colossalai.nn.layer.moe import load_moe_model, save_moe_model
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port, get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||||
|
from tests.test_tensor.common_utils import debug_print
|
||||||
|
from tests.test_zero.common import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def exam_moe_checkpoint():
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = MoeModel(checkpoint=True)
|
||||||
|
save_moe_model(model, 'temp_path.pth')
|
||||||
|
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
other_model = MoeModel(checkpoint=True)
|
||||||
|
load_moe_model(other_model, 'temp_path.pth')
|
||||||
|
|
||||||
|
state_0 = model.state_dict()
|
||||||
|
state_1 = other_model.state_dict()
|
||||||
|
for k, v in state_0.items():
|
||||||
|
u = state_1.get(k)
|
||||||
|
assert torch.equal(u.data, v.data)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
os.remove('temp_path.pth')
|
||||||
|
|
||||||
|
|
||||||
|
def _run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
MOE_CONTEXT.setup(seed=42)
|
||||||
|
exam_moe_checkpoint()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize("world_size", [2, 4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_moe_checkpoint(world_size):
|
||||||
|
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_moe_checkpoint(world_size=4)
|
Loading…
Reference in New Issue