diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 206388b2a..8d67d6f69 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -58,6 +58,10 @@ class ColoTensor(object): def shape(self): return torch.Size(self._size) + @property + def device(self): + return self._torch_tensor.device + def size(self, dim=None): if dim is None: return self.shape @@ -105,14 +109,14 @@ class ColoTensor(object): device=self._device) return self._torch_tensor - def set_spec(self, spec: str, lazy_shard: bool=False) -> None: + def set_spec(self, spec: str, lazy_shard: bool = False) -> None: self._shard_spec = spec if lazy_shard == False: self._shard() def _shard(self): assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.' - if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now. + if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now. num_partition = gpc.get_world_size(ParallelMode.TENSOR) local_rank = gpc.get_local_rank(ParallelMode.TENSOR) dim = -1 @@ -121,11 +125,11 @@ class ColoTensor(object): # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. - self._torch_tensor = self._torch_tensor.narrow(dim, - local_rank * chunk_size, chunk_size).detach().contiguous() # TODO Shall we clone() here since detach() will point to the old tensor? + self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach( + ).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor? self._torch_tensor.requires_grad = self._requires_grad self._size = self._torch_tensor.size() - self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu + self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index d6cb197eb..8c911b801 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,3 +1,4 @@ +from colossalai.utils.cuda import get_current_device from .utils import InsertPostInitMethodToModuleSubClasses import torch # from colossalai.logging import get_dist_logger @@ -8,9 +9,15 @@ from colossalai.tensor import ColoTensor class ColoInitContext(InsertPostInitMethodToModuleSubClasses): - def __init__(self, lazy_memory_allocate=False): + def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): + """ + Args: + lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False. + device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu'). + """ super().__init__() self._lazy_memory_allocate = lazy_memory_allocate + self._device = device def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ @@ -26,4 +33,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): save_torch_payload = True if not self._lazy_memory_allocate else False for name, param in name_list: delattr(module, name) - setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=save_torch_payload)) + setattr(module, name, + ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), save_payload=save_torch_payload)) diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 3c998fa66..84c0fff39 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -5,17 +5,16 @@ import torch from colossalai.tensor import ColoTensor from copy import deepcopy +from colossalai.utils.cuda import get_current_device -def test_linear(): + +def test_lazy_init(): in_dim = 4 out_dim = 5 with ColoInitContext(lazy_memory_allocate=True) as ctx: fc = torch.nn.Linear(in_dim, out_dim, bias=True) - print(fc.weight.numel()) - print(fc.bias.numel()) - # lazy_memory_allocate=True, no payload is maintained assert fc.weight._torch_tensor.numel() == 0 @@ -23,5 +22,18 @@ def test_linear(): assert fc.weight._torch_tensor.numel() == in_dim * out_dim +def test_device(): + in_dim = 4 + out_dim = 5 + + with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx: + fc = torch.nn.Linear(in_dim, out_dim, bias=True) + + # eval an lazy parameter + fc.weight.torch_tensor() + assert fc.weight.device == get_current_device() + + if __name__ == '__main__': - test_linear() + test_lazy_init() + test_device()