[hotfix] ColoTensor pin_memory (#840)

pull/842/head^2
Jiarui Fang 2022-04-22 17:07:46 +08:00 committed by GitHub
parent 9f6f656952
commit 4575a3298b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 9 deletions

View File

@ -20,21 +20,27 @@ class ColoTensor(object):
dtype=None, dtype=None,
requires_grad=False, requires_grad=False,
pin_memory=False, pin_memory=False,
device=None,
torch_tensor=torch.empty(0), torch_tensor=torch.empty(0),
): ):
self._size = size self._size = size
self._dtype = dtype self._dtype = dtype
self._requires_grad = requires_grad self._requires_grad = requires_grad
self._pin_memory = pin_memory self._pin_memory = pin_memory
self._device = device
self._torch_tensor = torch_tensor self._torch_tensor = torch_tensor
def numel(self):
return sum(self._size)
@staticmethod @staticmethod
def init_from_torch_tensor(tensor: torch.Tensor): def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(), colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype, dtype=tensor.dtype,
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
pin_memory=tensor.pin_memory, pin_memory=tensor.is_pinned(),
torch_tensor=tensor) device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_t return colo_t
def del_torch_tensor(self) -> None: def del_torch_tensor(self) -> None:
@ -42,12 +48,12 @@ class ColoTensor(object):
self._torch_tensor = torch.empty(self._size) self._torch_tensor = torch.empty(self._size)
def torch_tensor(self) -> torch.Tensor: def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor == None or self._torch_tensor.numel() == 0: if self._torch_tensor.numel() == 0:
print(self._size, type(self._size))
self._torch_tensor = torch.empty(*self._size, self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype, dtype=self._dtype,
pin_memory=self._pin_memory,
requires_grad=self._requires_grad, requires_grad=self._requires_grad,
pin_memory=self._pin_memory) device=self._device)
return self._torch_tensor return self._torch_tensor
@classmethod @classmethod
@ -67,7 +73,5 @@ class ColoTensor(object):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
kwargs = { kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
}
return func(*args, **kwargs) return func(*args, **kwargs)