feat(seed): set global seed for every model initialization (#496)

* bind seed

* bind seed
pull/507/head v0.2.1dev20231121
jiaxingli 2023-11-17 14:42:50 +08:00 committed by GitHub
parent 679ed3c8ca
commit eba2b859fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 0 deletions

View File

@ -27,6 +27,7 @@ else:
get_numa = True get_numa = True
logger = get_logger(__file__) logger = get_logger(__file__)
GLOBAL_SEED = 1024
def get_default_parser(): def get_default_parser():
@ -543,6 +544,9 @@ def initialize_distributed_env(
else: else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch" assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
global GLOBAL_SEED
GLOBAL_SEED = seed
if args_check: if args_check:
args_sanity_check() args_sanity_check()

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
from functools import wraps
from typing import Optional from typing import Optional
import torch import torch
@ -11,7 +12,9 @@ from torch import nn
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.random import _SEED_MANAGER
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.initialize.launch import GLOBAL_SEED
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import ( from internlm.model.linear import (
FeedForward, FeedForward,
@ -81,6 +84,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.use_flash_attn = use_flash_attn self.use_flash_attn = use_flash_attn
head_dim = hidden_size // num_attention_heads head_dim = hidden_size // num_attention_heads
self.mixer = MHA( self.mixer = MHA(
embed_dim=hidden_size, embed_dim=hidden_size,
num_heads=num_attention_heads, num_heads=num_attention_heads,
@ -410,6 +414,16 @@ class PackedFlashInternLm1D(nn.Module):
return hidden_states return hidden_states
def fix_seed(func):
@wraps(func)
def wrapper(*args, **kwargs):
_SEED_MANAGER.reset()
gpc.set_seed(GLOBAL_SEED)
func(*args, **kwargs)
return wrapper
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
""" """
build generic model 1d build generic model 1d
@ -429,6 +443,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
logger.info(f"The layer sharding is {all_parts}.") logger.info(f"The layer sharding is {all_parts}.")
models = [] models = []
PackedFlashInternLm1D.__init__ = fix_seed(PackedFlashInternLm1D.__init__)
for start, end in parts: for start, end in parts:
kwargs["num_layers"] = end - start kwargs["num_layers"] = end - start