From b65e8cb802cc276abad4178bfef134c03f6273ad Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Tue, 14 Nov 2023 17:29:06 +0800 Subject: [PATCH] init_seed --- internlm/train/training_internlm.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 71a47a0..e36d468 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -21,7 +21,7 @@ from torch.utils.data import ConcatDataset, DataLoader from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.random import set_mode +from internlm.core.context.random import get_states, set_mode, set_seed_states from internlm.core.naive_amp import NaiveAMPModel from internlm.core.trainer import TrainState from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader @@ -73,8 +73,23 @@ def initialize_model(): torch.nn.Module: The neural network model to be trained or evaluated. """ + rng_states = get_states(copy=True) + init_seed = gpc.config.get("seed", 1001) + tp_offset = gpc.get_local_rank(ParallelMode.TENSOR) + pp_offset = gpc.get_local_rank(ParallelMode.PIPELINE) + init_seed = init_seed + tp_offset + pp_offset * 32 + + torch.manual_seed(init_seed) + torch.cuda.manual_seed(init_seed) + + dist.barrier() model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) + dist.barrier() + + for mode, rng in rng_states.items(): + set_seed_states(mode, rng) + if isinstance(model, nn.ModuleList): model = nn.ModuleList( [