[moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine
pull/5023/head
Xuanlei Zhao 2023-11-08 23:07:03 +08:00 committed by GitHub
parent 67f5331754
commit f71e63b0f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 738 additions and 150 deletions

View File

@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoeCheckpintIO
from colossalai.moe import MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
@ -322,8 +322,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
**_kwargs,
)
def get_checkpoint_io(self) -> MoeCheckpintIO:
self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def get_checkpoint_io(self) -> MoECheckpintIO:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def configure(
@ -359,9 +359,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
max_norm=self.max_norm,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info

View File

@ -79,13 +79,15 @@ class TPInferEngine:
self.multi_query_group_num = model.config.num_attention_heads
# default to attention_heads
self.multi_query_attention = model.config.multi_query_attention
if hasattr(model.config, "multi_query_attention"):
self.multi_query_attention = getattr(model.config, "multi_query_attention")
if hasattr(model.config, "multi_query_group_num"):
self.multi_query_group_num = model.config.multi_query_group_num
self.multi_query_group_num = getattr(model.config, "multi_query_group_num")
if hasattr(model.config, "num_key_value_heads"):
self.multi_query_group_num = model.config.num_key_value_heads
self.multi_query_group_num = getattr(model.config, "num_key_value_heads")
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
@ -108,7 +110,7 @@ class TPInferEngine:
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
if self.multi_query_attention:
if hasattr(self, "multi_query_attention"):
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0

View File

@ -1,4 +1,4 @@
from .checkpoint import MoeCheckpintIO
from .checkpoint import MoECheckpintIO
from .experts import MLPExperts
from .layers import SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
@ -13,5 +13,5 @@ __all__ = [
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoeCheckpintIO",
"MoECheckpintIO",
]

View File

@ -1,32 +1,46 @@
import copy
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import OptimizerWrapper
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
from colossalai.tensor.moe_tensor.api import (
get_dp_group,
get_dp_rank,
get_dp_size,
get_ep_group,
get_ep_rank,
get_ep_size,
is_moe_tensor,
)
class MoeCheckpintIO(HybridParallelCheckpointIO):
class MoECheckpintIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
@ -55,7 +69,7 @@ class MoeCheckpintIO(HybridParallelCheckpointIO):
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num]
state_dict[name] = param
dist.barrier()
return state_dict
@ -156,7 +170,7 @@ class MoeCheckpintIO(HybridParallelCheckpointIO):
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [deepcopy(param) for _ in range(ep_size)]
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
@ -245,30 +259,523 @@ class MoeCheckpintIO(HybridParallelCheckpointIO):
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
dist.barrier()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
raise NotImplementedError()
def pre_load_optim(
self,
state: OrderedDict,
working_param,
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
inplace: bool,
) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
raise NotImplementedError()
Args:
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
current_shape (torch.Size): The size of parameter after sharding.
original_shape (torch.Size): The size of parameter before sharding.
device (torch.device): The destination device of loaded optimizer states.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
is_moe_tensor_flag = is_moe_tensor(working_param)
if is_moe_tensor_flag:
ep_rank = get_ep_rank(working_param)
ep_size = get_ep_size(working_param)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
if is_moe_tensor_flag:
with torch.no_grad():
expert_num = v.shape[0] // ep_size
assert v.shape[0] % ep_size == 0
v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num]
else:
# Shard state along data parallel group when using Zero.
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
# ep extra group
if MOE_MANAGER.parallel == "EP":
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
dist.barrier()
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from a file with given path.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the checkpoint file.
"""
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
if id(working_param) in optimizer.param_info["param2id"]:
return optimizer.param_info["param2id"][id(working_param)]
else:
None
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
# Load param_groups.
updated_groups = []
saved_groups = state_dict["param_groups"]
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
updated_groups.append(new_pg)
# ep extra group
if MOE_MANAGER.parallel == "EP":
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
master_to_working_map = optimizer.get_master_to_working_map()
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id is not None:
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
dist.barrier()
def pre_save_optim(
self,
state: OrderedDict,
param: torch.Tensor,
inplace: bool,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
if is_moe_tensor(param):
moe_dp_group = get_dp_group(param)
moe_dp_size = get_dp_size(param)
moe_ep_group = get_ep_group(param)
moe_ep_size = get_ep_size(param)
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# moe param
if is_moe_tensor(param):
# dp gather
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
dist.all_gather(gather_tensor, v, group=moe_dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# ep gather
gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)]
dist.all_gather(gather_tensor, v, group=moe_ep_group)
v = torch.cat(gather_tensor, dim=0)
else:
# global dp
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))]
dist.all_gather(gather_tensor, v, group=self.dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
state_[k] = v.detach().clone().to(device)
return state_
def _optimizer_sharder(
self,
optimizer: OptimizerWrapper,
size_per_shard: int = 1024,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
if param is None:
continue
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info["param2id"][id(working_param)]
state_ = self.pre_save_optim(
state,
working_param,
inplace=False,
device=torch.device("cuda"),
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_optimizer(
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
raise NotImplementedError()
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
raise NotImplementedError()
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.dp_rank != 0:
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = self._optimizer_sharder(
optimizer,
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for param_id, state_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(param_id, state_filename)
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer state dict to a file with given path.
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
checkpoint (str): Path to save optimizer state_dict.
gather_dtensor (bool): Whether to gather_dtensor, not used.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
# optimizer states of parameters kept by local device('s pipeline stage)
local_states = dict()
for param, state in optimizer.optim.state.items():
if param is None:
continue
# working param is needed for obtaining correct param_id
master_to_working_map = optimizer.get_master_to_working_map()
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
local_states[param_id] = self.pre_save_optim(
state,
working_param,
inplace=False,
device=torch.device("cuda"),
)
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
dist.all_gather_object(states_list, local_states, self.pp_group)
# Only the master rank do the saving.
if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
dist.barrier()

View File

@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
@ -53,7 +53,8 @@ class MLPExperts(nn.Module):
# get expert parallel info
if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False)
num_experts, use_tp=True if expert_parallel == "TP" else False
)
# get settings for different parallel
self.ep_size = get_ep_size(self)
if expert_parallel == "TP":
@ -87,7 +88,7 @@ class MLPExperts(nn.Module):
def reset_parameters(self):
# expert param should be different
if self.expert_parallel is not None:
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True)
seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
else:
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
with seed_ctx:
@ -99,10 +100,10 @@ class MLPExperts(nn.Module):
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
def forward(
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
) -> torch.Tensor:
"""
forward: hidden_size --> intermediate_size --> hidden_size
@ -129,7 +130,7 @@ class MLPExperts(nn.Module):
mask = torch.sum(mask, dim=-1)
x_list = []
for i in range(e):
x_list.append(x[i, :mask[i]])
x_list.append(x[i, : mask[i]])
x = x_list
if self.gated:

View File

@ -8,14 +8,13 @@ from colossalai.tensor.moe_tensor.api import get_moe_info
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
class MoeManager(metaclass=SingletonMeta):
class MoEManager(metaclass=SingletonMeta):
"""MoE manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
def __init__(self):
self.parallel = None
self.seed = None
self.mode = None
self.use_ep_inside = None
self.world_size = None
@ -48,7 +47,6 @@ class MoeManager(metaclass=SingletonMeta):
def setup(
self,
seed: int,
parallel: str = None,
mode: str = "dynamic",
max_ep_size: int = 8,
@ -73,10 +71,9 @@ class MoeManager(metaclass=SingletonMeta):
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
"""
assert (not self.is_initialized), "MoE distributed context shouldn't be set up again"
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.seed = seed + dist.get_rank()
self.parallel = parallel
self.use_ep_inside = use_ep_inside
self.world_size = dist.get_world_size()
@ -87,10 +84,12 @@ class MoeManager(metaclass=SingletonMeta):
if self.mode == "dynamic":
self.max_ep_size = min(max_ep_size, self.world_size)
else:
assert (fixed_dp_size > 0 and fixed_ep_size > 0
and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0"
assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int)
and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int"
assert (
fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
), "dp_size, ep_size and pp_size should be greater than 0"
assert (
isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
), "dp_size, ep_size and pp_size should be int"
self.ep_size = fixed_ep_size
self.dp_size = fixed_dp_size
self.pp_size = fixed_pp_size
@ -112,10 +111,12 @@ class MoeManager(metaclass=SingletonMeta):
"""
if self.mode == "dynamic":
gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater
lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less
assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa.")
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, (
"Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
dp_size = 1 if gt_flag else self.world_size // num_experts
ep_size = min(self.world_size // dp_size, self.max_ep_size)
dp_size = self.world_size // ep_size
@ -159,4 +160,4 @@ class MoeManager(metaclass=SingletonMeta):
return self.parallel
MOE_MANAGER = MoeManager()
MOE_MANAGER = MoEManager()

View File

@ -72,6 +72,19 @@ def get_ep_size(tensor: torch.Tensor) -> int:
return tensor.moe_info.ep_size
def get_dp_size(tensor: torch.Tensor) -> int:
"""
Get the data parallel size of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
int: The data parallel size of the given tensor.
"""
return tensor.moe_info.dp_size
def get_dp_group(tensor: torch.Tensor) -> ProcessGroup:
"""
Get the data parallel group of the given tensor.

View File

@ -155,9 +155,7 @@ def main():
"precision": "bf16",
"zero_stage": args.zero_stage,
}
mgr_dict = {
"seed": 42,
}
mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(

View File

@ -41,7 +41,7 @@ def fsdp_main(rank, world_size, args):
# initialize the process group
dist.init_process_group("nccl")
MOE_MANAGER.setup(seed=42, parallel=None)
MOE_MANAGER.setup(parallel=None)
dp_size = dist.get_world_size()
dataset = RandomDataset(

View File

@ -1,5 +1,5 @@
colossalai >= 0.3.3
torch >= 1.8.1
transformers >= 4.20.0
transformers >= 4.20.0, <= 4.34.0
sentencepiece
datasets

View File

@ -213,9 +213,7 @@ def main():
"precision": args.precision,
"zero_stage": args.zero_stage,
}
mgr_dict = {
"seed": 42,
}
mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(

View File

@ -6,10 +6,9 @@ import torch.nn as nn
import colossalai
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group
from tests.test_moe.moe_utils import MoeGradientHandler
BATCH_SIZE = 4
DIM = 16
@ -25,7 +24,7 @@ def run_test(rank, world_size, port):
backend="nccl",
)
MOE_MANAGER.setup(42, parallel="EP") # MOE initialization
MOE_MANAGER.setup(parallel="EP") # MOE initialization
num_experts_list = [1, 2, 4]
layer_list = []
for num_experts in num_experts_list:
@ -41,15 +40,6 @@ def run_test(rank, world_size, port):
model = nn.ModuleList(layer_list)
model = model.to(get_current_device())
dist_dict = MOE_MANAGER.parallel_info_dict
assert_not_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
assert_not_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
assert_not_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)
assert_not_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group)
assert_not_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group)
assert_not_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group)
sync_moe_model_param(model)
assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)

View File

@ -20,21 +20,23 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# Here we do not need TF32, since it brings absolute error on results
torch.backends.cuda.matmul.allow_tf32 = False
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
local_rank = dist.get_rank()
MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization
MOE_MANAGER.setup(parallel="EP") # MOE environment initialization
MOE_MANAGER.reset_loss()
torch.manual_seed(rs + local_rank) # set each process has different random seed
torch.manual_seed(rs + local_rank) # set each process has different random seed
# get randomized data
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
layer = SparseMLP(hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_experts=NUM_EXPERTS,
router_top_k=topk,
router_capacity_factor_train=1.0)
layer = SparseMLP(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_experts=NUM_EXPERTS,
router_top_k=topk,
router_capacity_factor_train=1.0,
)
layer = layer.to(get_current_device())
if data_type == torch.float16:
layer = layer.half()
@ -55,7 +57,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer.gate_weight.grad.zero_()
layer.enable_kernel = True
new_out = layer(tokens) # get outputs through colossal kernel
new_out = layer(tokens) # get outputs through colossal kernel
if data_type == torch.float32:
check_equal(old_out, new_out)
@ -90,5 +92,5 @@ def test_moe_kernel(rs, hidden_size, data_type, topk):
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
if __name__ == '__main__':
if __name__ == "__main__":
test_moe_kernel(2, 256, torch.float16, 2)

View File

@ -12,53 +12,112 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
sys.path.append(os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"examples/language/openmoe",
))
sys.path.append(
os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"examples/language/openmoe",
)
)
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
}
def run_fwd_bwd(
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
):
model.train()
if pipeline:
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
y = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = y["loss"]
else:
if criterion:
y = model(data).logits
loss = criterion(y)
else:
loss = model(data, label)
loss = loss.float()
if optimizer is not None:
optimizer.backward(loss)
else:
loss.backward()
return y
def get_config():
config = LlamaConfig(
vocab_size=300,
hidden_size=16,
intermediate_size=32,
num_hidden_layers=4,
num_hidden_layers=2,
num_attention_heads=2,
head_dim=4,
dropout_rate=0.0,
hidden_act="swiglu",
)
set_openmoe_args(config, num_experts=16, moe_layer_interval=1)
set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
return config
def get_model(parallel):
config = get_config()
model = OpenMoeForCausalLM(config)
optim = torch.optim.Adam(model.parameters())
if parallel == None:
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
zero_stage=0,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "zero_ep":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "ep":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "ep_zero":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=2,
zero_stage=1,
@ -66,54 +125,77 @@ def get_model(parallel):
custom_policy=OpenMoeForCausalLMPolicy(),
)
booster = Booster(plugin=plugin)
model, _, _, _, _ = booster.boost(model=model)
return model, booster
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
return model, booster, optim
def _test_moe_checkpoint(parallel, shard):
def _test_moe_checkpoint(rank, parallel):
if parallel == None:
MOE_MANAGER.setup(
seed=42,
parallel=None,
)
elif parallel == "zero2_ep":
elif parallel == "ep":
MOE_MANAGER.setup(
seed=42,
parallel="EP",
)
elif parallel == "ep_zero":
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=2,
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
seed=42,
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1 = get_model(parallel)
model2, booster2 = get_model(parallel)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel)
if shard:
booster1.save_model(model1, "./tmp_ckpt", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt")
# param ckpt
# shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# unshard
booster1.save_model(model1, "./tmp_ckpt1.pth")
booster3.load_model(model3, "./tmp_ckpt1.pth")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
# optim ckpt
criterion = lambda x: x.mean()
data = torch.randint(0, 4, (2, 4)).cuda()
label = torch.randint(0, 4, (2,)).cuda()
if parallel == "hybrid":
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
else:
booster1.save_model(model1, "tmp_ckpt.pth")
booster2.load_model(model2, "tmp_ckpt.pth")
state1 = model1.state_dict()
state2 = model2.state_dict()
for k, v in state1.items():
u = state2.get(k)
assert torch.equal(u.data, v.data)
kwargs = {}
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
optim1.step()
optim1.zero_grad()
# shard
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
dist.barrier()
booster2.load_optimizer(optim2, "./tmp_ckpt2")
# unshard
booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
# check
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
if dist.get_rank() == 0:
if shard:
shutil.rmtree("./tmp_ckpt")
else:
os.remove("tmp_ckpt.pth")
shutil.rmtree("./tmp_ckpt1")
shutil.rmtree("./tmp_ckpt2")
os.remove("./tmp_ckpt1.pth")
os.remove("./tmp_ckpt2.pth")
def _run_dist(rank, world_size, port, parallel, shard):
def _run_dist(rank, world_size, port, parallel):
colossalai.launch(
config=dict(),
rank=rank,
@ -122,17 +204,16 @@ def _run_dist(rank, world_size, port, parallel, shard):
port=port,
backend="nccl",
)
_test_moe_checkpoint(parallel, shard)
_test_moe_checkpoint(rank, parallel)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "zero_ep", "hybrid"])
@pytest.mark.parametrize("shard", [True, False])
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel, shard):
spawn(_run_dist, world_size, parallel=parallel, shard=shard)
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid", shard=True)
test_moe_checkpoint(world_size=4, parallel="hybrid")

View File

@ -14,16 +14,16 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, syn
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
assert batch_size % world_size == 0
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel=None)
MOE_MANAGER.setup(parallel=None)
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel="EP")
MOE_MANAGER.setup(parallel="EP")
ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed, parallel="TP")
MOE_MANAGER.setup(parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_current_device())
tp_model = tp_model.to(get_current_device())
@ -44,7 +44,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
torch.cuda.manual_seed(seed)
tp_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size
ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)]
ep_data = tp_data.detach()[micro_batch_size * rank : micro_batch_size * (rank + 1)]
out_local = local_model(tp_data)
MOE_MANAGER.reset_loss()
@ -52,8 +52,8 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
MOE_MANAGER.reset_loss()
out_ep = ep_model(ep_data)
MOE_MANAGER.reset_loss()
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank : micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_local[micro_batch_size * rank : micro_batch_size * (rank + 1)])
out_local.mean().backward()
out_tp.mean().backward()
@ -77,5 +77,5 @@ def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
if __name__ == '__main__':
if __name__ == "__main__":
test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42)

View File

@ -15,7 +15,7 @@ INTERMEDIATE_SIZE = 8
def run_moe_init(expert_parallel):
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=expert_parallel)
MOE_MANAGER.setup(parallel=expert_parallel)
expert_args = dict(
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,

View File

@ -35,13 +35,13 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
label = torch.randint(0, 4, (16,)).cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=None)
MOE_MANAGER.setup(parallel=None)
torch_model = MoeModel()
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP")
MOE_MANAGER.setup(max_ep_size=2, use_ep_inside=False, parallel="EP")
zero_model = MoeModel()
extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group
ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group)

View File

@ -45,7 +45,6 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
MOE_MANAGER.__init__()
MOE_MANAGER.setup(
seed=42,
parallel="EP",
)
zero_model = MoeModel(enable_load_balance=True)
@ -55,7 +54,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel="EP")
MOE_MANAGER.setup(parallel="EP")
torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
@ -94,7 +93,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
zero_optimizer.step()
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}"
assert torch.allclose(zero_out, torch_out, atol=3e-5), f"zero_out:{zero_out}\ntorch_out{torch_out}"
def run_hybrid_zero_optim_test(local_rank, world_size, stage=1):
@ -103,14 +102,13 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1):
label = torch.randint(0, 4, (16,)).cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(seed=42, parallel=None)
MOE_MANAGER.setup(parallel=None)
torch_model = MoeModel()
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(
seed=42,
max_ep_size=2,
use_ep_inside=False,
parallel="EP",

View File

@ -88,7 +88,7 @@ def run_zero_test(local_rank, world_size, stage=1):
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(seed=42, parallel="EP")
MOE_MANAGER.setup(parallel="EP")
seed_all(42 + rank)
run_zero_test(rank, world_size, stage=1)
run_zero_test(rank, world_size, stage=2)

View File

@ -76,7 +76,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(seed=42, parallel="EP")
MOE_MANAGER.setup(parallel="EP")
run_zero_optim_test(rank, world_size, stage=1)
run_zero_optim_test(rank, world_size, stage=2)