bind seed

pull/496/head
lijiaxing 2023-11-14 10:14:27 +08:00
parent c53667d70c
commit 0c94e429bb
1 changed files with 3 additions and 6 deletions

View File

@ -2,11 +2,9 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import random
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
import numpy as np
import torch import torch
from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP from flash_attn.modules.mlp import ParallelFusedMLP
@ -14,6 +12,7 @@ from torch import nn
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.random import _SEED_MANAGER
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.initialize.launch import GLOBAL_SEED from internlm.initialize.launch import GLOBAL_SEED
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
@ -418,10 +417,8 @@ class PackedFlashInternLm1D(nn.Module):
def fix_seed(func): def fix_seed(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
seed = GLOBAL_SEED _SEED_MANAGER.reset()
random.seed(seed) gpc.set_seed(GLOBAL_SEED)
np.random.seed(seed)
torch.manual_seed(seed)
func(*args, **kwargs) func(*args, **kwargs)
return wrapper return wrapper