[FAW] add cache manager for the cached embedding (#1419)

pull/1421/head^2
Jiarui Fang 2022-08-09 15:17:17 +08:00 committed by GitHub
parent 44fd3c83ab
commit 504419d261
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 514 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from .cache_mgr import CachedParamMgr
from .copyer import LimitBuffIndexCopyer
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer']

View File

@ -0,0 +1,36 @@
import abc
import torch.nn as nn
class BaseEmbeddingBag(abc.ABC, nn.Module):
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
mode='mean',
include_last_offset=False,
):
super(BaseEmbeddingBag, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Specific to embedding bag
self.mode = mode
self.include_last_offset = include_last_offset

View File

@ -0,0 +1,348 @@
import numpy as np
import torch
from torch.profiler import record_function
from typing import List, Optional
from contexttimer import Timer
from .copyer import LimitBuffIndexCopyer
class CachedParamMgr(torch.nn.Module):
"""
Manage Embedding Weights in Cache on CPU and CUDA memory.
CPU maintains entire original weight.
CUDA maintains a fraction of weights used in the upcomming computation.
During training, GPU needs to transmit rows between CPU and GPU.
"""
def __init__(self, weight: torch.Tensor, cuda_row_num: int = 0, buffer_size: int = 50_000) -> None:
super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size
self.num_embeddings, self.embedding_dim = weight.shape
self.cuda_row_num = cuda_row_num
self._cuda_available_row_num = self.cuda_row_num
self.elem_size_in_byte = weight.element_size()
self.cuda_cached_weight = torch.nn.Parameter(
torch.zeros(self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype))
if weight.device.type == 'cuda':
weight = weight.cpu()
# pin memory cpu for higher CPU-GPU copy bandwidth
self.cpu_weight = weight.contiguous().pin_memory()
# map original id to new id with respect to frequency
# id -> cpu_row_idx
self.register_buffer(
"idx_map",
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
persistent=False,
)
# cached_idx_map: gpu_row_idx -> cpu_row_idx
self.register_buffer("cached_idx_map",
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
dtype=torch.long).fill_(-1),
persistent=False)
# cpu_row_id -> gpu_row_idx.
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
self.register_buffer("inverted_cached_idx",
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
dtype=torch.long).fill_(-1),
persistent=False)
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
# index copy buffer size should less than 10% of cuda weight.
if self.buffer_size > 0:
self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size)
self.num_hits_history = []
self.num_miss_history = []
self.num_write_back_history = []
self.input_id_percent_in_load_chunk = []
self._reset_comm_stats()
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor:
"""
access a chunk of CPU weight.
Args:
chunk_id (int): chunk id
Returns:
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D.
"""
return self.cpu_weight.data.view(-1).narrow(0,
int(chunk_id) * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
@property
def cuda_available_chunk_num(self):
return self._cuda_available_row_num
@torch.no_grad()
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
"""reorder the cpu_weight according to ids' frequency in dataset before training.
Also Build the IndexMappingTable, aka index_mapping_table.
Execute only once before training.
Args:
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
if ids_freq_mapping is not None:
tmp_idx = torch.argsort(torch.from_numpy(ids_freq_mapping).cuda(), descending=True)
sorted_idx = torch.argsort(tmp_idx)
self.idx_map.data.copy_(sorted_idx)
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# As cuda_cached_weight is very big. You may not have that much available memory!
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
if preload_row_num > 0:
with Timer() as timer:
# extract chunks from cpu weight
preload_row_ids = torch.arange(preload_row_num)
preload_slot_ids = preload_row_ids.cuda()
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=preload_row_ids,
tgt_index=preload_slot_ids,
src=self.cpu_weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else:
preload_chunks = self.cpu_weight.view(self.num_embeddings, -1).index_select(0,
preload_row_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks)
# update auxiliary info
slot_offsets = preload_slot_ids
self.cached_idx_map[preload_slot_ids] = preload_slot_ids
self.inverted_cached_idx[preload_slot_ids] = slot_offsets
self._cuda_available_row_num -= preload_row_num
print(f'Cache warmup finished cost {timer.elapsed} sec.')
def flush(self):
"""flush all CUDA chunks to CPU.
The function is usually called after training finished.
"""
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
chunk_ids = self.cached_idx_map[slots]
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks)
self.cached_idx_map.index_fill_(0, slots, -1)
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1)
self._cuda_available_row_num += slots.numel()
assert self._cuda_available_row_num == self.cuda_row_num
assert torch.all(self.inverted_cached_idx == -1).item()
assert torch.all(self.cached_idx_map == -1).item()
def print_comm_stats(self):
if self._cuda_to_cpu_numel > 0:
print(
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / self._cuda_to_cpu_elapse} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
)
if self._cpu_to_cuda_numel > 0:
print(
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / self._cpu_to_cuda_elpase} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
)
@torch.no_grad()
def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor:
"""
convert ids to indices in self.cuda_cached_weight.
Implemented with parallel operations on GPU.
Args:
ids (torch.Tensor): ids from the dataset
Returns:
torch.Tensor: contains indices in self.cuda_cached_weight
"""
ids = self.idx_map.index_select(0, ids.view(-1))
ret = self.inverted_cached_idx.index_select(0, ids)
return ret
@torch.no_grad()
def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor:
"""
move the cpu embedding rows w.r.t. ids into CUDA memory
Args:
ids (torch.Tensor): the ids to be computed
Returns:
torch.Tensor: indices on the cuda_cached_weight.
"""
with record_function("(zhg) get unique indices"):
cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids))
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
f"which is larger than the presented {self.cuda_row_num}, " \
f"please increase cuda_row_num shrink batch size"
self.evict_backlist = cpu_row_idxs
with record_function("(zhg) get cpu chunk indices"):
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs))
self.num_write_back_history.append(0)
# move sure the cuda chunk will not be evicted!
with record_function("(zhg) cache update"):
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
# new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
return gpu_row_idxs
def _reset_comm_stats(self):
self._cpu_to_cuda_numel = 0
self._cpu_to_cuda_elpase = 0
self._cuda_to_cpu_elapse = 0
self._cuda_to_cpu_numel = 0
def _chunk_in_cuda(self, chunk_id: int) -> bool:
return self.inverted_cached_idx[chunk_id] != -1
@torch.no_grad()
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
"""prepare rows in cpu_row_idxs on CUDA memory
Args:
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
"""
evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num
if evict_num > 0:
with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=evict_gpu_row_idxs,
tgt_index=evict_info.cpu(),
src=self.cuda_cached_weight.view(self.cuda_row_num, -1),
tgt=self.cpu_weight.view(self.num_embeddings, -1))
else:
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
self._cuda_available_row_num += evict_num
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
self._cuda_to_cpu_elapse += timer.elapsed
self._cuda_to_cpu_numel += weight_size
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
with Timer() as timer:
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()]
# Here also allocate extra memory on CUDA. #cpu_row_idxs
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=cpu_row_idxs.cpu(),
tgt_index=slots,
src=self.cpu_weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else:
rows = self.cpu_weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
slot_offsets = slots
self.cached_idx_map[slots] = cpu_row_idxs
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
self._cuda_available_row_num -= cpu_row_idxs.numel()
self._cpu_to_cuda_elpase += timer.elapsed
weight_size = cpu_row_idxs.numel() * self.embedding_dim
self._cpu_to_cuda_numel += weight_size
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
def _evict(self) -> int:
"""
evict one chunk from cuda to cpu.
Returns:
(int) : the slot id be evicted.
"""
mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1)
buf = self.cached_idx_map[mask].clone()
idx = torch.nonzero(mask).squeeze(1)
self.cached_idx_map.index_fill_(0, idx, -1)
max_row, max_cpu_row_idx = torch.max(self.cached_idx_map, dim=0)
max_gpu_row_idx = self.cached_idx_map[max_cpu_row_idx]
if max_gpu_row_idx == -1:
raise RuntimeError("Can not evict a row")
max_gpu_row_idx = max_gpu_row_idx.item()
max_offset = self.inverted_cached_idx[max_gpu_row_idx]
# recover
self.cached_idx_map.index_copy_(0, idx, buf)
with Timer() as timer:
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor)
# update inverted_cached_idx, min_slot_id is evicted from cuda
self.cached_idx_map[max_cpu_row_idx] = -1
self.inverted_cached_idx[max_gpu_row_idx] = -1
self._cuda_available_row_num += 1
self._cuda_to_cpu_numel += self.embedding_dim
self._cuda_to_cpu_elapse += timer.elapsed
# self.num_write_back_history[-1] += 1
return max_cpu_row_idx
def _find_free_cuda_row(self) -> int:
if self._cuda_available_row_num == 0:
return -1
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
return candidates[0].item()
@torch.no_grad()
def _admit(self, row_id: int):
"""
move in row_id to CUDA
Args:
row_id (int): the id of row to be moved in
"""
# find a free slot in partial cuda weight
slot_id = self._find_free_cuda_row()
if slot_id == -1:
# evict one row
slot_id = self._evict()
slot_offset = slot_id
# copy payload from cpu to cuda
with Timer() as timer:
cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
cuda_tensor.data.copy_(self.cpu_weight_data(row_id))
# update the inverted_cached_idx
self.cached_idx_map[slot_id] = row_id
self.inverted_cached_idx[row_id] = slot_offset
self._cuda_available_row_num -= 1
self._cpu_to_cuda_numel += self.embedding_dim
self._cpu_to_cuda_elpase += timer.elapsed

View File

@ -0,0 +1,48 @@
import torch
from torch import LongTensor
class LimitBuffIndexCopyer(object):
"""LimitBuffIndexCopyer
Index Copy using limited temp buffer on CUDA.
Args:
size (int): buffer size
"""
def __init__(self, size: int) -> None:
self._buff_size = size
@torch.no_grad()
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
"""copy
src tensor[src_index] -(index_select)-> tmp -()-> tgt tensor [tgt_index]
The valid part in src is continous, while in tgt is scatter.
Args:
dim (int): dimension along which to index
src_index (int): indices of src tensor to select from
tgt_index (int): indices of tgt tensor to select from
src (torch.Tensor): the tensor containing values to copy
tgt (torch.Tensor): the tensor to be copied
"""
# tgt.index_copy_(dim, index, src)
assert dim == 0, "only support index_copy on dim 0"
assert tgt.dim() == 2
assert src.dim() == 2
tgt_device = tgt.device
src_device = src.device
assert src_index.numel() == tgt_index.numel()
dim_size = src_index.numel()
src_index = src_index.to(src_device)
for begin_pos in range(0, dim_size, self._buff_size):
cur_len = min(self._buff_size, dim_size - begin_pos)
src_idx_piece = src_index.narrow(0, begin_pos, cur_len)
if src_device.type == 'cpu' and tgt_device.type == 'cuda':
cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory()
tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device)
tmp_buffer.copy_(cpu_tmp_buffer)
else:
tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device)
tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len)
tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer)

View File

@ -5,3 +5,4 @@ timm
titans
torchaudio
torchrec
contexttimer

View File

@ -7,3 +7,4 @@ pre-commit
rich
click
fabric
contexttimer

View File

@ -0,0 +1,76 @@
import pytest
from functools import partial
import torch
import torch.multiprocessing as mp
import numpy as np
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.nn._ops.cache_embedding import CachedParamMgr
NUM_EMBED, EMBED_DIM = 100, 8
BATCH_SIZE = 8
def test_cachemgr():
model = torch.nn.EmbeddingBag(10000, 128)
# 10 chunks, 5 in cuda
mgr = CachedParamMgr(model.weight, 5)
assert mgr.cuda_row_num == 5
mgr._admit(1)
assert not mgr._chunk_in_cuda(2)
assert mgr._chunk_in_cuda(1)
# print(mgr.cached_chunk_table)
mgr._admit(8)
# now 3 chunk is available
assert mgr.cuda_available_chunk_num == 3
mgr._evict()
assert mgr.cuda_available_chunk_num == 4
mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0))
mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0))
# print(mgr.cached_chunk_table)
# mgr.print_comm_stats()
mgr.flush()
assert mgr.cuda_available_chunk_num == 5
def test_reorder_with_freq():
num_embed = 100
chunk_size = 1
num_chunk = 5
idx_map = np.random.randint(10000, size=(num_embed,))
sorted_idx = np.flipud(np.argsort(idx_map)).tolist()
chunkid, offset_in_chunk = [], []
for i in range(num_embed):
idx = sorted_idx.index(i)
chunkid.append(idx // chunk_size)
offset_in_chunk.append(idx % chunk_size)
chunkid = torch.tensor(chunkid, dtype=torch.long, device=torch.cuda.current_device())
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=torch.cuda.current_device())
weight = torch.rand(num_embed, 2)
mgr = CachedParamMgr(weight, num_chunk)
mgr.reorder(idx_map)
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=torch.cuda.current_device()))
mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor')
mgr_offsets = torch.remainder(indices, chunk_size)
assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
assert torch.allclose(offset_in_chunk, mgr_offsets), \
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
if __name__ == '__main__':
# test_freq_aware_embed()
# test_chunkmgr_admit()
pass