mirror of https://github.com/InternLM/InternLM
refactor code
parent
bba9b01c0e
commit
9665321745
|
@ -491,7 +491,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
||||
if self.pipeline_parallel_size > 1:
|
||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||
if self.config.model.num_experts > 1:
|
||||
if self.config.model.get("num_experts", 1) > 1:
|
||||
initializers.append(pgroup_initializer.Initializer_Expert(*initializer_args))
|
||||
for initializer in initializers:
|
||||
parallel_setting = initializer.init_dist_group()
|
||||
|
|
|
@ -15,13 +15,10 @@ from internlm.core.context import global_context as gpc
|
|||
from internlm.core.engine import Engine
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import get_current_device, move_to_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def get_tensor_shape():
|
||||
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
||||
|
@ -1347,8 +1344,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
else:
|
||||
output, label = (None, None)
|
||||
|
||||
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}")
|
||||
|
||||
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
accum_moe_loss = self._accum_moe_loss
|
||||
|
||||
|
|
|
@ -501,7 +501,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"The layer sharding is {all_parts}.") # pylint: disable=W1203
|
||||
logger.info(f"The layer sharding is {all_parts}.")
|
||||
|
||||
models = []
|
||||
|
||||
|
|
|
@ -69,10 +69,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
|||
return gumbel(shape)
|
||||
|
||||
|
||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||
|
||||
|
||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""
|
||||
|
@ -477,10 +473,10 @@ class MOELayer(Base):
|
|||
|
||||
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
|
||||
|
||||
a = combined_output.reshape(inputs[0].shape)
|
||||
out = combined_output.reshape(inputs[0].shape)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("moe").stop()
|
||||
self.time_moe = timer("moe").elapsed(reset=False)
|
||||
|
||||
return a
|
||||
return out
|
||||
|
|
|
@ -592,7 +592,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
groups_norms = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||
groups_norms.append([])
|
||||
groups_norms.append(None)
|
||||
else:
|
||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
|
||||
|
|
|
@ -9,12 +9,23 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
Compatiable with muiltiple param groups, each should have a name
|
||||
|
||||
Args:
|
||||
param_groups (Tuple[Dict]):
|
||||
The list of parameter groups to split
|
||||
param_groups (Tuple[Dict]): The list of parameter groups to split
|
||||
Output Example:
|
||||
>>> (
|
||||
>>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
|
||||
>>> ...,
|
||||
>>> )
|
||||
|
||||
Returns:
|
||||
Tuple[Dict]:
|
||||
list of MoE/non-MoE groups for optimizer
|
||||
Tuple[Dict]: list of params groups for optimizer
|
||||
Output Example:
|
||||
>>> (
|
||||
>>> {'name': 'default','params': [tensor],'weight_decay' :xxx},
|
||||
>>> {'name': 'norm', 'norm': True, 'params': [tensor],'weight_decay' :xxx},
|
||||
>>> {'name': 'gate', 'gate': True, 'params': [tensor],'weight_decay' :xxx},
|
||||
>>> {'name': 'moe_ep_size_4', 'moe': True, 'params': [tensor],'weight_decay' :xxx},
|
||||
>>> ...,
|
||||
>>> )
|
||||
"""
|
||||
if isinstance(param_groups, tuple):
|
||||
param_groups = list(param_groups) # Tuple cannot be modified
|
||||
|
|
Loading…
Reference in New Issue