Browse Source

[FAW] remove code related to chunk (#1501)

pull/1504/head
Jiarui Fang 2 years ago committed by GitHub
parent
commit
ba61109b6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 53
      colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
  2. 6
      colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py

53
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py

@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module):
self.num_hits_history = [] self.num_hits_history = []
self.num_miss_history = [] self.num_miss_history = []
self.num_write_back_history = [] self.num_write_back_history = []
self.input_id_percent_in_load_chunk = []
self._reset_comm_stats() self._reset_comm_stats()
self._evict_strategy = evict_strategy self._evict_strategy = evict_strategy
@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module):
# self.cuda_cached_weight = self.weight # self.cuda_cached_weight = self.weight
raise NotImplementedError() raise NotImplementedError()
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor: def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
""" """
access a chunk of CPU weight. access a row of CPU weight.
Args: Args:
chunk_id (int): chunk id row_idx (int): the idx of rows
Returns: Returns:
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D. torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
""" """
return self.weight.data.view(-1).narrow(0, return self.weight.data.view(-1).narrow(0,
int(chunk_id) * self.embedding_dim, int(row_idx) * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim) self.embedding_dim).view(1, self.embedding_dim)
@property @property
def cuda_available_chunk_num(self): def cuda_available_row_num(self):
return self._cuda_available_row_num return self._cuda_available_row_num
@torch.no_grad() @torch.no_grad()
@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module):
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
if preload_row_num > 0: if preload_row_num > 0:
with Timer() as timer: with Timer() as timer:
# extract chunks from cpu weight # extract rows from cpu weight
preload_row_ids = torch.arange(preload_row_num) preload_row_ids = torch.arange(preload_row_num)
preload_slot_ids = preload_row_ids.cuda() preload_slot_ids = preload_row_ids.cuda()
@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module):
src=self.weight.view(self.num_embeddings, -1), src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else: else:
preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks) self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows)
# update auxiliary info # update auxiliary info
slot_offsets = preload_slot_ids slot_offsets = preload_slot_ids
@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module):
print(f'Cache warmup finished cost {timer.elapsed} sec.') print(f'Cache warmup finished cost {timer.elapsed} sec.')
def flush(self): def flush(self):
"""flush all CUDA chunks to CPU. """flush all CUDA rows to CPU.
The function is usually called after training finished. The function is usually called after training finished.
""" """
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
chunk_ids = self.cached_idx_map[slots] row_ids = self.cached_idx_map[slots]
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks) self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
self.cached_idx_map.index_fill_(0, slots, -1) self.cached_idx_map.index_fill_(0, slots, -1)
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1) self.inverted_cached_idx.index_fill_(0, row_ids, -1)
self._cuda_available_row_num += slots.numel() self._cuda_available_row_num += slots.numel()
assert self._cuda_available_row_num == self.cuda_row_num assert self._cuda_available_row_num == self.cuda_row_num
@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs = torch.unique(cpu_row_idxs_original) cpu_row_idxs = torch.unique(cpu_row_idxs_original)
assert len(cpu_row_idxs) <= self.cuda_row_num, \ assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"the input indices pull {len(cpu_row_idxs)} chunks, " \ f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
f"which is larger than the presented {self.cuda_row_num}, " \ f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
f"please increase cuda_row_num shrink batch size" f"Please increase cuda_row_num or decrease the training batch size."
self.evict_backlist = cpu_row_idxs self.evict_backlist = cpu_row_idxs
with record_function("(zhg) get cpu chunk indices"): with record_function("(zhg) get cpu row idxs"):
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs))
self.num_write_back_history.append(0) self.num_write_back_history.append(0)
# move sure the cuda chunk will not be evicted! # move sure the cuda rows will not be evicted!
with record_function("(zhg) cache update"): with record_function("(zhg) cache update"):
self._prepare_rows_on_cuda(comm_cpu_row_idxs) self._prepare_rows_on_cuda(comm_cpu_row_idxs)
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
# new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"): with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids) gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU. # update for LFU.
@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module):
self._cuda_to_cpu_elapse = 0 self._cuda_to_cpu_elapse = 0
self._cuda_to_cpu_numel = 0 self._cuda_to_cpu_numel = 0
def _chunk_in_cuda(self, chunk_id: int) -> bool: def _row_in_cuda(self, row_id: int) -> bool:
return self.inverted_cached_idx[chunk_id] != -1 return self.inverted_cached_idx[row_id] != -1
@torch.no_grad() @torch.no_grad()
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
"""prepare rows in cpu_row_idxs on CUDA memory """prepare rows in cpu_row_idxs on CUDA memory
Args: Args:
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
""" """
evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
if evict_num > 0: if evict_num > 0:
with Timer() as timer: with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module):
""" """
deprecated deprecated
evict one chunk from cuda to cpu. evict one row from cuda to cpu.
Returns: Returns:
(int) : the slot id be evicted. (int) : the slot id be evicted.
""" """

6
colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py

@ -119,8 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
if self.cache_weight_mgr._cuda_to_cpu_numel > 0: if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cuda_to_cpu_elapse self.cache_weight_mgr._cuda_to_cpu_elapse
return 0 return 0
@property
def input_id_percent_in_load_chunk(self):
return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100
Loading…
Cancel
Save