refactor code

pull/182/head
zhanglei 2023-09-22 11:47:05 +08:00
parent 17bc5f562b
commit 80972ff314
6 changed files with 9 additions and 33 deletions

View File

@ -646,7 +646,8 @@ class PipelineScheduler(BaseScheduler):
return_loss (bool, optional): Whether returns the loss value. Default is true.
return_output_label (bool, optional): If False, the output and label won't be returned.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss), loss and label could be None.
The loss would be returned only in the last stage. And the moe_loss is accumulated from all stages.
"""
assert (
@ -1316,8 +1317,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
return_output_label (bool, optional): If False, the output and label won't be returned.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
The loss would be returned only in the last stage.
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss), loss and label could be None.
The loss would be returned only in the last stage. And the moe_loss is accumulated from all stages.
"""
assert (
forward_only or return_loss

View File

@ -203,7 +203,7 @@ class Trainer:
**kwargs: Additional keyword arguments.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss).
"""
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
return output, label, loss, moe_loss

View File

@ -9,12 +9,6 @@ from internlm.moe.experts import Experts
from internlm.moe.sharded_moe import MOELayer, TopKGate
from internlm.utils.logger import get_logger
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# global llm logger
logger = get_logger(__file__)

View File

@ -4,12 +4,6 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Union, cast
import torch

View File

@ -4,13 +4,6 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
import torch

View File

@ -538,8 +538,7 @@ class HybridZeroOptimizer(BaseOptimizer):
def _compute_norm_with_moe_group(self, group_id):
params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
# wo do not get the average grad for moe parameters, so we have to constuct the gradients list here.
# Maybe this can be optimized.
# we do not get the average grad for moe parameters, so we have to constuct the gradients list here.
grads = [p.grad for p in params]
if len(params) == 0:
@ -696,14 +695,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
if self._is_norm_group(self.optim.param_groups[group_id]):
dist.all_reduce(
flat_fp32_avg_grads,
op=dist.ReduceOp.AVG,
group=gpc.get_group(ParallelMode.TENSOR),
)
if self._is_gate_group(self.optim.param_groups[group_id]):
is_tp_shared_params = (self._is_norm_group(self.optim.param_groups[group_id])
or self._is_gate_group(self.optim.param_groups[group_id]))
if is_tp_shared_params:
dist.all_reduce(
flat_fp32_avg_grads,
op=dist.ReduceOp.AVG,