Merge pull request #4 from blankde/feature_add_moe_refactor_zl

refactor code
pull/182/head
Ryan (张磊) 2023-09-22 14:22:45 +08:00 committed by GitHub
commit aa7645a831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 11 additions and 33 deletions

View File

@ -646,7 +646,8 @@ class PipelineScheduler(BaseScheduler):
return_loss (bool, optional): Whether returns the loss value. Default is true. return_loss (bool, optional): Whether returns the loss value. Default is true.
return_output_label (bool, optional): If False, the output and label won't be returned. return_output_label (bool, optional): If False, the output and label won't be returned.
Returns: Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None. Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss), loss and label could be None.
The loss would be returned only in the last stage. And the moe_loss is accumulated from all stages.
""" """
assert ( assert (
@ -1316,8 +1317,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
return_output_label (bool, optional): If False, the output and label won't be returned. return_output_label (bool, optional): If False, the output and label won't be returned.
Returns: Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss), loss and label could be None.
The loss would be returned only in the last stage. The loss would be returned only in the last stage. And the moe_loss is accumulated from all stages.
""" """
assert ( assert (
forward_only or return_loss forward_only or return_loss

View File

@ -203,7 +203,7 @@ class Trainer:
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
Returns: Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss).
""" """
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
return output, label, loss, moe_loss return output, label, loss, moe_loss

View File

@ -9,12 +9,6 @@ from internlm.moe.experts import Experts
from internlm.moe.sharded_moe import MOELayer, TopKGate from internlm.moe.sharded_moe import MOELayer, TopKGate
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# global llm logger # global llm logger
logger = get_logger(__file__) logger = get_logger(__file__)

View File

@ -4,12 +4,6 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555 Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files: We retain the following license from the original files:
""" """
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Union, cast from typing import Union, cast
import torch import torch

View File

@ -4,13 +4,6 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555 Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files: We retain the following license from the original files:
""" """
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
import torch import torch

View File

@ -538,8 +538,7 @@ class HybridZeroOptimizer(BaseOptimizer):
def _compute_norm_with_moe_group(self, group_id): def _compute_norm_with_moe_group(self, group_id):
params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank) params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
# wo do not get the average grad for moe parameters, so we have to constuct the gradients list here. # we do not get the average grad for moe parameters, so we have to constuct the gradients list here.
# Maybe this can be optimized.
grads = [p.grad for p in params] grads = [p.grad for p in params]
if len(params) == 0: if len(params) == 0:
@ -696,14 +695,11 @@ class HybridZeroOptimizer(BaseOptimizer):
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients. # Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors. # Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
if self._is_norm_group(self.optim.param_groups[group_id]): is_tp_sync_groups = (
dist.all_reduce( self._is_norm_group(self.optim.param_groups[group_id]),
flat_fp32_avg_grads, self._is_gate_group(self.optim.param_groups[group_id]),
op=dist.ReduceOp.AVG, )
group=gpc.get_group(ParallelMode.TENSOR), if any(is_tp_sync_groups):
)
if self._is_gate_group(self.optim.param_groups[group_id]):
dist.all_reduce( dist.all_reduce(
flat_fp32_avg_grads, flat_fp32_avg_grads,
op=dist.ReduceOp.AVG, op=dist.ReduceOp.AVG,