mirror of https://github.com/hpcaitech/ColossalAI
154 lines
6.6 KiB
Python
154 lines
6.6 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Dataloaders."""
|
||
|
|
||
|
import torch
|
||
|
import random
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.context import ParallelMode
|
||
|
|
||
|
|
||
|
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
|
||
|
"""Build dataloader given an input dataset."""
|
||
|
|
||
|
if dataset is None:
|
||
|
return None
|
||
|
|
||
|
# Megatron sampler
|
||
|
if dataloader_type == 'single':
|
||
|
batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset),
|
||
|
consumed_samples=consumed_samples,
|
||
|
micro_batch_size=micro_batch_size,
|
||
|
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||
|
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||
|
elif dataloader_type == 'cyclic':
|
||
|
batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset),
|
||
|
consumed_samples=consumed_samples,
|
||
|
micro_batch_size=micro_batch_size,
|
||
|
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||
|
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||
|
else:
|
||
|
raise Exception('{} dataloader type is not supported.'.format(dataloader_type))
|
||
|
|
||
|
# Torch dataloader.
|
||
|
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
|
||
|
|
||
|
|
||
|
class MegatronPretrainingSampler:
|
||
|
|
||
|
def __init__(self,
|
||
|
total_samples,
|
||
|
consumed_samples,
|
||
|
micro_batch_size,
|
||
|
data_parallel_rank,
|
||
|
data_parallel_size,
|
||
|
drop_last=True):
|
||
|
# Keep a copy of input params for later use.
|
||
|
self.total_samples = total_samples
|
||
|
self.consumed_samples = consumed_samples
|
||
|
self.micro_batch_size = micro_batch_size
|
||
|
self.data_parallel_rank = data_parallel_rank
|
||
|
self.micro_batch_times_data_parallel_size = \
|
||
|
self.micro_batch_size * data_parallel_size
|
||
|
self.drop_last = drop_last
|
||
|
|
||
|
# Sanity checks.
|
||
|
assert self.total_samples > 0, \
|
||
|
'no sample to consume: {}'.format(self.total_samples)
|
||
|
assert self.consumed_samples < self.total_samples, \
|
||
|
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
||
|
self.total_samples)
|
||
|
assert self.micro_batch_size > 0
|
||
|
assert data_parallel_size > 0
|
||
|
assert self.data_parallel_rank < data_parallel_size, \
|
||
|
'data_parallel_rank should be smaller than data size: {}, ' \
|
||
|
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.total_samples
|
||
|
|
||
|
def get_start_end_idx(self):
|
||
|
start_idx = self.data_parallel_rank * self.micro_batch_size
|
||
|
end_idx = start_idx + self.micro_batch_size
|
||
|
return start_idx, end_idx
|
||
|
|
||
|
def __iter__(self):
|
||
|
batch = []
|
||
|
# Last batch will be dropped if drop_last is not set False
|
||
|
for idx in range(self.consumed_samples, self.total_samples):
|
||
|
batch.append(idx)
|
||
|
if len(batch) == self.micro_batch_times_data_parallel_size:
|
||
|
start_idx, end_idx = self.get_start_end_idx()
|
||
|
yield batch[start_idx:end_idx]
|
||
|
batch = []
|
||
|
|
||
|
# Check the last partial batch and see drop_last is set
|
||
|
if len(batch) > 0 and not self.drop_last:
|
||
|
start_idx, end_idx = self.get_start_end_idx()
|
||
|
yield batch[start_idx:end_idx]
|
||
|
|
||
|
|
||
|
class MegatronPretrainingRandomSampler:
|
||
|
|
||
|
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
|
||
|
# Keep a copy of input params for later use.
|
||
|
self.total_samples = total_samples
|
||
|
self.consumed_samples = consumed_samples
|
||
|
self.micro_batch_size = micro_batch_size
|
||
|
self.data_parallel_rank = data_parallel_rank
|
||
|
self.data_parallel_size = data_parallel_size
|
||
|
self.micro_batch_times_data_parallel_size = \
|
||
|
self.micro_batch_size * data_parallel_size
|
||
|
self.last_batch_size = \
|
||
|
self.total_samples % self.micro_batch_times_data_parallel_size
|
||
|
|
||
|
# Sanity checks.
|
||
|
assert self.total_samples > 0, \
|
||
|
'no sample to consume: {}'.format(self.total_samples)
|
||
|
assert self.micro_batch_size > 0
|
||
|
assert data_parallel_size > 0
|
||
|
assert self.data_parallel_rank < data_parallel_size, \
|
||
|
'data_parallel_rank should be smaller than data size: {}, ' \
|
||
|
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.total_samples
|
||
|
|
||
|
def __iter__(self):
|
||
|
active_total_samples = self.total_samples - self.last_batch_size
|
||
|
self.epoch = self.consumed_samples // active_total_samples
|
||
|
current_epoch_samples = self.consumed_samples % active_total_samples
|
||
|
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
||
|
|
||
|
# data sharding and random sampling
|
||
|
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
|
||
|
* self.micro_batch_size
|
||
|
bucket_offset = current_epoch_samples // self.data_parallel_size
|
||
|
start_idx = self.data_parallel_rank * bucket_size
|
||
|
|
||
|
g = torch.Generator()
|
||
|
g.manual_seed(self.epoch)
|
||
|
random_idx = torch.randperm(bucket_size, generator=g).tolist()
|
||
|
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
|
||
|
|
||
|
batch = []
|
||
|
# Last batch if not complete will be dropped.
|
||
|
for idx in idx_range:
|
||
|
batch.append(idx)
|
||
|
if len(batch) == self.micro_batch_size:
|
||
|
self.consumed_samples += self.micro_batch_times_data_parallel_size
|
||
|
yield batch
|
||
|
batch = []
|