mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
2.0 KiB
50 lines
2.0 KiB
import torch
|
|
from torch import LongTensor
|
|
|
|
|
|
class LimitBuffIndexCopyer(object):
|
|
"""LimitBuffIndexCopyer
|
|
Index Copy using limited temp buffer on CUDA.
|
|
|
|
Args:
|
|
size (int): buffer size
|
|
"""
|
|
|
|
def __init__(self, size: int) -> None:
|
|
self._buff_size = size
|
|
|
|
@torch.no_grad()
|
|
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
|
|
"""copy
|
|
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
|
|
The valid rows in the src tensor are continous, while rows in tgt tensor is scattered.
|
|
|
|
Args:
|
|
dim (int): dimension along which to index
|
|
src_index (int): indices of src tensor to select from
|
|
tgt_index (int): indices of tgt tensor to select from
|
|
src (torch.Tensor): the tensor containing values to copy
|
|
tgt (torch.Tensor): the tensor to be copied
|
|
"""
|
|
# tgt.index_copy_(dim, index, src)
|
|
assert dim == 0, "only support index_copy on dim 0"
|
|
assert tgt.dim() == 2
|
|
assert src.dim() == 2
|
|
tgt_device = tgt.device
|
|
src_device = src.device
|
|
|
|
assert src_index.numel() == tgt_index.numel()
|
|
dim_size = src_index.numel()
|
|
src_index = src_index.to(src_device)
|
|
for begin_pos in range(0, dim_size, self._buff_size):
|
|
cur_len = min(self._buff_size, dim_size - begin_pos)
|
|
src_idx_piece = src_index.narrow(0, begin_pos, cur_len)
|
|
if src_device.type == 'cpu' and tgt_device.type == 'cuda':
|
|
cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory()
|
|
tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device)
|
|
tmp_buffer.copy_(cpu_tmp_buffer)
|
|
else:
|
|
tmp_buffer = src.index_select(dim, src_idx_piece).to(tgt_device)
|
|
tgt_idx_piece = tgt_index.narrow(0, begin_pos, cur_len)
|
|
tgt.index_copy_(dim, tgt_idx_piece, tmp_buffer)
|