You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/nn/parallel/layers/cache_embedding/copyer.py

50 lines
2.0 KiB

import torch
from torch import LongTensor
class LimitBuffIndexCopyer(object):
"""LimitBuffIndexCopyer
Index Copy using limited temp buffer on CUDA.
Args:
size (int): buffer size
"""
def __init__(self, size: int) -> None:
self._buff_size = size
@torch.no_grad()
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
"""copy
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
The valid rows in the src tensor are continous, while rows in tgt tensor is scattered.
Args:
dim (int): dimension along which to index
src_index (int): indices of src tensor to select from
tgt_index (int): indices of tgt tensor to select from
src (torch.Tensor): the tensor containing values to copy
tgt (torch.Tensor): the tensor to be copied
"""
# tgt.index_copy_(dim, index, src)
assert dim == 0, "only support index_copy on dim 0"
assert tgt.dim() == 2
assert src.dim() == 2
tgt_device = tgt.device
src_device = src.device
assert src_index.numel() == tgt_index.numel()
dim_size = src_index.numel()
src_index = src_index.to(src_device)
for begin_pos in range(0, dim_size, self._buff_size):
cur_len = min(self._buff_size, dim_size - begin_pos)
src_idx_piece = src_index.narrow(0, begin_pos, cur_len)
if src_device.type == 'cpu' and tgt_device.type == 'cuda':
cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory()
tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device)
tmp_buffer.copy_(cpu_tmp_buffer)
else:
tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device)
tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len)
tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer)