mirror of https://github.com/InternLM/InternLM
fix layer grad_norm with pp
parent
7920168179
commit
6ce78a4e09
|
@ -150,7 +150,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
self.virtual_pipeline_parallel_size = None
|
||||
self.virtual_pipeline_parallel_rank = None
|
||||
self._expert_parallel_group_names = []
|
||||
self.layer_names = {"unknown", "embedding", "norm", "head"}
|
||||
self.layer_names = ["unknown"]
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
|
|
|
@ -24,6 +24,7 @@ from internlm.solver.pipeline_utils import partition_uniform
|
|||
from internlm.utils.checkpoint import activation_checkpoint
|
||||
from internlm.utils.common import filter_kwargs
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.parallel import set_model_params_layer_name
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
MODEL_TYPE = "INTERNLM"
|
||||
|
@ -418,6 +419,21 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
if gpc.is_rank_for_log():
|
||||
logger.info(f"The layer sharding is {all_parts}.")
|
||||
|
||||
# config gpc.layer_name
|
||||
# get names of first and last layers
|
||||
kwargs["num_layers"] = 1
|
||||
kwargs["device"] = device
|
||||
kwargs["first"] = True
|
||||
kwargs["last"] = True
|
||||
kwargs["start_layer_idx"] = 0
|
||||
tmp_chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).cpu()
|
||||
# get names of middle layers
|
||||
for idx in range(num_layers):
|
||||
layer_name = f"{PackedFlashBaseLayer1D.__name__}.{idx}"
|
||||
gpc.layer_names.append(layer_name)
|
||||
set_model_params_layer_name(tmp_chunk)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
models = []
|
||||
|
||||
for start, end in parts:
|
||||
|
|
|
@ -26,6 +26,7 @@ from internlm.solver.pipeline_utils import partition_uniform
|
|||
from internlm.utils.checkpoint import activation_checkpoint
|
||||
from internlm.utils.common import filter_kwargs
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.parallel import set_model_params_layer_name
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
MODEL_TYPE = "INTERNLM_MoE"
|
||||
|
@ -516,6 +517,21 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
if gpc.is_rank_for_log():
|
||||
logger.info(f"The layer sharding is {all_parts}.")
|
||||
|
||||
# config gpc.layer_name
|
||||
# get names of first and last layers
|
||||
kwargs["num_layers"] = 1
|
||||
kwargs["device"] = device
|
||||
kwargs["first"] = True
|
||||
kwargs["last"] = True
|
||||
kwargs["start_layer_idx"] = 0
|
||||
tmp_chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).cpu()
|
||||
# get names of middle layers
|
||||
for idx in range(num_layers):
|
||||
layer_name = f"{PackedFlashBaseLayer1D.__name__}.{idx}"
|
||||
gpc.layer_names.append(layer_name)
|
||||
set_model_params_layer_name(tmp_chunk)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
models = []
|
||||
|
||||
for start, end in parts:
|
||||
|
|
|
@ -78,14 +78,16 @@ def set_model_params_layer_name(model):
|
|||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
# Create a unique layer name based on the block's class name and index
|
||||
for name, children in _chunk.named_children():
|
||||
for _, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for idx, block in enumerate(children):
|
||||
for param in block.parameters():
|
||||
layer_name = f"{block.__class__.__name__}.{idx}"
|
||||
gpc.layer_names.add(layer_name)
|
||||
gpc.layer_names.append(layer_name)
|
||||
param.__setattr__("layer_name", layer_name)
|
||||
else:
|
||||
for param in children.parameters():
|
||||
gpc.layer_names.add(name)
|
||||
param.__setattr__("layer_name", name)
|
||||
layer_name = f"{children.__class__.__name__}"
|
||||
gpc.layer_names.append(layer_name)
|
||||
param.__setattr__("layer_name", layer_name)
|
||||
gpc.layer_names = sorted(set(gpc.layer_names))
|
||||
|
|
Loading…
Reference in New Issue