mirror of https://github.com/InternLM/InternLM
refactor code
parent
17bc5f562b
commit
80972ff314
|
@ -646,7 +646,8 @@ class PipelineScheduler(BaseScheduler):
|
||||||
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
||||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss), loss and label could be None.
|
||||||
|
The loss would be returned only in the last stage. And the moe_loss is accumulated from all stages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
@ -1316,8 +1317,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss), loss and label could be None.
|
||||||
The loss would be returned only in the last stage.
|
The loss would be returned only in the last stage. And the moe_loss is accumulated from all stages.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
forward_only or return_loss
|
forward_only or return_loss
|
||||||
|
|
|
@ -203,7 +203,7 @@ class Trainer:
|
||||||
**kwargs: Additional keyword arguments.
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss).
|
||||||
"""
|
"""
|
||||||
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||||
return output, label, loss, moe_loss
|
return output, label, loss, moe_loss
|
||||||
|
|
|
@ -9,12 +9,6 @@ from internlm.moe.experts import Experts
|
||||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
# DeepSpeed Team
|
|
||||||
|
|
||||||
|
|
||||||
# global llm logger
|
# global llm logger
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,6 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||||
We retain the following license from the original files:
|
We retain the following license from the original files:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
# DeepSpeed Team
|
|
||||||
|
|
||||||
from typing import Union, cast
|
from typing import Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -4,13 +4,6 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||||
We retain the following license from the original files:
|
We retain the following license from the original files:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
# DeepSpeed Team
|
|
||||||
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -538,8 +538,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
def _compute_norm_with_moe_group(self, group_id):
|
def _compute_norm_with_moe_group(self, group_id):
|
||||||
params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
|
params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
|
||||||
# wo do not get the average grad for moe parameters, so we have to constuct the gradients list here.
|
# we do not get the average grad for moe parameters, so we have to constuct the gradients list here.
|
||||||
# Maybe this can be optimized.
|
|
||||||
grads = [p.grad for p in params]
|
grads = [p.grad for p in params]
|
||||||
|
|
||||||
if len(params) == 0:
|
if len(params) == 0:
|
||||||
|
@ -696,14 +695,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
|
# Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
|
||||||
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
|
# Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
|
||||||
if self._is_norm_group(self.optim.param_groups[group_id]):
|
is_tp_shared_params = (self._is_norm_group(self.optim.param_groups[group_id])
|
||||||
dist.all_reduce(
|
or self._is_gate_group(self.optim.param_groups[group_id]))
|
||||||
flat_fp32_avg_grads,
|
if is_tp_shared_params:
|
||||||
op=dist.ReduceOp.AVG,
|
|
||||||
group=gpc.get_group(ParallelMode.TENSOR),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._is_gate_group(self.optim.param_groups[group_id]):
|
|
||||||
dist.all_reduce(
|
dist.all_reduce(
|
||||||
flat_fp32_avg_grads,
|
flat_fp32_avg_grads,
|
||||||
op=dist.ReduceOp.AVG,
|
op=dist.ReduceOp.AVG,
|
||||||
|
|
Loading…
Reference in New Issue