mirror of https://github.com/InternLM/InternLM
revert worker_init_fn since we move gc disable after dataloader launching
parent
df0acdee43
commit
f8eaf618af
|
@ -2,7 +2,6 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import gc
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
@ -264,9 +263,6 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
|
||||||
)
|
)
|
||||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||||
|
|
||||||
def dl_worker_init(worker_id): # pylint: disable=unused-argument
|
|
||||||
gc.enable()
|
|
||||||
|
|
||||||
# Create the training data loader
|
# Create the training data loader
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
dataset=train_ds,
|
dataset=train_ds,
|
||||||
|
@ -275,7 +271,6 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=train_collate_fn,
|
collate_fn=train_collate_fn,
|
||||||
persistent_workers=num_worker > 0,
|
persistent_workers=num_worker > 0,
|
||||||
worker_init_fn=dl_worker_init,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dl, dataset_types
|
return train_dl, dataset_types
|
||||||
|
|
Loading…
Reference in New Issue