fix oom issue in dataloaders

pull/545/head
zigzagcai 2023-12-15 14:08:40 +08:00
parent bbb5651582
commit 2dc8ddd582
1 changed files with 5 additions and 0 deletions

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
import functools
import gc
import os
import pickle
import time
@ -262,6 +263,9 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
def dl_worker_init(worker_id):
gc.enable()
# Create the training data loader
train_dl = DataLoader(
@ -271,6 +275,7 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
pin_memory=True,
collate_fn=train_collate_fn,
persistent_workers=num_worker > 0,
worker_init_fn = dl_worker_init,
)
return train_dl, dataset_types