mirror of https://github.com/InternLM/InternLM
bind seed
parent
c53667d70c
commit
0c94e429bb
|
@ -2,11 +2,9 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import random
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
|
@ -14,6 +12,7 @@ from torch import nn
|
|||
|
||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.core.context.random import _SEED_MANAGER
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.initialize.launch import GLOBAL_SEED
|
||||
from internlm.model.embedding import Embedding1D
|
||||
|
@ -418,10 +417,8 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
def fix_seed(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
seed = GLOBAL_SEED
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
_SEED_MANAGER.reset()
|
||||
gpc.set_seed(GLOBAL_SEED)
|
||||
func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
Loading…
Reference in New Issue