[booster] update prepare dataloader method for plugin (#3706)

* [booster] add prepare dataloader method for plug

* [booster] update examples and docstr
pull/3713/head
Hongxin Liu 2 years ago committed by GitHub
parent f83ea813f5
commit 3bf09efe74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,7 +20,7 @@ class DPPluginBase(Plugin):
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
def prepare_train_dataloader(self, def prepare_dataloader(self,
dataset, dataset,
batch_size, batch_size,
shuffle=False, shuffle=False,
@ -33,8 +33,6 @@ class DPPluginBase(Plugin):
Prepare a dataloader for distributed training. The dataloader will be wrapped by Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Note:
1. Evaluation datasets should not be passed to this function.
Args: Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded. dataset (`torch.utils.data.Dataset`): The dataset to be loaded.

@ -156,7 +156,7 @@ class GeminiPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ... >>> model, train_dataset, optimizer, criterion = ...
>>> plugin = GeminiPlugin() >>> plugin = GeminiPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin) >>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

@ -95,7 +95,7 @@ class LowLevelZeroPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ... >>> model, train_dataset, optimizer, criterion = ...
>>> plugin = LowLevelZeroPlugin() >>> plugin = LowLevelZeroPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin) >>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

@ -4,7 +4,7 @@ from typing import Callable, List, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, Dataset
from colossalai.checkpoint_io import CheckpointIO from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
@ -59,3 +59,18 @@ class Plugin(ABC):
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
""" """
pass pass
@abstractmethod
def prepare_dataloader(self,
dataset: Dataset,
batch_size: int,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
**kwargs):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader`
"""
pass

@ -72,7 +72,7 @@ class TorchDDPPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ... >>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchDDPPlugin() >>> plugin = TorchDDPPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin) >>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

@ -49,14 +49,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
download=True) download=True)
# Data loader # Data loader
train_dataloader = plugin.prepare_train_dataloader(train_dataset, train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
batch_size=batch_size, test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
shuffle=True,
drop_last=True)
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False)
return train_dataloader, test_dataloader return train_dataloader, test_dataloader

@ -63,14 +63,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
download=True) download=True)
# Data loader # Data loader
train_dataloader = plugin.prepare_train_dataloader(train_dataset, train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
batch_size=batch_size, test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
shuffle=True,
drop_last=True)
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False)
return train_dataloader, test_dataloader return train_dataloader, test_dataloader

@ -84,26 +84,26 @@ class GLUEDataBuilder:
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def train_dataloader(self): def train_dataloader(self):
return self.plugin.prepare_train_dataloader(self.dataset["train"], return self.plugin.prepare_dataloader(self.dataset["train"],
batch_size=self.train_batch_size, batch_size=self.train_batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True)
def val_dataloader(self): def val_dataloader(self):
if len(self.eval_splits) == 1: if len(self.eval_splits) == 1:
return self.plugin.prepare_train_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1: elif len(self.eval_splits) > 1:
return [ return [
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size) self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits for x in self.eval_splits
] ]
def test_dataloader(self): def test_dataloader(self):
if len(self.eval_splits) == 1: if len(self.eval_splits) == 1:
return self.plugin.prepare_train_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1: elif len(self.eval_splits) > 1:
return [ return [
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size) self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits for x in self.eval_splits
] ]

@ -55,7 +55,7 @@ def check_dataloader_sharding():
# create a custom dasetset with 0 to 10 # create a custom dasetset with 0 to 10
dataset = TensorDataset(torch.arange(0, 10)) dataset = TensorDataset(torch.arange(0, 10))
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
# get the first batch of data # get the first batch of data
batch = next(iter(train_dataloader))[0].cuda() batch = next(iter(train_dataloader))[0].cuda()

Loading…
Cancel
Save