mirror of https://github.com/InternLM/InternLM
fix set layer name
parent
7d68509c4f
commit
7920168179
|
@ -6,6 +6,7 @@ from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
|
|
||||||
|
|
||||||
def is_model_parallel_parameter(p):
|
def is_model_parallel_parameter(p):
|
||||||
|
@ -70,20 +71,21 @@ def set_model_params_layer_name(model):
|
||||||
Args:
|
Args:
|
||||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||||
"""
|
"""
|
||||||
if isinstance(model, nn.ModuleList):
|
if not isinstance(model, nn.ModuleList):
|
||||||
_chunk = model[0]
|
model = [model]
|
||||||
else:
|
|
||||||
_chunk = model
|
|
||||||
|
|
||||||
# Create a unique layer name based on the block's class name and index
|
for _chunk in model:
|
||||||
for name, children in _chunk.named_children():
|
if isinstance(_chunk, NaiveAMPModel):
|
||||||
if isinstance(children, nn.ModuleList):
|
_chunk = _chunk.model
|
||||||
for idx, block in enumerate(children):
|
# Create a unique layer name based on the block's class name and index
|
||||||
for param in block.parameters():
|
for name, children in _chunk.named_children():
|
||||||
layer_name = f"{block.__class__.__name__}.{idx}"
|
if isinstance(children, nn.ModuleList):
|
||||||
gpc.layer_names.add(layer_name)
|
for idx, block in enumerate(children):
|
||||||
param.__setattr__("layer_name", layer_name)
|
for param in block.parameters():
|
||||||
else:
|
layer_name = f"{block.__class__.__name__}.{idx}"
|
||||||
for param in children.parameters():
|
gpc.layer_names.add(layer_name)
|
||||||
gpc.layer_names.add(name)
|
param.__setattr__("layer_name", layer_name)
|
||||||
param.__setattr__("layer_name", name)
|
else:
|
||||||
|
for param in children.parameters():
|
||||||
|
gpc.layer_names.add(name)
|
||||||
|
param.__setattr__("layer_name", name)
|
||||||
|
|
Loading…
Reference in New Issue