mirror of https://github.com/hpcaitech/ColossalAI
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import math
|
|
|
|
import numpy as np
|
|
|
|
|
|
class DistributedSampler:
|
|
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
|
|
self.dataset = dataset
|
|
self.num_replicas = num_replicas
|
|
self.rank = rank
|
|
|
|
if len(self.dataset) % self.num_replicas != 0:
|
|
self.num_samples = math.ceil(
|
|
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
|
)
|
|
else:
|
|
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
|
|
|
self.total_size = self.num_samples * self.num_replicas
|
|
|
|
indices = list(range(len(self.dataset)))
|
|
indices = indices[: self.total_size]
|
|
assert len(indices) == self.total_size
|
|
# subsample
|
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
assert len(indices) == self.num_samples
|
|
self.indices = indices
|
|
|
|
def sample(self, batch_size: int) -> list:
|
|
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
|
|
return [self.dataset[idx] for idx in sampled_indices]
|