mirror of https://github.com/InternLM/InternLM
				
				
				
			fix(model): add IS_SEQUENCE_PARALLEL check for norm module (#528)
* add IS_SEQUENCE_PARALLEL check for norm module * fix lint * remove comments * replace the named_children by named_modules * fix lint * fix the spell bug and move the sequence judge to training_internlmpull/530/head
							parent
							
								
									2dbbab7418
								
							
						
					
					
						commit
						c581cc4c02
					
				| 
						 | 
				
			
			@ -53,6 +53,7 @@ from internlm.utils.common import DummyProfile
 | 
			
		|||
from internlm.utils.logger import get_logger
 | 
			
		||||
from internlm.utils.megatron_timers import megatron_timer as timer
 | 
			
		||||
from internlm.utils.parallel import (
 | 
			
		||||
    check_sequence_parallel,
 | 
			
		||||
    set_model_params_layer_name,
 | 
			
		||||
    sync_model_param,
 | 
			
		||||
    sync_model_param_within_tp,
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +112,10 @@ def initialize_model():
 | 
			
		|||
    # if fsdp enabled, wrap the model
 | 
			
		||||
    model = wrap_FSDP_model(model)
 | 
			
		||||
 | 
			
		||||
    # check whether the norm module has IS_SEQUENCE_PARALLEL attribute
 | 
			
		||||
    if gpc.config.parallel.sequence_parallel is True:
 | 
			
		||||
        check_sequence_parallel(model)
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,11 +4,14 @@
 | 
			
		|||
import torch.distributed as dist
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
 | 
			
		||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.core.naive_amp import NaiveAMPModel
 | 
			
		||||
from internlm.model.utils import try_import_RMSNorm
 | 
			
		||||
from internlm.solver.pipeline_utils import partition_uniform
 | 
			
		||||
 | 
			
		||||
RMSNorm = try_import_RMSNorm()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_model_parallel_parameter(p):
 | 
			
		||||
    return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
 | 
			
		||||
| 
						 | 
				
			
			@ -98,3 +101,26 @@ def set_model_params_layer_name(model):
 | 
			
		|||
                    layer_param_name = f"{layer_name}-{param_name}"
 | 
			
		||||
                    param.__setattr__("layer_name", layer_name)
 | 
			
		||||
                    param.__setattr__("param_name", f"{layer_name}-{param_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_sequence_parallel(model):
 | 
			
		||||
    """
 | 
			
		||||
    check whether the norm module has IS_SEQUENCE_PARALLEL attribute.
 | 
			
		||||
    when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute
 | 
			
		||||
    to illustrate the norm should conduct the all-reduce for its grad.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if not isinstance(model, nn.ModuleList):
 | 
			
		||||
        model = [model]
 | 
			
		||||
 | 
			
		||||
    for _chunk in model:
 | 
			
		||||
        if isinstance(_chunk, NaiveAMPModel):
 | 
			
		||||
            _chunk = _chunk.model
 | 
			
		||||
 | 
			
		||||
        for _, module in _chunk.named_modules():
 | 
			
		||||
            if isinstance(module, (RMSNorm, nn.LayerNorm)):
 | 
			
		||||
                for param in module.parameters():
 | 
			
		||||
                    assert hasattr(param, IS_SEQUENCE_PARALLEL), (
 | 
			
		||||
                        "when the gpc.config.parallel.sequence parallel is True,"
 | 
			
		||||
                        "the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
 | 
			
		||||
                    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue