mirror of https://github.com/InternLM/InternLM
refactor code
parent
bba9b01c0e
commit
9665321745
|
@ -491,7 +491,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||||
if self.config.model.num_experts > 1:
|
if self.config.model.get("num_experts", 1) > 1:
|
||||||
initializers.append(pgroup_initializer.Initializer_Expert(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Expert(*initializer_args))
|
||||||
for initializer in initializers:
|
for initializer in initializers:
|
||||||
parallel_setting = initializer.init_dist_group()
|
parallel_setting = initializer.init_dist_group()
|
||||||
|
|
|
@ -15,13 +15,10 @@ from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
from internlm.utils.common import get_current_device, move_to_device
|
from internlm.utils.common import get_current_device, move_to_device
|
||||||
from internlm.utils.logger import get_logger
|
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_shape():
|
def get_tensor_shape():
|
||||||
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
||||||
|
@ -1347,8 +1344,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
else:
|
else:
|
||||||
output, label = (None, None)
|
output, label = (None, None)
|
||||||
|
|
||||||
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}")
|
|
||||||
|
|
||||||
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||||
accum_moe_loss = self._accum_moe_loss
|
accum_moe_loss = self._accum_moe_loss
|
||||||
|
|
||||||
|
|
|
@ -501,7 +501,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
||||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||||
parts = all_parts[pipeline_rank]
|
parts = all_parts[pipeline_rank]
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"The layer sharding is {all_parts}.") # pylint: disable=W1203
|
logger.info(f"The layer sharding is {all_parts}.")
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
|
|
|
@ -69,10 +69,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
||||||
return gumbel(shape)
|
return gumbel(shape)
|
||||||
|
|
||||||
|
|
||||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
|
||||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
|
||||||
|
|
||||||
|
|
||||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||||
class _AllToAll(torch.autograd.Function):
|
class _AllToAll(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
@ -477,10 +473,10 @@ class MOELayer(Base):
|
||||||
|
|
||||||
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
|
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
|
||||||
|
|
||||||
a = combined_output.reshape(inputs[0].shape)
|
out = combined_output.reshape(inputs[0].shape)
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer("moe").stop()
|
timer("moe").stop()
|
||||||
self.time_moe = timer("moe").elapsed(reset=False)
|
self.time_moe = timer("moe").elapsed(reset=False)
|
||||||
|
|
||||||
return a
|
return out
|
||||||
|
|
|
@ -592,7 +592,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
groups_norms = []
|
groups_norms = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||||
groups_norms.append([])
|
groups_norms.append(None)
|
||||||
else:
|
else:
|
||||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||||
|
|
||||||
|
|
|
@ -9,12 +9,23 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
Compatiable with muiltiple param groups, each should have a name
|
Compatiable with muiltiple param groups, each should have a name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
param_groups (Tuple[Dict]):
|
param_groups (Tuple[Dict]): The list of parameter groups to split
|
||||||
The list of parameter groups to split
|
Output Example:
|
||||||
|
>>> (
|
||||||
|
>>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
|
||||||
|
>>> ...,
|
||||||
|
>>> )
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict]:
|
Tuple[Dict]: list of params groups for optimizer
|
||||||
list of MoE/non-MoE groups for optimizer
|
Output Example:
|
||||||
|
>>> (
|
||||||
|
>>> {'name': 'default','params': [tensor],'weight_decay' :xxx},
|
||||||
|
>>> {'name': 'norm', 'norm': True, 'params': [tensor],'weight_decay' :xxx},
|
||||||
|
>>> {'name': 'gate', 'gate': True, 'params': [tensor],'weight_decay' :xxx},
|
||||||
|
>>> {'name': 'moe_ep_size_4', 'moe': True, 'params': [tensor],'weight_decay' :xxx},
|
||||||
|
>>> ...,
|
||||||
|
>>> )
|
||||||
"""
|
"""
|
||||||
if isinstance(param_groups, tuple):
|
if isinstance(param_groups, tuple):
|
||||||
param_groups = list(param_groups) # Tuple cannot be modified
|
param_groups = list(param_groups) # Tuple cannot be modified
|
||||||
|
|
Loading…
Reference in New Issue