From a389ac4ec93aa9596417fe0b8ec74e7f88c8882c Mon Sep 17 00:00:00 2001
From: CsRic <59389055+CsRic@users.noreply.github.com>
Date: Thu, 8 Sep 2022 16:41:19 +0800
Subject: [PATCH] [embedding] cache_embedding small improvement (#1564)

---
 .../layers/cache_embedding/cache_mgr.py       |  7 ++--
 .../parallel_freq_aware_embedding.py          |  5 ++-
 ...parallel_freq_aware_embedding_tablewise.py | 36 +++++++++++++++----
 3 files changed, 35 insertions(+), 13 deletions(-)

diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
index fdb120134..e7daf5355 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
@@ -178,7 +178,7 @@ class CachedParamMgr(torch.nn.Module):
         """reorder
         reorder the weight according to ids' frequency in dataset before training.
         Execute only once before training, also known as warmup phase.
-        
+
         Note:
             If you would like to use the DATASET as the eviction strategy, you must call this function.
 
@@ -304,7 +304,8 @@ class CachedParamMgr(torch.nn.Module):
             self.evict_backlist = cpu_row_idxs
 
         with record_function("(pre-id) get cpu row idxs"):
-            comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
+            comm_cpu_row_idxs = cpu_row_idxs[torch.isin(
+                cpu_row_idxs, self.cached_idx_map, assume_unique=True, invert=True)]
 
         self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
         self.num_miss_history.append(len(comm_cpu_row_idxs))
@@ -345,7 +346,7 @@ class CachedParamMgr(torch.nn.Module):
         evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
         if evict_num > 0:
             with Timer() as timer:
-                mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
+                mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist, assume_unique=True)
                 invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
                 if self._evict_strategy == EvictionStrategy.DATASET:
                     # mask method.
diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
index e53b126b7..61d870fad 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
+++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
@@ -75,7 +75,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
     def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
         with torch.no_grad():
             reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
-
         output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
                                        self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
                                        per_sample_weights, self.include_last_offset, self.padding_idx)
@@ -124,6 +123,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
 
     def print_comm_stats_(self):
         self.cache_weight_mgr.print_comm_stats()
-        
+
     def element_size(self):
-        return self.weight.element_size()
\ No newline at end of file
+        return self.weight.element_size()
diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
index c0d72fbfc..d2f6b7c53 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
+++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py
@@ -87,6 +87,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
             local_per_sample_weights_list: List(torch.Tensor) = []
 
         offset_pre_end = 0    # local_offsets trick
+
         for i, handle_table in enumerate(self.assigned_table_list):
             indices_start_position = offsets[batch_size * handle_table]
             if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
@@ -94,6 +95,28 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
                 indices_end_position = indices.shape[0]
             else:
                 indices_end_position = offsets[batch_size * (handle_table + 1)]
+            # alternative approach: reduce malloc
+            '''
+            # 1. local_indices_list:
+            local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
+            torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
+            local_indices_list.append(local_indices)
+            # 2. local_offsets_list:
+            if i + 1 == len(self.assigned_table_list):
+                # till-the-end special case
+                if not self.include_last_offset:
+                    local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
+                else:
+                    local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1)
+                torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
+                local_offsets_list.append(local_offsets)
+            else:
+                temp_holder = offsets[batch_size * handle_table].item()
+                local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
+                torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
+                offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
+                local_offsets_list.append(local_offsets)
+            '''
             # 1. local_indices_list:
             local_indices_list.append(
                 indices.narrow(0, indices_start_position,
@@ -103,21 +126,20 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
                 # till-the-end special case
                 if not self.include_last_offset:
                     local_offsets = offsets.narrow(0, batch_size * handle_table,
-                                                   batch_size).add(offset_pre_end - offsets[batch_size *
-                                                                                            (handle_table)])
+                                                   batch_size).add(offset_pre_end - offsets[batch_size
+                                                                                            * (handle_table)])
                 else:
-                    local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
-                                                   1).add(offset_pre_end - offsets[batch_size * (handle_table)])
+                    local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
+                                                   + 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
                 local_offsets_list.append(local_offsets)
             else:
-                local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
-                                               1).add(offset_pre_end - offsets[batch_size * (handle_table)])
+                local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
+                                               + 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
                 offset_pre_end = local_offsets[-1]
                 local_offsets_list.append(local_offsets[:-1])
             # 3. local_per_sample_weights_list:
             if per_sample_weights != None:
                 local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
-
         local_indices = torch.cat(local_indices_list, 0)
         local_offsets = torch.cat(local_offsets_list, 0)
         local_per_sample_weights = None