mirror of https://github.com/hpcaitech/ColossalAI
[embedding] cache_embedding small improvement (#1564)
parent
10dd8226b1
commit
a389ac4ec9
|
@ -178,7 +178,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
"""reorder
|
||||
reorder the weight according to ids' frequency in dataset before training.
|
||||
Execute only once before training, also known as warmup phase.
|
||||
|
||||
|
||||
Note:
|
||||
If you would like to use the DATASET as the eviction strategy, you must call this function.
|
||||
|
||||
|
@ -304,7 +304,8 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self.evict_backlist = cpu_row_idxs
|
||||
|
||||
with record_function("(pre-id) get cpu row idxs"):
|
||||
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
||||
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(
|
||||
cpu_row_idxs, self.cached_idx_map, assume_unique=True, invert=True)]
|
||||
|
||||
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
||||
self.num_miss_history.append(len(comm_cpu_row_idxs))
|
||||
|
@ -345,7 +346,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
|
||||
if evict_num > 0:
|
||||
with Timer() as timer:
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist, assume_unique=True)
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# mask method.
|
||||
|
|
|
@ -75,7 +75,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||
|
||||
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
|
||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
|
@ -124,6 +123,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
|
||||
def print_comm_stats_(self):
|
||||
self.cache_weight_mgr.print_comm_stats()
|
||||
|
||||
|
||||
def element_size(self):
|
||||
return self.weight.element_size()
|
||||
return self.weight.element_size()
|
||||
|
|
|
@ -87,6 +87,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
local_per_sample_weights_list: List(torch.Tensor) = []
|
||||
|
||||
offset_pre_end = 0 # local_offsets trick
|
||||
|
||||
for i, handle_table in enumerate(self.assigned_table_list):
|
||||
indices_start_position = offsets[batch_size * handle_table]
|
||||
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
||||
|
@ -94,6 +95,28 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
indices_end_position = indices.shape[0]
|
||||
else:
|
||||
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
||||
# alternative approach: reduce malloc
|
||||
'''
|
||||
# 1. local_indices_list:
|
||||
local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
|
||||
torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
|
||||
local_indices_list.append(local_indices)
|
||||
# 2. local_offsets_list:
|
||||
if i + 1 == len(self.assigned_table_list):
|
||||
# till-the-end special case
|
||||
if not self.include_last_offset:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
|
||||
else:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1)
|
||||
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
|
||||
local_offsets_list.append(local_offsets)
|
||||
else:
|
||||
temp_holder = offsets[batch_size * handle_table].item()
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
|
||||
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
|
||||
offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
|
||||
local_offsets_list.append(local_offsets)
|
||||
'''
|
||||
# 1. local_indices_list:
|
||||
local_indices_list.append(
|
||||
indices.narrow(0, indices_start_position,
|
||||
|
@ -103,21 +126,20 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
# till-the-end special case
|
||||
if not self.include_last_offset:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
||||
batch_size).add(offset_pre_end - offsets[batch_size *
|
||||
(handle_table)])
|
||||
batch_size).add(offset_pre_end - offsets[batch_size
|
||||
* (handle_table)])
|
||||
else:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
|
||||
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
|
||||
+ 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
local_offsets_list.append(local_offsets)
|
||||
else:
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
|
||||
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
|
||||
+ 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||
offset_pre_end = local_offsets[-1]
|
||||
local_offsets_list.append(local_offsets[:-1])
|
||||
# 3. local_per_sample_weights_list:
|
||||
if per_sample_weights != None:
|
||||
local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
|
||||
|
||||
local_indices = torch.cat(local_indices_list, 0)
|
||||
local_offsets = torch.cat(local_offsets_list, 0)
|
||||
local_per_sample_weights = None
|
||||
|
|
Loading…
Reference in New Issue