diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 474bfd2..db18951 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import functools +import gc import os import pickle import time @@ -262,6 +263,9 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C data_world_size=gpc.get_world_size(ParallelMode.DATA), ) train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length) + + def dl_worker_init(worker_id): + gc.enable() # Create the training data loader train_dl = DataLoader( @@ -271,6 +275,7 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C pin_memory=True, collate_fn=train_collate_fn, persistent_workers=num_worker > 0, + worker_init_fn = dl_worker_init, ) return train_dl, dataset_types