mirror of https://github.com/InternLM/InternLM
feat(seed): set global seed for every model initialization (#496)
* bind seed * bind seedpull/507/head v0.2.1dev20231121
parent
679ed3c8ca
commit
eba2b859fc
|
@ -27,6 +27,7 @@ else:
|
|||
get_numa = True
|
||||
|
||||
logger = get_logger(__file__)
|
||||
GLOBAL_SEED = 1024
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
@ -543,6 +544,9 @@ def initialize_distributed_env(
|
|||
else:
|
||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||
|
||||
global GLOBAL_SEED
|
||||
GLOBAL_SEED = seed
|
||||
|
||||
if args_check:
|
||||
args_sanity_check()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -11,7 +12,9 @@ from torch import nn
|
|||
|
||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.core.context.random import _SEED_MANAGER
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.initialize.launch import GLOBAL_SEED
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
FeedForward,
|
||||
|
@ -81,6 +84,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
|
||||
self.mixer = MHA(
|
||||
embed_dim=hidden_size,
|
||||
num_heads=num_attention_heads,
|
||||
|
@ -410,6 +414,16 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
def fix_seed(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
_SEED_MANAGER.reset()
|
||||
gpc.set_seed(GLOBAL_SEED)
|
||||
func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
"""
|
||||
build generic model 1d
|
||||
|
@ -429,6 +443,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
logger.info(f"The layer sharding is {all_parts}.")
|
||||
|
||||
models = []
|
||||
PackedFlashInternLm1D.__init__ = fix_seed(PackedFlashInternLm1D.__init__)
|
||||
|
||||
for start, end in parts:
|
||||
kwargs["num_layers"] = end - start
|
||||
|
|
Loading…
Reference in New Issue