diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index f40034dc1..eb6b8b128 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -20,21 +20,27 @@ class ColoTensor(object): dtype=None, requires_grad=False, pin_memory=False, + device=None, torch_tensor=torch.empty(0), ): self._size = size self._dtype = dtype self._requires_grad = requires_grad self._pin_memory = pin_memory + self._device = device self._torch_tensor = torch_tensor + def numel(self): + return sum(self._size) + @staticmethod - def init_from_torch_tensor(tensor: torch.Tensor): + def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor': colo_t = ColoTensor(*tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, - pin_memory=tensor.pin_memory, - torch_tensor=tensor) + pin_memory=tensor.is_pinned(), + device=tensor.device, + torch_tensor=tensor if save_payload else torch.empty(0)) return colo_t def del_torch_tensor(self) -> None: @@ -42,12 +48,12 @@ class ColoTensor(object): self._torch_tensor = torch.empty(self._size) def torch_tensor(self) -> torch.Tensor: - if self._torch_tensor == None or self._torch_tensor.numel() == 0: - print(self._size, type(self._size)) + if self._torch_tensor.numel() == 0: self._torch_tensor = torch.empty(*self._size, dtype=self._dtype, + pin_memory=self._pin_memory, requires_grad=self._requires_grad, - pin_memory=self._pin_memory) + device=self._device) return self._torch_tensor @classmethod @@ -67,7 +73,5 @@ class ColoTensor(object): if kwargs is None: kwargs = {} - kwargs = { - k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items() - } + kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} return func(*args, **kwargs)