mirror of https://github.com/InternLM/InternLM
				
				
				
			fix(*)/all-reduce for norm in sequence parallel (#443)
* fix all-reduce norm grad * change the order of dp and sp all-reduce * fix lintpull/450/head
							parent
							
								
									949a0a1d55
								
							
						
					
					
						commit
						1d7e2d04ec
					
				| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
from .parallel_context import (
 | 
			
		||||
    IS_SEQUENCE_PARALLEL,
 | 
			
		||||
    IS_TENSOR_PARALLEL,
 | 
			
		||||
    Config,
 | 
			
		||||
    ParallelContext,
 | 
			
		||||
| 
						 | 
				
			
			@ -29,6 +30,7 @@ from .random import (
 | 
			
		|||
__all__ = [
 | 
			
		||||
    "Config",
 | 
			
		||||
    "IS_TENSOR_PARALLEL",
 | 
			
		||||
    "IS_SEQUENCE_PARALLEL",
 | 
			
		||||
    "global_context",
 | 
			
		||||
    "ParallelContext",
 | 
			
		||||
    "ParallelMode",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,6 +25,7 @@ from .process_group_initializer import ParallelMode
 | 
			
		|||
from .random import add_seed, get_seeds, set_mode
 | 
			
		||||
 | 
			
		||||
IS_TENSOR_PARALLEL = "is_tensor_parallel"
 | 
			
		||||
IS_SEQUENCE_PARALLEL = "is_sequence_parallel"
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings
 | 
			
		|||
from flash_attn.modules.mlp import ParallelFusedMLP
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
 | 
			
		||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
 | 
			
		||||
from internlm.core.context.parallel_context import global_context as gpc
 | 
			
		||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
 | 
			
		||||
from internlm.model.embedding import Embedding1D
 | 
			
		||||
| 
						 | 
				
			
			@ -134,6 +134,12 @@ class PackedFlashBaseLayer1D(nn.Module):
 | 
			
		|||
        for _, param in self.mlp.named_parameters():
 | 
			
		||||
            if gpc.get_world_size(ParallelMode.TENSOR) > 1:
 | 
			
		||||
                setattr(param, IS_TENSOR_PARALLEL, True)
 | 
			
		||||
        for param in self.norm1.parameters():
 | 
			
		||||
            if gpc.config.parallel.sequence_parallel is True:
 | 
			
		||||
                setattr(param, IS_SEQUENCE_PARALLEL, True)
 | 
			
		||||
        for param in self.norm2.parameters():
 | 
			
		||||
            if gpc.config.parallel.sequence_parallel is True:
 | 
			
		||||
                setattr(param, IS_SEQUENCE_PARALLEL, True)
 | 
			
		||||
 | 
			
		||||
        self.dropout2 = nn.Dropout(drop_rate)
 | 
			
		||||
        self.use_swiglu = use_swiglu
 | 
			
		||||
| 
						 | 
				
			
			@ -356,6 +362,10 @@ class PackedFlashInternLm1D(nn.Module):
 | 
			
		|||
                normal_(std=0.0052)(param)
 | 
			
		||||
                if gpc.get_world_size(ParallelMode.TENSOR) > 1:
 | 
			
		||||
                    setattr(param, IS_TENSOR_PARALLEL, True)
 | 
			
		||||
            for param in self.norm.parameters():
 | 
			
		||||
                if gpc.config.parallel.sequence_parallel is True:
 | 
			
		||||
                    setattr(param, IS_SEQUENCE_PARALLEL, True)
 | 
			
		||||
 | 
			
		||||
        self.parallel_output = parallel_output
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,7 +8,7 @@ import torch
 | 
			
		|||
import torch.distributed as dist
 | 
			
		||||
from torch.optim import Optimizer
 | 
			
		||||
 | 
			
		||||
from internlm.core.context import Config, ParallelMode
 | 
			
		||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, Config, ParallelMode
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.monitor import send_alert_message
 | 
			
		||||
from internlm.solver.optimizer.store import (
 | 
			
		||||
| 
						 | 
				
			
			@ -297,6 +297,15 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
                            reduce_rank=reduce_rank,
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                        def reduction_sp_func():
 | 
			
		||||
                            handle = reduce_tensor(
 | 
			
		||||
                                param.grad,
 | 
			
		||||
                                dtype=None,
 | 
			
		||||
                                dst_rank=reduce_rank,
 | 
			
		||||
                                parallel_mode=ParallelMode.TENSOR,
 | 
			
		||||
                            )
 | 
			
		||||
                            handle.wait()
 | 
			
		||||
 | 
			
		||||
                        # define hook
 | 
			
		||||
                        # NOT IMPORTANT BUT GOOD TO KNOW:
 | 
			
		||||
                        # args here is not grad, but allow_unreacable and accumulate_grad
 | 
			
		||||
| 
						 | 
				
			
			@ -304,6 +313,18 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
                            if self.skip_grad_reduce is False:
 | 
			
		||||
                                reduction_func()
 | 
			
		||||
 | 
			
		||||
                        # define hook for sequence_parallel
 | 
			
		||||
                        def reduce_grad_hook_sp(*args):  # pylint: disable=W0613
 | 
			
		||||
                            if self.skip_grad_reduce is False:
 | 
			
		||||
                                reduction_sp_func()
 | 
			
		||||
 | 
			
		||||
                        # if sequence_parallel is True,
 | 
			
		||||
                        # the grad of norm should be all-reduce across the tp process group
 | 
			
		||||
                        if gpc.config.parallel.sequence_parallel is True:
 | 
			
		||||
                            if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
 | 
			
		||||
                                accum_grad_obj_sp = get_grad_accumulate_object(param)
 | 
			
		||||
                                accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
 | 
			
		||||
 | 
			
		||||
                        accum_grad_obj.register_hook(reduce_grad_hook)
 | 
			
		||||
 | 
			
		||||
                    _define_and_attach(param, reduce_rank)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue