mirror of https://github.com/hpcaitech/ColossalAI
[embedding] updates some default parameters
parent
cd5cf2bcc9
commit
a19eb80998
|
@ -1 +0,0 @@
|
|||
Subproject commit 9ab77e0ecc8e4ff480704dac2535b9c8f44f47b2
|
|
@ -35,7 +35,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self,
|
||||
weight: torch.Tensor,
|
||||
cuda_row_num: int = 0,
|
||||
buffer_size: int = 50_000,
|
||||
buffer_size: int = 0,
|
||||
pin_weight: bool = False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||
use_cpu_caching=False,
|
||||
|
@ -211,7 +211,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
|
||||
preload_cuda_row_idxs = torch.arange(preload_row_num).to(self._cache_dev)
|
||||
else:
|
||||
preload_cpu_ids = torch.arange(preload_row_num)
|
||||
preload_cpu_ids = torch.arange(preload_row_num, device=self.weight.device)
|
||||
preload_cuda_row_idxs = preload_cpu_ids.to(self._cache_dev)
|
||||
|
||||
if self.buffer_size > 0:
|
||||
|
@ -304,8 +304,10 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self.evict_backlist = cpu_row_idxs
|
||||
|
||||
with record_function("(pre-id) get cpu row idxs"):
|
||||
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(
|
||||
cpu_row_idxs, self.cached_idx_map, assume_unique=True, invert=True)]
|
||||
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs,
|
||||
self.cached_idx_map,
|
||||
assume_unique=True,
|
||||
invert=True)]
|
||||
|
||||
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
||||
self.num_miss_history.append(len(comm_cpu_row_idxs))
|
||||
|
|
|
@ -30,7 +30,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
cuda_row_num (int, optional): the max number of embedding vector in cuda cache. Defaults to 0.
|
||||
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None.
|
||||
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
|
||||
buffer_size (int, optional): the max number of vectors in transmitter buffer. Defaults to 50_000.
|
||||
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, means do not use the buffer. Defaults to 0.
|
||||
pin_weight (bool, optional): pin the cpu weight. Defaults to False.
|
||||
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET.
|
||||
"""
|
||||
|
@ -51,9 +51,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
cuda_row_num: int = 0,
|
||||
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None,
|
||||
warmup_ratio: float = 0.7,
|
||||
buffer_size: int = 50_000,
|
||||
buffer_size: int = 0,
|
||||
pin_weight: bool = False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
|
||||
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
||||
scale_grad_by_freq, sparse, mode, include_last_offset)
|
||||
|
||||
|
@ -96,10 +96,10 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
evict_strategy=self.evict_strategy)
|
||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||
|
||||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None):
|
||||
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(input)
|
||||
|
||||
embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
|
||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
|
|
|
@ -123,7 +123,6 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
local_per_sample_weights_list: List(torch.Tensor) = []
|
||||
|
||||
offset_pre_end = 0 # local_offsets trick
|
||||
|
||||
for i, handle_table in enumerate(self.assigned_table_list):
|
||||
indices_start_position = offsets[batch_size * handle_table]
|
||||
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
||||
|
@ -162,15 +161,15 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
# till-the-end special case
|
||||
if not self.include_last_offset:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
||||
batch_size).add(offset_pre_end - offsets[batch_size
|
||||
* (handle_table)])
|
||||
batch_size).add(offset_pre_end - offsets[batch_size *
|
||||
(handle_table)])
|
||||
else:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
|
||||
+ 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
|
||||
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
local_offsets_list.append(local_offsets)
|
||||
else:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
|
||||
+ 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
|
||||
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
offset_pre_end = local_offsets[-1]
|
||||
local_offsets_list.append(local_offsets[:-1])
|
||||
# 3. local_per_sample_weights_list:
|
||||
|
|
1
examples
1
examples
|
@ -1 +0,0 @@
|
|||
Subproject commit 757514d2b1501d3530777cdf567f0a18063acf2d
|
Loading…
Reference in New Issue