You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/tensor/padded_tensor/api.py

129 lines
3.4 KiB

import torch
def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor._unpad_detach = ptensor.detach
ptensor._unpad_clone = ptensor.clone
def new_detach(self):
t_ = self._unpad_detach()
t_._padding_dim = self._padding_dim
t_._origin_length = self._origin_length
t_._current_length = self._current_length
return t_
def new_clone(self, *args, **kwargs):
t_ = self._unpad_clone(*args, **kwargs)
t_._padding_dim = self._padding_dim
t_._origin_length = self._origin_length
t_._current_length = self._current_length
return t_
# bind the new methods to the tensor
ptensor.detach = new_detach.__get__(ptensor)
ptensor.clone = new_clone.__get__(ptensor)
return ptensor
def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor.detach = ptensor._unpad_detach
ptensor.clone = ptensor._unpad_clone
delattr(ptensor, "_unpad_detach")
delattr(ptensor, "_unpad_clone")
return ptensor
def is_padded_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a padding tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a padding tensor.
"""
return hasattr(tensor, "_padding_dim")
def to_padded_tensor(
tensor: torch.Tensor,
current_length: int,
padding_dim: int,
) -> torch.Tensor:
assert (
padding_dim < tensor.dim()
), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}"
if is_padded_tensor(tensor):
return tensor
origin_length = tensor.shape[padding_dim]
padding_num = current_length - origin_length
padding_data = torch.zeros(
*tensor.shape[:padding_dim],
padding_num,
*tensor.shape[padding_dim + 1 :],
device=tensor.device,
dtype=tensor.dtype,
)
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()
tensor._padding_dim = padding_dim
tensor._origin_length = origin_length
tensor._current_length = current_length
_hijack_detach_and_clone(tensor)
return tensor
def to_unpadded_tensor(ptensor: torch.Tensor):
if not is_padded_tensor(ptensor):
return ptensor
unpad_slices = [slice(None)] * ptensor.dim()
unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length)
ptensor.data = ptensor.data[tuple(unpad_slices)]
delattr(ptensor, "_padding_dim")
delattr(ptensor, "_origin_length")
delattr(ptensor, "_current_length")
_hijack_back_detach_and_clone(ptensor)
return ptensor
def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
if is_padded_tensor(tensor):
return tensor
tensor._padding_dim = padding_dim
tensor._origin_length = origin_length
tensor._current_length = current_length
_hijack_detach_and_clone(tensor)
return tensor