mirror of https://github.com/hpcaitech/ColossalAI
[booster] update prepare dataloader method for plugin (#3706)
* [booster] add prepare dataloader method for plug * [booster] update examples and docstrpull/3713/head
parent
f83ea813f5
commit
3bf09efe74
|
@ -20,21 +20,19 @@ class DPPluginBase(Plugin):
|
||||||
self.rank = dist.get_rank()
|
self.rank = dist.get_rank()
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
|
|
||||||
def prepare_train_dataloader(self,
|
def prepare_dataloader(self,
|
||||||
dataset,
|
dataset,
|
||||||
batch_size,
|
batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
seed=1024,
|
seed=1024,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
r"""
|
r"""
|
||||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||||
|
|
||||||
Note:
|
|
||||||
1. Evaluation datasets should not be passed to this function.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||||
|
|
|
@ -156,7 +156,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
>>> model, train_dataset, optimizer, criterion = ...
|
>>> model, train_dataset, optimizer, criterion = ...
|
||||||
>>> plugin = GeminiPlugin()
|
>>> plugin = GeminiPlugin()
|
||||||
|
|
||||||
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
|
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||||
>>> booster = Booster(plugin=plugin)
|
>>> booster = Booster(plugin=plugin)
|
||||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
>>> model, train_dataset, optimizer, criterion = ...
|
>>> model, train_dataset, optimizer, criterion = ...
|
||||||
>>> plugin = LowLevelZeroPlugin()
|
>>> plugin = LowLevelZeroPlugin()
|
||||||
|
|
||||||
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
|
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||||
>>> booster = Booster(plugin=plugin)
|
>>> booster = Booster(plugin=plugin)
|
||||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Callable, List, Tuple, Union
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
@ -59,3 +59,18 @@ class Plugin(ABC):
|
||||||
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
|
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare_dataloader(self,
|
||||||
|
dataset: Dataset,
|
||||||
|
batch_size: int,
|
||||||
|
shuffle: bool = False,
|
||||||
|
seed: int = 1024,
|
||||||
|
drop_last: bool = False,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
num_workers: int = 0,
|
||||||
|
**kwargs):
|
||||||
|
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
`torch.utils.data.DataLoader`
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
|
@ -72,7 +72,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||||
>>> model, train_dataset, optimizer, criterion = ...
|
>>> model, train_dataset, optimizer, criterion = ...
|
||||||
>>> plugin = TorchDDPPlugin()
|
>>> plugin = TorchDDPPlugin()
|
||||||
|
|
||||||
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
|
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||||
>>> booster = Booster(plugin=plugin)
|
>>> booster = Booster(plugin=plugin)
|
||||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
|
||||||
|
|
|
@ -49,14 +49,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
|
||||||
download=True)
|
download=True)
|
||||||
|
|
||||||
# Data loader
|
# Data loader
|
||||||
train_dataloader = plugin.prepare_train_dataloader(train_dataset,
|
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||||
batch_size=batch_size,
|
test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
|
||||||
shuffle=True,
|
|
||||||
drop_last=True)
|
|
||||||
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False)
|
|
||||||
return train_dataloader, test_dataloader
|
return train_dataloader, test_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -63,14 +63,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
|
||||||
download=True)
|
download=True)
|
||||||
|
|
||||||
# Data loader
|
# Data loader
|
||||||
train_dataloader = plugin.prepare_train_dataloader(train_dataset,
|
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||||
batch_size=batch_size,
|
test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
|
||||||
shuffle=True,
|
|
||||||
drop_last=True)
|
|
||||||
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False)
|
|
||||||
return train_dataloader, test_dataloader
|
return train_dataloader, test_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -84,26 +84,26 @@ class GLUEDataBuilder:
|
||||||
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.plugin.prepare_train_dataloader(self.dataset["train"],
|
return self.plugin.prepare_dataloader(self.dataset["train"],
|
||||||
batch_size=self.train_batch_size,
|
batch_size=self.train_batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
drop_last=True)
|
drop_last=True)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
if len(self.eval_splits) == 1:
|
if len(self.eval_splits) == 1:
|
||||||
return self.plugin.prepare_train_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
||||||
elif len(self.eval_splits) > 1:
|
elif len(self.eval_splits) > 1:
|
||||||
return [
|
return [
|
||||||
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||||||
for x in self.eval_splits
|
for x in self.eval_splits
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
if len(self.eval_splits) == 1:
|
if len(self.eval_splits) == 1:
|
||||||
return self.plugin.prepare_train_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
|
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
|
||||||
elif len(self.eval_splits) > 1:
|
elif len(self.eval_splits) > 1:
|
||||||
return [
|
return [
|
||||||
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||||||
for x in self.eval_splits
|
for x in self.eval_splits
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ def check_dataloader_sharding():
|
||||||
|
|
||||||
# create a custom dasetset with 0 to 10
|
# create a custom dasetset with 0 to 10
|
||||||
dataset = TensorDataset(torch.arange(0, 10))
|
dataset = TensorDataset(torch.arange(0, 10))
|
||||||
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
|
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
|
||||||
|
|
||||||
# get the first batch of data
|
# get the first batch of data
|
||||||
batch = next(iter(train_dataloader))[0].cuda()
|
batch = next(iter(train_dataloader))[0].cuda()
|
||||||
|
|
Loading…
Reference in New Issue