diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index a2cc14f..38fb9ca 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -2,11 +2,9 @@ # -*- encoding: utf-8 -*- import math -import random from functools import wraps from typing import Optional -import numpy as np import torch from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.mlp import ParallelFusedMLP @@ -14,6 +12,7 @@ from torch import nn from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context.random import _SEED_MANAGER from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.launch import GLOBAL_SEED from internlm.model.embedding import Embedding1D @@ -418,10 +417,8 @@ class PackedFlashInternLm1D(nn.Module): def fix_seed(func): @wraps(func) def wrapper(*args, **kwargs): - seed = GLOBAL_SEED - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) + _SEED_MANAGER.reset() + gpc.set_seed(GLOBAL_SEED) func(*args, **kwargs) return wrapper