From a19eb8099858458e8b8875cb42e0479c1f3f6d5e Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 15 Sep 2022 15:45:17 +0800 Subject: [PATCH] [embedding] updates some default parameters --- benchmark | 1 - .../nn/parallel/layers/cache_embedding/cache_mgr.py | 10 ++++++---- .../layers/cache_embedding/freq_aware_embedding.py | 12 ++++++------ .../parallel_freq_aware_embedding_tablewise.py | 13 ++++++------- examples | 1 - 5 files changed, 18 insertions(+), 19 deletions(-) delete mode 160000 benchmark delete mode 160000 examples diff --git a/benchmark b/benchmark deleted file mode 160000 index 9ab77e0ec..000000000 --- a/benchmark +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9ab77e0ecc8e4ff480704dac2535b9c8f44f47b2 diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index e7daf5355..42f3e0e4b 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -35,7 +35,7 @@ class CachedParamMgr(torch.nn.Module): self, weight: torch.Tensor, cuda_row_num: int = 0, - buffer_size: int = 50_000, + buffer_size: int = 0, pin_weight: bool = False, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, use_cpu_caching=False, @@ -211,7 +211,7 @@ class CachedParamMgr(torch.nn.Module): freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True) preload_cuda_row_idxs = torch.arange(preload_row_num).to(self._cache_dev) else: - preload_cpu_ids = torch.arange(preload_row_num) + preload_cpu_ids = torch.arange(preload_row_num, device=self.weight.device) preload_cuda_row_idxs = preload_cpu_ids.to(self._cache_dev) if self.buffer_size > 0: @@ -304,8 +304,10 @@ class CachedParamMgr(torch.nn.Module): self.evict_backlist = cpu_row_idxs with record_function("(pre-id) get cpu row idxs"): - comm_cpu_row_idxs = cpu_row_idxs[torch.isin( - cpu_row_idxs, self.cached_idx_map, assume_unique=True, invert=True)] + comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, + self.cached_idx_map, + assume_unique=True, + invert=True)] self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs)) diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index ad6e39b7d..282f6d0c4 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -30,7 +30,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): cuda_row_num (int, optional): the max number of embedding vector in cuda cache. Defaults to 0. ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None. warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7. - buffer_size (int, optional): the max number of vectors in transmitter buffer. Defaults to 50_000. + buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, means do not use the buffer. Defaults to 0. pin_weight (bool, optional): pin the cpu weight. Defaults to False. evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. """ @@ -51,9 +51,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): cuda_row_num: int = 0, ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, warmup_ratio: float = 0.7, - buffer_size: int = 50_000, + buffer_size: int = 0, pin_weight: bool = False, - evict_strategy: EvictionStrategy = EvictionStrategy.DATASET): + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, mode, include_last_offset) @@ -96,10 +96,10 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): evict_strategy=self.evict_strategy) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) - def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None): + def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): with torch.no_grad(): - reorder_ids = self.cache_weight_mgr.prepare_ids(indices) - + reorder_ids = self.cache_weight_mgr.prepare_ids(input) + embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index 267e88cfb..731115d3c 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -123,7 +123,6 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): local_per_sample_weights_list: List(torch.Tensor) = [] offset_pre_end = 0 # local_offsets trick - for i, handle_table in enumerate(self.assigned_table_list): indices_start_position = offsets[batch_size * handle_table] if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): @@ -162,15 +161,15 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): # till-the-end special case if not self.include_last_offset: local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).add(offset_pre_end - offsets[batch_size - * (handle_table)]) + batch_size).add(offset_pre_end - offsets[batch_size * + (handle_table)]) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size - + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) local_offsets_list.append(local_offsets) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size - + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + + 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) offset_pre_end = local_offsets[-1] local_offsets_list.append(local_offsets[:-1]) # 3. local_per_sample_weights_list: diff --git a/examples b/examples deleted file mode 160000 index 757514d2b..000000000 --- a/examples +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 757514d2b1501d3530777cdf567f0a18063acf2d