[embeddings] more detailed timer (#1692)

pull/1688/head^2
Jiarui Fang 2 years ago committed by GitHub
parent 4973157ad7
commit 363fc2861a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -91,6 +91,7 @@ class CachedParamMgr(torch.nn.Module):
dtype=torch.long).fill_(sys.maxsize),
persistent=False)
self._elapsed_dict = {}
self._show_cache_miss = True
self._reset_comm_stats()
def _reset_comm_stats(self):
@ -99,6 +100,9 @@ class CachedParamMgr(torch.nn.Module):
self._cpu_to_cuda_numel = 0
self._cuda_to_cpu_numel = 0
if self._show_cache_miss:
self._cache_miss = 0
self._total_cache = 0
@contextmanager
def timer(self, name):
@ -268,6 +272,10 @@ class CachedParamMgr(torch.nn.Module):
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
self._cuda_available_row_num += slots.numel()
if self._show_cache_miss:
self._cache_miss = 0
self._total_cache = 0
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter.fill_(sys.maxsize)
assert self._cuda_available_row_num == self.cuda_row_num
@ -275,14 +283,14 @@ class CachedParamMgr(torch.nn.Module):
assert torch.all(self.cached_idx_map == -1).item()
def print_comm_stats(self):
if self._cuda_to_cpu_numel > 0 and "3_2_2_evict_out_gpu_to_cpu_copy" in self._elapsed_dict:
elapsed = self._elapsed_dict["3_2_2_evict_out_gpu_to_cpu_copy"]
if self._cuda_to_cpu_numel > 0 and "3_evict_out" in self._elapsed_dict:
elapsed = self._elapsed_dict["3_evict_out"]
print(
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
)
print(f'cuda_to_cpu_elapse {elapsed} sec')
if self._cpu_to_cuda_numel > 0 and "3_4_2_evict_in_gpu_to_cpu_copy" in self._elapsed_dict:
elapsed = self._elapsed_dict["3_4_2_evict_in_gpu_to_cpu_copy"]
if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict:
elapsed = self._elapsed_dict["5_evict_in"]
print(
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
)
@ -291,6 +299,8 @@ class CachedParamMgr(torch.nn.Module):
for k, v in self._elapsed_dict.items():
print(f'{k}: {v}')
print(f'cache miss ratio {self._cache_miss / self._total_cache}')
@torch.no_grad()
def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor:
"""
@ -315,41 +325,45 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
"""
torch.cuda.synchronize()
with self.timer("1_unique_indices") as timer:
with record_function("(cache) get unique indices"):
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
f"Please increase cuda_row_num or decrease the training batch size."
self.evict_backlist = cpu_row_idxs
torch.cuda.synchronize()
# O(cache ratio)
with self.timer("2_cpu_row_idx") as timer:
with record_function("(cache) get cpu row idxs"):
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs))
self.num_write_back_history.append(0)
# move sure the cuda rows will not be evicted!
with self.timer("3_prepare_rows_on_cuda") as timer:
with self.timer("cache_op") as gtimer:
# identify cpu rows to cache
with self.timer("1_identify_cpu_row_idxs") as timer:
with record_function("(cache) get unique indices"):
if self._evict_strategy == EvictionStrategy.LFU:
cpu_row_idxs, repeat_times = torch.unique(ids, return_counts=True)
else:
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
f"Please increase cuda_row_num or decrease the training batch size."
self.evict_backlist = cpu_row_idxs
tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)
comm_cpu_row_idxs = cpu_row_idxs[tmp]
if self._show_cache_miss:
self._cache_miss += torch.sum(repeat_times[tmp])
self._total_cache += ids.numel()
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs))
self.num_write_back_history.append(0)
# move sure the cuda rows will not be evicted!
with record_function("(cache) prepare_rows_on_cuda"):
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
with self.timer("4_cpu_to_gpu_row_idxs") as timer:
with record_function("(cache) embed cpu rows idx -> cache gpu row idxs"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
with self.timer("6_update_cache") as timer:
with record_function("6_update_cache"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU.
if self._evict_strategy == EvictionStrategy.LFU:
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
# update for LFU.
if self._evict_strategy == EvictionStrategy.LFU:
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
return gpu_row_idxs
@ -377,8 +391,7 @@ class CachedParamMgr(torch.nn.Module):
raise NotImplemented
if evict_num > 0:
torch.cuda.synchronize()
with self.timer("3_1_evict_prepare") as timer:
with self.timer("2_identify_cuda_row_idxs") as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
if self._evict_strategy == EvictionStrategy.DATASET:
@ -388,7 +401,7 @@ class CachedParamMgr(torch.nn.Module):
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer:
with self.timer("2_1_find_evict_gpu_idxs") as timer:
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
# move evict out rows to cpu
@ -401,11 +414,11 @@ class CachedParamMgr(torch.nn.Module):
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU:
with self.timer("3_1_0_backup_freqs") as timer:
with self.timer("2_1_backup_freqs") as timer:
backup_freqs = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer:
with self.timer("2_2_find_evict_gpu_idxs") as timer:
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
if self._async_copy:
@ -414,12 +427,13 @@ class CachedParamMgr(torch.nn.Module):
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
with torch.cuda.stream(None):
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
with self.timer("3_1_2_find_evict_index_copy") as timer:
with self.timer("2_3_revert_freqs") as timer:
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
with self.timer("3_2_evict_out_elapse") as timer:
with self.timer("3_evict_out") as timer:
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=evict_gpu_row_idxs,
@ -432,13 +446,13 @@ class CachedParamMgr(torch.nn.Module):
if self._async_copy:
_wait_for_data(evict_out_rows_cpu, None)
else:
with self.timer("3_2_1_evict_out_index_select") as timer:
with self.timer("3_1_evict_out_index_select") as timer:
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs)
with self.timer("3_2_2_evict_out_gpu_to_cpu_copy") as timer:
with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer:
evict_out_rows_cpu = evict_out_rows_cpu.cpu()
with self.timer("3_2_2_evict_out_index_select") as timer:
with self.timer("3_2_evict_out_cpu_copy") as timer:
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
@ -447,15 +461,15 @@ class CachedParamMgr(torch.nn.Module):
self._cuda_available_row_num += evict_num
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
self._cuda_to_cpu_numel += weight_size
self._cuda_to_cpu_numel += weight_size
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
# slots of cuda weight to evict in
with self.timer("3_3_non_zero") as timer:
with self.timer("4_identify_cuda_slot") as timer:
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()]
# TODO wait for optimize
with self.timer("3_4_evict_in_elapse") as timer:
with self.timer("5_evict_in") as timer:
# Here also allocate extra memory on CUDA. #cpu_row_idxs
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
@ -467,20 +481,20 @@ class CachedParamMgr(torch.nn.Module):
if self._async_copy:
_wait_for_data(evict_in_rows_gpu, self._memcpy_stream)
else:
with self.timer("3_4_1_evict_in_index_select") as timer:
with self.timer("5_1_evict_in_index_select") as timer:
# narrow index select to a subset of self.weight
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1)
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu())
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
with self.timer("3_4_2_evict_in_gpu_to_cpu_copy") as timer:
with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer:
evict_in_rows_gpu = evict_in_rows_gpu.cuda()
with self.timer("3_4_3_evict_in_index_copy") as timer:
with self.timer("5_3_evict_in_index_copy") as timer:
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
with self.timer("3_5_evict_in_elapse_final") as timer:
with self.timer("6_update_cache") as timer:
self.cached_idx_map[slots] = cpu_row_idxs
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slots)
if self._evict_strategy == EvictionStrategy.LFU:

Loading…
Cancel
Save