add comments for moe

pull/182/head
Wenwen Qu 2023-08-25 19:03:31 +08:00
parent aa2612edc4
commit 629e6a5ad1
7 changed files with 23 additions and 24 deletions

View File

@ -349,7 +349,7 @@ class Initializer_Zero1(ProcessGroupInitializer):
class Initializer_Expert(ProcessGroupInitializer): class Initializer_Expert(ProcessGroupInitializer):
"""A ProcessGroupInitializer for zero-1 parallelism. """A ProcessGroupInitializer for expert parallelism.
Args: Args:
rank (int): The rank of current process. rank (int): The rank of current process.
@ -406,7 +406,7 @@ class Initializer_Expert(ProcessGroupInitializer):
class Initializer_Expert_Data(ProcessGroupInitializer): class Initializer_Expert_Data(ProcessGroupInitializer):
"""A ProcessGroupInitializer for zero-1 parallelism. """A ProcessGroupInitializer for expert data parallelism.
Args: Args:
rank (int): The rank of current process. rank (int): The rank of current process.

View File

@ -105,6 +105,7 @@ class NonPipelineScheduler(BaseScheduler):
# forward # forward
with conditional_context(torch.no_grad(), enable=forward_only): with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data) self._call_hooks("before_forward", data)
# moe_losses contains the loss of each layer
output, moe_losses = self._call_engine(engine, data) output, moe_losses = self._call_engine(engine, data)
self._call_hooks("after_forward", output) self._call_hooks("after_forward", output)

View File

@ -278,6 +278,7 @@ class PipelineScheduler(BaseScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data) self._call_hooks("before_forward", data)
# moe_losses contains the loss of each layer in current stage
output_obj, moe_losses = self._call_engine(engine.model, data) output_obj, moe_losses = self._call_engine(engine.model, data)
self._call_hooks("after_forward", output_obj) self._call_hooks("after_forward", output_obj)
@ -345,6 +346,9 @@ class PipelineScheduler(BaseScheduler):
else: else:
# scale the latent loss # scale the latent loss
moe_loss = moe_loss * engine.optimizer.loss_scale moe_loss = moe_loss * engine.optimizer.loss_scale
# we perform chain rule here by projecting the grad to the direction of
# [output_obj_grad, 1], Because moe_loss have no relation with subsequent
# layer, we set it to None (will be ragarded as 1).
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None]) engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
# Collect the grad of the input_obj. # Collect the grad of the input_obj.

View File

@ -175,6 +175,7 @@ class PackedFlashBaseLayer1D(nn.Module):
] ]
) )
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
if moe_use_residual: if moe_use_residual:
residual_mlp = FeedForward( residual_mlp = FeedForward(
hidden_size, hidden_size,
@ -186,6 +187,7 @@ class PackedFlashBaseLayer1D(nn.Module):
dtype=torch.float, dtype=torch.float,
) )
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
self.mlp = MoE( self.mlp = MoE(
hidden_size=hidden_size, hidden_size=hidden_size,
experts=experts, experts=experts,
@ -291,9 +293,9 @@ class PackedFlashBaseLayer1D(nn.Module):
# MLP. # MLP.
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
if self.num_experts <= 1: if self.num_experts <= 1: # dense mlp output
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
else: else: # MoE output
hidden_states, moe_loss, _ = self.mlp(hidden_states) hidden_states, moe_loss, _ = self.mlp(hidden_states)
return hidden_states + residual, moe_loss return hidden_states + residual, moe_loss

View File

@ -92,7 +92,6 @@ class HybridZeroOptimizer(BaseOptimizer):
cpu_offload=False, cpu_offload=False,
grad_scal_cfg: Config = None, grad_scal_cfg: Config = None,
zero_cfg: Config = None, zero_cfg: Config = None,
has_moe: bool = False,
param_bcast_sync_handler: ParamBcastSyncHandler = None, param_bcast_sync_handler: ParamBcastSyncHandler = None,
): ):
# DynamicGradScaler related args # DynamicGradScaler related args
@ -115,8 +114,6 @@ class HybridZeroOptimizer(BaseOptimizer):
super().__init__(optim=optimizer) super().__init__(optim=optimizer)
self.has_moe = has_moe
self._dtype = self.optim.param_groups[0]["params"][0].dtype self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._cpu_offload = cpu_offload self._cpu_offload = cpu_offload
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
@ -273,12 +270,6 @@ class HybridZeroOptimizer(BaseOptimizer):
def num_param_groups(self): def num_param_groups(self):
return len(self._fp16_param_groups) return len(self._fp16_param_groups)
def _get_real_dp_process_group(self, param_groups):
if "moe" in param_groups.keys() and param_groups["moe"]:
return ParallelMode.EXPERT_DATA
else:
return ParallelMode.DATA
def _partition_param_list(self, param_group): def _partition_param_list(self, param_group):
no_params_ranks = [] no_params_ranks = []
params_per_rank = [[] for _ in range(self._zero_world_size)] params_per_rank = [[] for _ in range(self._zero_world_size)]
@ -287,7 +278,8 @@ class HybridZeroOptimizer(BaseOptimizer):
param_list = param_group["params"] param_list = param_group["params"]
if self._is_moe_group(param_group): if self._is_moe_group(param_group):
# just add current params to params_per_rank[_zero_local_rank] # for moe group, we do not need to partition the params, just add current
# params to params_per_rank[_zero_local_rank]
params_per_rank[self._zero_local_rank] = list(param_list) params_per_rank[self._zero_local_rank] = list(param_list)
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
no_params_ranks = list(range(self._zero_world_size)) no_params_ranks = list(range(self._zero_world_size))
@ -538,8 +530,8 @@ class HybridZeroOptimizer(BaseOptimizer):
def _compute_norm_with_moe_group(self, group_id): def _compute_norm_with_moe_group(self, group_id):
parameters = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank) parameters = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
# wo do not get the average grad for moe parameters, so we have to constuct # wo do not get the average grad for moe parameters, so we have to constuct the gradients list here.
# the gradients list hear. Maybe this can be optimized. # Maybe this can be optimized.
gradients = [p.grad for p in parameters] gradients = [p.grad for p in parameters]
norm = compute_norm( norm = compute_norm(
gradients=gradients, gradients=gradients,
@ -666,8 +658,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# get the global norm # get the global norm
global_norm_groups = [] global_norm_groups = []
if self._clip_grad_norm > 0: if self._clip_grad_norm > 0:
for group_id in range(self.num_param_groups): for norm in norms:
global_norm_groups.append(norms[group_id] ** 0.5) global_norm_groups.append(norm**0.5)
# the following operations are performed only on the rank to which parameters are assigned. # the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32: if gpc.config.model.dtype is not torch.float32:

View File

@ -24,7 +24,7 @@ from internlm.data.packed_dataset import (
get_packed_dataset_without_short_length, get_packed_dataset_without_short_length,
) )
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.moe import create_moe_param_groups, has_moe_layers from internlm.model.moe import create_moe_param_groups
from internlm.monitor import set_env_var from internlm.monitor import set_env_var
from internlm.monitor.monitor import monitor_manager as mm from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -99,6 +99,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
param_bcast_sync_handler = None param_bcast_sync_handler = None
adam_cfg = gpc.config.adam adam_cfg = gpc.config.adam
# split the moe parameters into different groups
if gpc.config.model.num_experts > 1: if gpc.config.model.num_experts > 1:
params = create_moe_param_groups(model, adam_cfg.weight_decay) params = create_moe_param_groups(model, adam_cfg.weight_decay)
else: else:
@ -110,12 +111,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
eps=adam_cfg.adam_eps, eps=adam_cfg.adam_eps,
) )
has_moe = has_moe_layers(model)
optimizer = HybridZeroOptimizer( optimizer = HybridZeroOptimizer(
naive_optimizer, naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler, grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer, zero_cfg=gpc.config.hybrid_zero_optimizer,
has_moe=has_moe,
param_bcast_sync_handler=param_bcast_sync_handler, param_bcast_sync_handler=param_bcast_sync_handler,
) )
@ -377,7 +376,7 @@ def record_current_batch_training_metrics(
"tgs (tokens/gpu/second)": tk_per_gpu, "tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr, "lr": lr,
"loss_scale": scaler, "loss_scale": scaler,
"grad_norm": grad_norm, "grad_norm": grad_norm, # TODO: not scalar
} }
infos["micro_num"] = len(batch[1]) infos["micro_num"] = len(batch[1])

View File

@ -73,7 +73,7 @@ def save_model_checkpoint(folder, model):
""" """
states = model.state_dict() states = model.state_dict()
# get non-moe parameters # get non-expert parameters
states = get_non_moe_state_dict(states) states = get_non_moe_state_dict(states)
topo = get_model_topology(model) topo = get_model_topology(model)
@ -98,7 +98,7 @@ def save_model_checkpoint(folder, model):
topo_fp = os.path.join(folder, topo_fn) topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo) llm_save(topo_fp, saved_obj=topo)
# move the judgement logic into save_moe_checkpoint(.) # try to save expert parameter to separate files if model have moe layer
try_save_moe_checkpoint(folder, model) try_save_moe_checkpoint(folder, model)
torch.distributed.barrier() torch.distributed.barrier()
@ -147,6 +147,7 @@ def load_model_checkpoint(folder, model):
print("load: ", states[key].float(),flush=True) print("load: ", states[key].float(),flush=True)
""" """
# try to load expert parameter to separate files if model have moe layer
try_load_moe_checkpoint(folder, model, states) try_load_moe_checkpoint(folder, model, states)
missing_k, unexpected_keys = model.load_state_dict(states, strict=False) missing_k, unexpected_keys = model.load_state_dict(states, strict=False)