mirror of https://github.com/hpcaitech/ColossalAI
[embeddings] use cache_ratio instead of cuda_row_num (#1611)
parent
6a8f8cc05e
commit
504ff1d101
|
@ -293,7 +293,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: indices on the cuda_cached_weight.
|
torch.Tensor: indices on the cuda_cached_weight.
|
||||||
"""
|
"""
|
||||||
with record_function("(pre-id) get unique indices"):
|
with record_function(f"(pre-id) get unique indices. cache ratio {self.cuda_row_num / self.num_embeddings}"):
|
||||||
ids = ids.to(self._cache_dev)
|
ids = ids.to(self._cache_dev)
|
||||||
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
||||||
|
|
||||||
|
|
|
@ -27,10 +27,10 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False.
|
include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False.
|
||||||
dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32.
|
dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32.
|
||||||
device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu.
|
device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu.
|
||||||
cuda_row_num (int, optional): the max number of embedding vector in cuda cache. Defaults to 0.
|
cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row
|
||||||
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None.
|
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None.
|
||||||
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
|
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
|
||||||
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, means do not use the buffer. Defaults to 0.
|
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0.
|
||||||
pin_weight (bool, optional): pin the cpu weight. Defaults to False.
|
pin_weight (bool, optional): pin the cpu weight. Defaults to False.
|
||||||
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET.
|
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET.
|
||||||
"""
|
"""
|
||||||
|
@ -48,7 +48,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
include_last_offset: bool = False,
|
include_last_offset: bool = False,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
cuda_row_num: int = 0,
|
cache_ratio: float = 0.01,
|
||||||
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None,
|
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None,
|
||||||
warmup_ratio: float = 0.7,
|
warmup_ratio: float = 0.7,
|
||||||
buffer_size: int = 0,
|
buffer_size: int = 0,
|
||||||
|
@ -57,10 +57,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
||||||
scale_grad_by_freq, sparse, mode, include_last_offset)
|
scale_grad_by_freq, sparse, mode, include_last_offset)
|
||||||
|
|
||||||
|
assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0"
|
||||||
self.evict_strategy = evict_strategy
|
self.evict_strategy = evict_strategy
|
||||||
if _weight is None:
|
if _weight is None:
|
||||||
_weight = self._weight_alloc(dtype, device)
|
_weight = self._weight_alloc(dtype, device)
|
||||||
|
cuda_row_num = int(num_embeddings * cache_ratio)
|
||||||
# configure weight & cache
|
# configure weight & cache
|
||||||
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
include_last_offset=False,
|
include_last_offset=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
cuda_row_num=0,
|
cache_ratio=0.01,
|
||||||
ids_freq_mapping=None,
|
ids_freq_mapping=None,
|
||||||
warmup_ratio=0.7,
|
warmup_ratio=0.7,
|
||||||
buffer_size=50_000,
|
buffer_size=50_000,
|
||||||
|
@ -58,7 +58,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
|
|
||||||
super(ParallelFreqAwareEmbeddingBag,
|
super(ParallelFreqAwareEmbeddingBag,
|
||||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||||
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
||||||
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
||||||
|
|
||||||
def _weight_alloc(self, dtype, device):
|
def _weight_alloc(self, dtype, device):
|
||||||
|
|
|
@ -31,7 +31,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
include_last_offset=False,
|
include_last_offset=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
cuda_row_num=0,
|
cache_ratio=0.01,
|
||||||
warmup_ratio=0.7,
|
warmup_ratio=0.7,
|
||||||
buffer_size=50_000,
|
buffer_size=50_000,
|
||||||
pin_weight=False,
|
pin_weight=False,
|
||||||
|
@ -59,11 +59,12 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
else:
|
else:
|
||||||
ids_freq_mapping = None
|
ids_freq_mapping = None
|
||||||
break
|
break
|
||||||
|
self.cache_ratio = cache_ratio
|
||||||
# table-associate cache
|
# table-associate cache
|
||||||
|
cuda_row_num = int(cache_ratio * self.num_embeddings)
|
||||||
super(ParallelFreqAwareEmbeddingBagTablewise,
|
super(ParallelFreqAwareEmbeddingBagTablewise,
|
||||||
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||||
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
||||||
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
||||||
|
|
||||||
# for assigned tables reconnection:
|
# for assigned tables reconnection:
|
||||||
|
|
|
@ -110,7 +110,7 @@ def test_freq_aware_embed(use_LFU: bool):
|
||||||
EMBED_DIM,
|
EMBED_DIM,
|
||||||
mode='mean',
|
mode='mean',
|
||||||
include_last_offset=True,
|
include_last_offset=True,
|
||||||
cuda_row_num=BATCH_SIZE * 2,
|
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
|
||||||
ids_freq_mapping=None,
|
ids_freq_mapping=None,
|
||||||
evict_strategy=evict_strategy).to(device)
|
evict_strategy=evict_strategy).to(device)
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ def test_lfu_strategy(init_freq: bool):
|
||||||
# minimal test to check behavior
|
# minimal test to check behavior
|
||||||
Bag = FreqAwareEmbeddingBag(5,
|
Bag = FreqAwareEmbeddingBag(5,
|
||||||
5,
|
5,
|
||||||
cuda_row_num=3,
|
cache_ratio=3 / 5,
|
||||||
buffer_size=0,
|
buffer_size=0,
|
||||||
pin_weight=True,
|
pin_weight=True,
|
||||||
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
||||||
|
@ -238,7 +238,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||||
embedding_dim=5,
|
embedding_dim=5,
|
||||||
_weight=_weight,
|
_weight=_weight,
|
||||||
include_last_offset=True,
|
include_last_offset=True,
|
||||||
cuda_row_num=8,
|
cache_ratio=0.5,
|
||||||
buffer_size=0,
|
buffer_size=0,
|
||||||
evict_strategy=EvictionStrategy.LFU,
|
evict_strategy=EvictionStrategy.LFU,
|
||||||
)
|
)
|
||||||
|
@ -304,7 +304,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
||||||
coloweight,
|
coloweight,
|
||||||
include_last_offset=True,
|
include_last_offset=True,
|
||||||
freeze=False,
|
freeze=False,
|
||||||
cuda_row_num=batch_size * 2,
|
cache_ratio=batch_size * 2 / num_embed,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert model.cache_weight_mgr.weight.device.type == 'cpu'
|
assert model.cache_weight_mgr.weight.device.type == 'cpu'
|
||||||
|
|
Loading…
Reference in New Issue