From c53667d70caaf17223d19d0e4e8b57d4b0d14583 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Mon, 13 Nov 2023 20:03:19 +0800 Subject: [PATCH] bind seed --- internlm/initialize/launch.py | 4 ++++ internlm/model/modeling_internlm.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 6614db0..299a72d 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -27,6 +27,7 @@ else: get_numa = True logger = get_logger(__file__) +GLOBAL_SEED = 1024 def get_default_parser(): @@ -531,6 +532,9 @@ def initialize_distributed_env( else: assert launcher in ["slurm", "torch"], "launcher only support slurm or torch" + global GLOBAL_SEED + GLOBAL_SEED = seed + if args_check: args_sanity_check() diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index cbf425c..a2cc14f 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -2,8 +2,11 @@ # -*- encoding: utf-8 -*- import math +import random +from functools import wraps from typing import Optional +import numpy as np import torch from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.mlp import ParallelFusedMLP @@ -12,6 +15,7 @@ from torch import nn from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal +from internlm.initialize.launch import GLOBAL_SEED from internlm.model.embedding import Embedding1D from internlm.model.linear import ( FeedForward, @@ -81,6 +85,7 @@ class PackedFlashBaseLayer1D(nn.Module): self.use_flash_attn = use_flash_attn head_dim = hidden_size // num_attention_heads + self.mixer = MHA( embed_dim=hidden_size, num_heads=num_attention_heads, @@ -410,6 +415,18 @@ class PackedFlashInternLm1D(nn.Module): return hidden_states +def fix_seed(func): + @wraps(func) + def wrapper(*args, **kwargs): + seed = GLOBAL_SEED + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + func(*args, **kwargs) + + return wrapper + + def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): """ build generic model 1d @@ -429,6 +446,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), logger.info(f"The layer sharding is {all_parts}.") models = [] + PackedFlashInternLm1D.__init__ = fix_seed(PackedFlashInternLm1D.__init__) for start, end in parts: kwargs["num_layers"] = end - start