mirror of https://github.com/hpcaitech/ColossalAI
[FAW] remove code related to chunk (#1501)
parent
d5085bb317
commit
ba61109b6c
|
@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self.num_hits_history = []
|
||||
self.num_miss_history = []
|
||||
self.num_write_back_history = []
|
||||
self.input_id_percent_in_load_chunk = []
|
||||
self._reset_comm_stats()
|
||||
|
||||
self._evict_strategy = evict_strategy
|
||||
|
@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module):
|
|||
# self.cuda_cached_weight = self.weight
|
||||
raise NotImplementedError()
|
||||
|
||||
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor:
|
||||
def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
|
||||
"""
|
||||
access a chunk of CPU weight.
|
||||
access a row of CPU weight.
|
||||
|
||||
Args:
|
||||
chunk_id (int): chunk id
|
||||
row_idx (int): the idx of rows
|
||||
|
||||
Returns:
|
||||
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D.
|
||||
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
|
||||
"""
|
||||
|
||||
return self.weight.data.view(-1).narrow(0,
|
||||
int(chunk_id) * self.embedding_dim,
|
||||
int(row_idx) * self.embedding_dim,
|
||||
self.embedding_dim).view(1, self.embedding_dim)
|
||||
|
||||
@property
|
||||
def cuda_available_chunk_num(self):
|
||||
def cuda_available_row_num(self):
|
||||
return self._cuda_available_row_num
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
||||
if preload_row_num > 0:
|
||||
with Timer() as timer:
|
||||
# extract chunks from cpu weight
|
||||
# extract rows from cpu weight
|
||||
preload_row_ids = torch.arange(preload_row_num)
|
||||
preload_slot_ids = preload_row_ids.cuda()
|
||||
|
||||
|
@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module):
|
|||
src=self.weight.view(self.num_embeddings, -1),
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks)
|
||||
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows)
|
||||
|
||||
# update auxiliary info
|
||||
slot_offsets = preload_slot_ids
|
||||
|
@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module):
|
|||
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
||||
|
||||
def flush(self):
|
||||
"""flush all CUDA chunks to CPU.
|
||||
"""flush all CUDA rows to CPU.
|
||||
The function is usually called after training finished.
|
||||
"""
|
||||
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
||||
chunk_ids = self.cached_idx_map[slots]
|
||||
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks)
|
||||
row_ids = self.cached_idx_map[slots]
|
||||
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
||||
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
|
||||
self.cached_idx_map.index_fill_(0, slots, -1)
|
||||
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1)
|
||||
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
|
||||
self._cuda_available_row_num += slots.numel()
|
||||
|
||||
assert self._cuda_available_row_num == self.cuda_row_num
|
||||
|
@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module):
|
|||
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
|
||||
|
||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
||||
f"which is larger than the presented {self.cuda_row_num}, " \
|
||||
f"please increase cuda_row_num shrink batch size"
|
||||
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
|
||||
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
|
||||
f"Please increase cuda_row_num or decrease the training batch size."
|
||||
self.evict_backlist = cpu_row_idxs
|
||||
|
||||
with record_function("(zhg) get cpu chunk indices"):
|
||||
with record_function("(zhg) get cpu row idxs"):
|
||||
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
||||
|
||||
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
||||
self.num_miss_history.append(len(comm_cpu_row_idxs))
|
||||
self.num_write_back_history.append(0)
|
||||
|
||||
# move sure the cuda chunk will not be evicted!
|
||||
# move sure the cuda rows will not be evicted!
|
||||
with record_function("(zhg) cache update"):
|
||||
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
||||
|
||||
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
||||
# new ids chunk_offset + offset_in_chunk
|
||||
with record_function("(zhg) embed idx -> cache chunk id"):
|
||||
|
||||
with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
|
||||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||
|
||||
# update for LFU.
|
||||
|
@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self._cuda_to_cpu_elapse = 0
|
||||
self._cuda_to_cpu_numel = 0
|
||||
|
||||
def _chunk_in_cuda(self, chunk_id: int) -> bool:
|
||||
return self.inverted_cached_idx[chunk_id] != -1
|
||||
def _row_in_cuda(self, row_id: int) -> bool:
|
||||
return self.inverted_cached_idx[row_id] != -1
|
||||
|
||||
@torch.no_grad()
|
||||
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
|
||||
"""prepare rows in cpu_row_idxs on CUDA memory
|
||||
|
||||
Args:
|
||||
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
|
||||
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
|
||||
"""
|
||||
evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num
|
||||
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
|
||||
if evict_num > 0:
|
||||
with Timer() as timer:
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
|
@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
"""
|
||||
deprecated
|
||||
|
||||
evict one chunk from cuda to cpu.
|
||||
evict one row from cuda to cpu.
|
||||
Returns:
|
||||
(int) : the slot id be evicted.
|
||||
"""
|
||||
|
|
|
@ -119,8 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
|
||||
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
|
||||
self.cache_weight_mgr._cuda_to_cpu_elapse
|
||||
return 0
|
||||
|
||||
@property
|
||||
def input_id_percent_in_load_chunk(self):
|
||||
return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100
|
||||
return 0
|
Loading…
Reference in New Issue