add comments for moe

pull/182/head
Wenwen Qu 2023-08-25 19:03:31 +08:00
parent aa2612edc4
commit 629e6a5ad1
7 changed files with 23 additions and 24 deletions

View File

@ -349,7 +349,7 @@ class Initializer_Zero1(ProcessGroupInitializer):
class Initializer_Expert(ProcessGroupInitializer):
"""A ProcessGroupInitializer for zero-1 parallelism.
"""A ProcessGroupInitializer for expert parallelism.
Args:
rank (int): The rank of current process.
@ -406,7 +406,7 @@ class Initializer_Expert(ProcessGroupInitializer):
class Initializer_Expert_Data(ProcessGroupInitializer):
"""A ProcessGroupInitializer for zero-1 parallelism.
"""A ProcessGroupInitializer for expert data parallelism.
Args:
rank (int): The rank of current process.

View File

@ -105,6 +105,7 @@ class NonPipelineScheduler(BaseScheduler):
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data)
# moe_losses contains the loss of each layer
output, moe_losses = self._call_engine(engine, data)
self._call_hooks("after_forward", output)

View File

@ -278,6 +278,7 @@ class PipelineScheduler(BaseScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
# moe_losses contains the loss of each layer in current stage
output_obj, moe_losses = self._call_engine(engine.model, data)
self._call_hooks("after_forward", output_obj)
@ -345,6 +346,9 @@ class PipelineScheduler(BaseScheduler):
else:
# scale the latent loss
moe_loss = moe_loss * engine.optimizer.loss_scale
# we perform chain rule here by projecting the grad to the direction of
# [output_obj_grad, 1], Because moe_loss have no relation with subsequent
# layer, we set it to None (will be ragarded as 1).
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
# Collect the grad of the input_obj.

View File

@ -175,6 +175,7 @@ class PackedFlashBaseLayer1D(nn.Module):
]
)
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
if moe_use_residual:
residual_mlp = FeedForward(
hidden_size,
@ -186,6 +187,7 @@ class PackedFlashBaseLayer1D(nn.Module):
dtype=torch.float,
)
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
self.mlp = MoE(
hidden_size=hidden_size,
experts=experts,
@ -291,9 +293,9 @@ class PackedFlashBaseLayer1D(nn.Module):
# MLP.
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
if self.num_experts <= 1:
if self.num_experts <= 1: # dense mlp output
hidden_states = self.mlp(hidden_states)
else:
else: # MoE output
hidden_states, moe_loss, _ = self.mlp(hidden_states)
return hidden_states + residual, moe_loss

View File

@ -92,7 +92,6 @@ class HybridZeroOptimizer(BaseOptimizer):
cpu_offload=False,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
has_moe: bool = False,
param_bcast_sync_handler: ParamBcastSyncHandler = None,
):
# DynamicGradScaler related args
@ -115,8 +114,6 @@ class HybridZeroOptimizer(BaseOptimizer):
super().__init__(optim=optimizer)
self.has_moe = has_moe
self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._cpu_offload = cpu_offload
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
@ -273,12 +270,6 @@ class HybridZeroOptimizer(BaseOptimizer):
def num_param_groups(self):
return len(self._fp16_param_groups)
def _get_real_dp_process_group(self, param_groups):
if "moe" in param_groups.keys() and param_groups["moe"]:
return ParallelMode.EXPERT_DATA
else:
return ParallelMode.DATA
def _partition_param_list(self, param_group):
no_params_ranks = []
params_per_rank = [[] for _ in range(self._zero_world_size)]
@ -287,7 +278,8 @@ class HybridZeroOptimizer(BaseOptimizer):
param_list = param_group["params"]
if self._is_moe_group(param_group):
# just add current params to params_per_rank[_zero_local_rank]
# for moe group, we do not need to partition the params, just add current
# params to params_per_rank[_zero_local_rank]
params_per_rank[self._zero_local_rank] = list(param_list)
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
no_params_ranks = list(range(self._zero_world_size))
@ -538,8 +530,8 @@ class HybridZeroOptimizer(BaseOptimizer):
def _compute_norm_with_moe_group(self, group_id):
parameters = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
# wo do not get the average grad for moe parameters, so we have to constuct
# the gradients list hear. Maybe this can be optimized.
# wo do not get the average grad for moe parameters, so we have to constuct the gradients list here.
# Maybe this can be optimized.
gradients = [p.grad for p in parameters]
norm = compute_norm(
gradients=gradients,
@ -666,8 +658,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# get the global norm
global_norm_groups = []
if self._clip_grad_norm > 0:
for group_id in range(self.num_param_groups):
global_norm_groups.append(norms[group_id] ** 0.5)
for norm in norms:
global_norm_groups.append(norm**0.5)
# the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32:

View File

@ -24,7 +24,7 @@ from internlm.data.packed_dataset import (
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.moe import create_moe_param_groups, has_moe_layers
from internlm.model.moe import create_moe_param_groups
from internlm.monitor import set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -99,6 +99,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
param_bcast_sync_handler = None
adam_cfg = gpc.config.adam
# split the moe parameters into different groups
if gpc.config.model.num_experts > 1:
params = create_moe_param_groups(model, adam_cfg.weight_decay)
else:
@ -110,12 +111,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
eps=adam_cfg.adam_eps,
)
has_moe = has_moe_layers(model)
optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
has_moe=has_moe,
param_bcast_sync_handler=param_bcast_sync_handler,
)
@ -377,7 +376,7 @@ def record_current_batch_training_metrics(
"tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr,
"loss_scale": scaler,
"grad_norm": grad_norm,
"grad_norm": grad_norm, # TODO: not scalar
}
infos["micro_num"] = len(batch[1])

View File

@ -73,7 +73,7 @@ def save_model_checkpoint(folder, model):
"""
states = model.state_dict()
# get non-moe parameters
# get non-expert parameters
states = get_non_moe_state_dict(states)
topo = get_model_topology(model)
@ -98,7 +98,7 @@ def save_model_checkpoint(folder, model):
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)
# move the judgement logic into save_moe_checkpoint(.)
# try to save expert parameter to separate files if model have moe layer
try_save_moe_checkpoint(folder, model)
torch.distributed.barrier()
@ -147,6 +147,7 @@ def load_model_checkpoint(folder, model):
print("load: ", states[key].float(),flush=True)
"""
# try to load expert parameter to separate files if model have moe layer
try_load_moe_checkpoint(folder, model, states)
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)