diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py index 4021b3175..d5da5938b 100644 --- a/colossalai/booster/plugin/dp_plugin_base.py +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -20,21 +20,19 @@ class DPPluginBase(Plugin): self.rank = dist.get_rank() self.world_size = dist.get_world_size() - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): + def prepare_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - Note: - 1. Evaluation datasets should not be passed to this function. Args: dataset (`torch.utils.data.Dataset`): The dataset to be loaded. diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index fde8912a6..4850b52de 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -156,7 +156,7 @@ class GeminiPlugin(DPPluginBase): >>> model, train_dataset, optimizer, criterion = ... >>> plugin = GeminiPlugin() - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) >>> booster = Booster(plugin=plugin) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 828d8b274..f0f576856 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -95,7 +95,7 @@ class LowLevelZeroPlugin(DPPluginBase): >>> model, train_dataset, optimizer, criterion = ... >>> plugin = LowLevelZeroPlugin() - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) >>> booster = Booster(plugin=plugin) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 7a222022c..eb5478595 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -4,7 +4,7 @@ from typing import Callable, List, Tuple, Union import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from colossalai.checkpoint_io import CheckpointIO from colossalai.interface import OptimizerWrapper @@ -59,3 +59,18 @@ class Plugin(ABC): Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. """ pass + + @abstractmethod + def prepare_dataloader(self, + dataset: Dataset, + batch_size: int, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + **kwargs): + """Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` + """ + pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index d30d266c0..76906d844 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -72,7 +72,7 @@ class TorchDDPPlugin(DPPluginBase): >>> model, train_dataset, optimizer, criterion = ... >>> plugin = TorchDDPPlugin() - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) >>> booster = Booster(plugin=plugin) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index e64e95fc2..a96a4b640 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -49,14 +49,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl download=True) # Data loader - train_dataloader = plugin.prepare_train_dataloader(train_dataset, - batch_size=batch_size, - shuffle=True, - drop_last=True) - test_dataloader = plugin.prepare_train_dataloader(test_dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False) + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) return train_dataloader, test_dataloader diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index fee53df07..2405fdfc6 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -63,14 +63,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl download=True) # Data loader - train_dataloader = plugin.prepare_train_dataloader(train_dataset, - batch_size=batch_size, - shuffle=True, - drop_last=True) - test_dataloader = plugin.prepare_train_dataloader(test_dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False) + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) return train_dataloader, test_dataloader diff --git a/examples/tutorial/new_api/glue_bert/data.py b/examples/tutorial/new_api/glue_bert/data.py index e43312aeb..981cedcca 100644 --- a/examples/tutorial/new_api/glue_bert/data.py +++ b/examples/tutorial/new_api/glue_bert/data.py @@ -84,26 +84,26 @@ class GLUEDataBuilder: AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) def train_dataloader(self): - return self.plugin.prepare_train_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.plugin.prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) def val_dataloader(self): if len(self.eval_splits) == 1: - return self.plugin.prepare_train_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits ] def test_dataloader(self): if len(self.eval_splits) == 1: - return self.plugin.prepare_train_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits ] diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index a2b94ba6c..eab949828 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -55,7 +55,7 @@ def check_dataloader_sharding(): # create a custom dasetset with 0 to 10 dataset = TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) # get the first batch of data batch = next(iter(train_dataloader))[0].cuda()