mirror of https://github.com/InternLM/InternLM
				
				
				
			
						commit
						aa7645a831
					
				| 
						 | 
				
			
			@ -646,7 +646,8 @@ class PipelineScheduler(BaseScheduler):
 | 
			
		|||
            return_loss (bool, optional): Whether returns the loss value. Default is true.
 | 
			
		||||
            return_output_label (bool, optional): If False, the output and label won't be returned.
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
 | 
			
		||||
            Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss), loss and label could be None.
 | 
			
		||||
                The loss would be returned only in the last stage. And the moe_loss is accumulated from all stages.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        assert (
 | 
			
		||||
| 
						 | 
				
			
			@ -1316,8 +1317,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
 | 
			
		|||
            return_output_label (bool, optional): If False, the output and label won't be returned.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
 | 
			
		||||
                The loss would be returned only in the last stage.
 | 
			
		||||
            Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss), loss and label could be None.
 | 
			
		||||
                The loss would be returned only in the last stage. And the moe_loss is accumulated from all stages.
 | 
			
		||||
        """
 | 
			
		||||
        assert (
 | 
			
		||||
            forward_only or return_loss
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -203,7 +203,7 @@ class Trainer:
 | 
			
		|||
            **kwargs: Additional keyword arguments.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
 | 
			
		||||
            Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss).
 | 
			
		||||
        """
 | 
			
		||||
        output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
 | 
			
		||||
        return output, label, loss, moe_loss
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,12 +9,6 @@ from internlm.moe.experts import Experts
 | 
			
		|||
from internlm.moe.sharded_moe import MOELayer, TopKGate
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 | 
			
		||||
# DeepSpeed Team
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# global llm logger
 | 
			
		||||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,12 +4,6 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
 | 
			
		|||
 Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
 | 
			
		||||
 We retain the following license from the original files:
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 | 
			
		||||
# DeepSpeed Team
 | 
			
		||||
 | 
			
		||||
from typing import Union, cast
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,13 +4,6 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
 | 
			
		|||
 Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
 | 
			
		||||
 We retain the following license from the original files:
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 | 
			
		||||
# DeepSpeed Team
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -538,8 +538,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
 | 
			
		||||
    def _compute_norm_with_moe_group(self, group_id):
 | 
			
		||||
        params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
 | 
			
		||||
        # wo do not get the average grad for moe parameters, so we have to constuct the gradients list here.
 | 
			
		||||
        # Maybe this can be optimized.
 | 
			
		||||
        # we do not get the average grad for moe parameters, so we have to constuct the gradients list here.
 | 
			
		||||
        grads = [p.grad for p in params]
 | 
			
		||||
 | 
			
		||||
        if len(params) == 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -696,14 +695,11 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
 | 
			
		||||
            # Parameters shared within a TP group, such as norm and moe gate, have precision inconsistency in gradients.
 | 
			
		||||
            # Therefore, it is recommended to synchronize gradients within the TP group to eliminate accumulated errors.
 | 
			
		||||
            if self._is_norm_group(self.optim.param_groups[group_id]):
 | 
			
		||||
                dist.all_reduce(
 | 
			
		||||
                    flat_fp32_avg_grads,
 | 
			
		||||
                    op=dist.ReduceOp.AVG,
 | 
			
		||||
                    group=gpc.get_group(ParallelMode.TENSOR),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if self._is_gate_group(self.optim.param_groups[group_id]):
 | 
			
		||||
            is_tp_sync_groups = (
 | 
			
		||||
                self._is_norm_group(self.optim.param_groups[group_id]),
 | 
			
		||||
                self._is_gate_group(self.optim.param_groups[group_id]),
 | 
			
		||||
            )
 | 
			
		||||
            if any(is_tp_sync_groups):
 | 
			
		||||
                dist.all_reduce(
 | 
			
		||||
                    flat_fp32_avg_grads,
 | 
			
		||||
                    op=dist.ReduceOp.AVG,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue