mirror of https://github.com/hpcaitech/ColossalAI
[FAW] cpu caching operations (#1520)
parent
481aecb05a
commit
9a9ef65313
|
@ -30,6 +30,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
`EvictionStrategy.LFU`: use the least frequently used cache.
|
`EvictionStrategy.LFU`: use the least frequently used cache.
|
||||||
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
|
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
|
||||||
Defaults to EvictionStrategy.DATASET.
|
Defaults to EvictionStrategy.DATASET.
|
||||||
|
use_cpu_caching (bool, optional): use cpu to execute cache indexing. It is slower than use gpu.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -39,6 +40,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
buffer_size: int = 50_000,
|
buffer_size: int = 50_000,
|
||||||
pin_weight: bool = False,
|
pin_weight: bool = False,
|
||||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||||
|
use_cpu_caching=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(CachedParamMgr, self).__init__()
|
super(CachedParamMgr, self).__init__()
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
|
@ -48,6 +50,13 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self.pin_weight = pin_weight
|
self.pin_weight = pin_weight
|
||||||
self.elem_size_in_byte = weight.element_size()
|
self.elem_size_in_byte = weight.element_size()
|
||||||
|
|
||||||
|
self._cpu_caching = use_cpu_caching
|
||||||
|
|
||||||
|
if self._cpu_caching:
|
||||||
|
self._cache_dev = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
self._cache_dev = torch.cuda.current_device()
|
||||||
|
|
||||||
# weight configure
|
# weight configure
|
||||||
self._init_weight(weight)
|
self._init_weight(weight)
|
||||||
|
|
||||||
|
@ -62,8 +71,13 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
# cache_row_idx -> frequency, freq of the cache rows.
|
# cache_row_idx -> frequency, freq of the cache rows.
|
||||||
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
||||||
|
if self._cpu_caching:
|
||||||
|
self.freq_cnter = torch.empty(self.cuda_row_num, device=self._cache_dev,
|
||||||
|
dtype=torch.long).fill_(sys.maxsize)
|
||||||
|
|
||||||
|
else:
|
||||||
self.register_buffer("freq_cnter",
|
self.register_buffer("freq_cnter",
|
||||||
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
torch.empty(self.cuda_row_num, device=self._cache_dev,
|
||||||
dtype=torch.long).fill_(sys.maxsize),
|
dtype=torch.long).fill_(sys.maxsize),
|
||||||
persistent=False)
|
persistent=False)
|
||||||
|
|
||||||
|
@ -105,26 +119,32 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self.weight = weight.pin_memory() if self.pin_weight else weight
|
self.weight = weight.pin_memory() if self.pin_weight else weight
|
||||||
# map original id to new id with respect to frequency
|
# map original id to new id with respect to frequency
|
||||||
# id -> cpu_row_idx
|
# id -> cpu_row_idx
|
||||||
|
|
||||||
|
if self._cpu_caching:
|
||||||
|
self.idx_map = torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev)
|
||||||
|
self.cached_idx_map = torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1)
|
||||||
|
self.inverted_cached_idx = torch.zeros(self.num_embeddings, device=self._cache_dev,
|
||||||
|
dtype=torch.long).fill_(-1)
|
||||||
|
else:
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"idx_map",
|
"idx_map",
|
||||||
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
|
torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# cached_idx_map: gpu_row_idx -> cpu_row_idx
|
# cached_idx_map: gpu_row_idx -> cpu_row_idx
|
||||||
self.register_buffer("cached_idx_map",
|
self.register_buffer("cached_idx_map",
|
||||||
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1),
|
||||||
dtype=torch.long).fill_(-1),
|
|
||||||
persistent=False)
|
persistent=False)
|
||||||
|
|
||||||
# cpu_row_id -> gpu_row_idx.
|
# cpu_row_id -> gpu_row_idx.
|
||||||
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
|
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
|
||||||
self.register_buffer("inverted_cached_idx",
|
self.register_buffer("inverted_cached_idx",
|
||||||
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
|
torch.zeros(self.num_embeddings, device=self._cache_dev,
|
||||||
dtype=torch.long).fill_(-1),
|
dtype=torch.long).fill_(-1),
|
||||||
persistent=False)
|
persistent=False)
|
||||||
|
|
||||||
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
|
self.evict_backlist = torch.tensor([], device=self._cache_dev)
|
||||||
|
|
||||||
# index copy buffer size should less than 10% of cuda weight.
|
# index copy buffer size should less than 10% of cuda weight.
|
||||||
if self.buffer_size > 0:
|
if self.buffer_size > 0:
|
||||||
|
@ -191,24 +211,24 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
# extract rows from cpu weight
|
# extract rows from cpu weight
|
||||||
if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None:
|
if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None:
|
||||||
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
|
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
|
||||||
preload_cuda_row_idxs = torch.arange(preload_row_num).cuda()
|
preload_cuda_row_idxs = torch.arange(preload_row_num).to(self._cache_dev)
|
||||||
else:
|
else:
|
||||||
preload_cpu_ids = torch.arange(preload_row_num)
|
preload_cpu_ids = torch.arange(preload_row_num)
|
||||||
preload_cuda_row_idxs = preload_cpu_ids.cuda()
|
preload_cuda_row_idxs = preload_cpu_ids.to(self._cache_dev)
|
||||||
|
|
||||||
if self.buffer_size > 0:
|
if self.buffer_size > 0:
|
||||||
self.limit_buff_index_copyer.index_copy(0,
|
self.limit_buff_index_copyer.index_copy(0,
|
||||||
src_index=preload_cpu_ids,
|
src_index=preload_cpu_ids,
|
||||||
tgt_index=preload_cuda_row_idxs,
|
tgt_index=preload_cuda_row_idxs.cuda(),
|
||||||
src=self.weight.view(self.num_embeddings, -1),
|
src=self.weight.view(self.num_embeddings, -1),
|
||||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||||
else:
|
else:
|
||||||
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
|
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
|
||||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs.cuda(),
|
||||||
preload_rows)
|
preload_rows)
|
||||||
|
|
||||||
# update auxiliary info
|
# update auxiliary info
|
||||||
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda()
|
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.to(self._cache_dev)
|
||||||
self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs
|
self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs
|
||||||
self._cuda_available_row_num -= preload_row_num
|
self._cuda_available_row_num -= preload_row_num
|
||||||
|
|
||||||
|
@ -217,7 +237,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
if ids_freq_mapping is None:
|
if ids_freq_mapping is None:
|
||||||
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
|
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
|
||||||
else:
|
else:
|
||||||
self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda()
|
self.freq_cnter[preload_cuda_row_idxs] = freq_value.to(self._cache_dev)
|
||||||
|
|
||||||
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
||||||
|
|
||||||
|
@ -227,7 +247,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
||||||
row_ids = self.cached_idx_map[slots]
|
row_ids = self.cached_idx_map[slots]
|
||||||
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots.cuda()).cpu()
|
||||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
|
||||||
self.cached_idx_map.index_fill_(0, slots, -1)
|
self.cached_idx_map.index_fill_(0, slots, -1)
|
||||||
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
|
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
|
||||||
|
@ -276,6 +296,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
torch.Tensor: indices on the cuda_cached_weight.
|
torch.Tensor: indices on the cuda_cached_weight.
|
||||||
"""
|
"""
|
||||||
with record_function("(zhg) get unique indices"):
|
with record_function("(zhg) get unique indices"):
|
||||||
|
ids = ids.to(self._cache_dev)
|
||||||
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
||||||
|
|
||||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||||
|
@ -353,7 +374,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
tgt=self.weight.view(self.num_embeddings, -1))
|
tgt=self.weight.view(self.num_embeddings, -1))
|
||||||
else:
|
else:
|
||||||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||||
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
|
rows = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||||
|
-1).index_select(0, evict_gpu_row_idxs.cuda()).cpu()
|
||||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
|
||||||
|
|
||||||
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
||||||
|
@ -372,12 +394,12 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
if self.buffer_size > 0:
|
if self.buffer_size > 0:
|
||||||
self.limit_buff_index_copyer.index_copy(0,
|
self.limit_buff_index_copyer.index_copy(0,
|
||||||
src_index=cpu_row_idxs.cpu(),
|
src_index=cpu_row_idxs.cpu(),
|
||||||
tgt_index=slots,
|
tgt_index=slots.cuda(),
|
||||||
src=self.weight.view(self.num_embeddings, -1),
|
src=self.weight.view(self.num_embeddings, -1),
|
||||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||||
else:
|
else:
|
||||||
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
|
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
|
||||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots.cuda(), rows)
|
||||||
slot_offsets = slots
|
slot_offsets = slots
|
||||||
self.cached_idx_map[slots] = cpu_row_idxs
|
self.cached_idx_map[slots] = cpu_row_idxs
|
||||||
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
||||||
|
|
|
@ -74,8 +74,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||||
|
|
||||||
embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
|
||||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||||
if shape_hook is not None:
|
if shape_hook is not None:
|
||||||
embeddings = shape_hook(embeddings)
|
embeddings = shape_hook(embeddings)
|
||||||
|
|
|
@ -8,6 +8,7 @@ from colossalai.nn._ops._utils import dual_all_to_all
|
||||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
||||||
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||||
|
|
||||||
|
|
||||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return 0, embedding_dim, True
|
return 0, embedding_dim, True
|
||||||
|
@ -29,8 +30,7 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||||
|
|
||||||
class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
|
@ -48,8 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
warmup_ratio=0.7,
|
warmup_ratio=0.7,
|
||||||
buffer_size=50_000,
|
buffer_size=50_000,
|
||||||
pin_weight=False,
|
pin_weight=False,
|
||||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
|
||||||
):
|
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
self.world_size = torch.distributed.get_world_size()
|
self.world_size = torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
@ -60,7 +59,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
super(ParallelFreqAwareEmbeddingBag,
|
super(ParallelFreqAwareEmbeddingBag,
|
||||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||||
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
||||||
warmup_ratio, buffer_size, pin_weight,evict_strategy)
|
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
||||||
|
|
||||||
def _weight_alloc(self, dtype, device):
|
def _weight_alloc(self, dtype, device):
|
||||||
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
||||||
|
@ -77,8 +76,8 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||||
|
|
||||||
output_shard = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
|
||||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||||
|
|
||||||
if shape_hook is not None:
|
if shape_hook is not None:
|
||||||
|
|
|
@ -83,15 +83,16 @@ def test_reorder_with_freq():
|
||||||
chunkid.append(idx // chunk_size)
|
chunkid.append(idx // chunk_size)
|
||||||
offset_in_chunk.append(idx % chunk_size)
|
offset_in_chunk.append(idx % chunk_size)
|
||||||
|
|
||||||
chunkid = torch.tensor(chunkid, dtype=torch.long, device=torch.cuda.current_device())
|
dev = torch.device('cuda')
|
||||||
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=torch.cuda.current_device())
|
chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev)
|
||||||
|
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev)
|
||||||
|
|
||||||
weight = torch.rand(num_embed, 2)
|
weight = torch.rand(num_embed, 2)
|
||||||
mgr = CachedParamMgr(weight, num_chunk)
|
mgr = CachedParamMgr(weight, num_chunk, use_cpu_caching=dev.type == 'cpu')
|
||||||
|
|
||||||
mgr.reorder(idx_map)
|
mgr.reorder(idx_map)
|
||||||
|
|
||||||
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=torch.cuda.current_device()))
|
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev))
|
||||||
mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor')
|
mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor')
|
||||||
mgr_offsets = torch.remainder(indices, chunk_size)
|
mgr_offsets = torch.remainder(indices, chunk_size)
|
||||||
assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
|
assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
|
||||||
|
@ -280,6 +281,6 @@ def test_parallel_freq_aware_embed(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_freq_aware_embed(True)
|
test_freq_aware_embed(True)
|
||||||
# test_parallel_freq_aware_embed(2)
|
# test_parallel_freq_aware_embed(2)
|
||||||
test_lfu_strategy(False)
|
# test_lfu_strategy(False)
|
||||||
|
|
Loading…
Reference in New Issue