diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 71a47a0..e36d468 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -21,7 +21,7 @@ from torch.utils.data import ConcatDataset, DataLoader from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.random import set_mode +from internlm.core.context.random import get_states, set_mode, set_seed_states from internlm.core.naive_amp import NaiveAMPModel from internlm.core.trainer import TrainState from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader @@ -73,8 +73,23 @@ def initialize_model(): torch.nn.Module: The neural network model to be trained or evaluated. """ + rng_states = get_states(copy=True) + init_seed = gpc.config.get("seed", 1001) + tp_offset = gpc.get_local_rank(ParallelMode.TENSOR) + pp_offset = gpc.get_local_rank(ParallelMode.PIPELINE) + init_seed = init_seed + tp_offset + pp_offset * 32 + + torch.manual_seed(init_seed) + torch.cuda.manual_seed(init_seed) + + dist.barrier() model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) + dist.barrier() + + for mode, rng in rng_states.items(): + set_seed_states(mode, rng) + if isinstance(model, nn.ModuleList): model = nn.ModuleList( [