mirror of https://github.com/InternLM/InternLM
fix oom issue in dataloaders
parent
bbb5651582
commit
2dc8ddd582
|
@ -2,6 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
@ -262,6 +263,9 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
|
|||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||
|
||||
def dl_worker_init(worker_id):
|
||||
gc.enable()
|
||||
|
||||
# Create the training data loader
|
||||
train_dl = DataLoader(
|
||||
|
@ -271,6 +275,7 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
|
|||
pin_memory=True,
|
||||
collate_fn=train_collate_fn,
|
||||
persistent_workers=num_worker > 0,
|
||||
worker_init_fn = dl_worker_init,
|
||||
)
|
||||
|
||||
return train_dl, dataset_types
|
||||
|
|
Loading…
Reference in New Issue