diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 06f225a..e96d2d9 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -27,6 +27,7 @@ else: get_numa = True logger = get_logger(__file__) +GLOBAL_SEED = 1024 def get_default_parser(): @@ -543,6 +544,9 @@ def initialize_distributed_env( else: assert launcher in ["slurm", "torch"], "launcher only support slurm or torch" + global GLOBAL_SEED + GLOBAL_SEED = seed + if args_check: args_sanity_check() diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index cbf425c..38fb9ca 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import math +from functools import wraps from typing import Optional import torch @@ -11,7 +12,9 @@ from torch import nn from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context.random import _SEED_MANAGER from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal +from internlm.initialize.launch import GLOBAL_SEED from internlm.model.embedding import Embedding1D from internlm.model.linear import ( FeedForward, @@ -81,6 +84,7 @@ class PackedFlashBaseLayer1D(nn.Module): self.use_flash_attn = use_flash_attn head_dim = hidden_size // num_attention_heads + self.mixer = MHA( embed_dim=hidden_size, num_heads=num_attention_heads, @@ -410,6 +414,16 @@ class PackedFlashInternLm1D(nn.Module): return hidden_states +def fix_seed(func): + @wraps(func) + def wrapper(*args, **kwargs): + _SEED_MANAGER.reset() + gpc.set_seed(GLOBAL_SEED) + func(*args, **kwargs) + + return wrapper + + def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): """ build generic model 1d @@ -429,6 +443,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), logger.info(f"The layer sharding is {all_parts}.") models = [] + PackedFlashInternLm1D.__init__ = fix_seed(PackedFlashInternLm1D.__init__) for start, end in parts: kwargs["num_layers"] = end - start