bind seed

pull/496/head
lijiaxing 2023-11-14 10:14:27 +08:00
parent c53667d70c
commit 0c94e429bb
1 changed files with 3 additions and 6 deletions

View File

@ -2,11 +2,9 @@
# -*- encoding: utf-8 -*-
import math
import random
from functools import wraps
from typing import Optional
import numpy as np
import torch
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
@ -14,6 +12,7 @@ from torch import nn
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.random import _SEED_MANAGER
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.initialize.launch import GLOBAL_SEED
from internlm.model.embedding import Embedding1D
@ -418,10 +417,8 @@ class PackedFlashInternLm1D(nn.Module):
def fix_seed(func):
@wraps(func)
def wrapper(*args, **kwargs):
seed = GLOBAL_SEED
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
_SEED_MANAGER.reset()
gpc.set_seed(GLOBAL_SEED)
func(*args, **kwargs)
return wrapper