Browse Source

colo init context add device attr. (#866)

pull/868/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
d01d3b8cb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      colossalai/tensor/colo_tensor.py
  2. 12
      colossalai/utils/model/colo_init_context.py
  3. 22
      tests/test_tensor/test_context.py

14
colossalai/tensor/colo_tensor.py

@ -58,6 +58,10 @@ class ColoTensor(object):
def shape(self):
return torch.Size(self._size)
@property
def device(self):
return self._torch_tensor.device
def size(self, dim=None):
if dim is None:
return self.shape
@ -105,14 +109,14 @@ class ColoTensor(object):
device=self._device)
return self._torch_tensor
def set_spec(self, spec: str, lazy_shard: bool=False) -> None:
def set_spec(self, spec: str, lazy_shard: bool = False) -> None:
self._shard_spec = spec
if lazy_shard == False:
self._shard()
def _shard(self):
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
dim = -1
@ -121,11 +125,11 @@ class ColoTensor(object):
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self._torch_tensor = self._torch_tensor.narrow(dim,
local_rank * chunk_size, chunk_size).detach().contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu
self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):

12
colossalai/utils/model/colo_init_context.py

@ -1,3 +1,4 @@
from colossalai.utils.cuda import get_current_device
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
# from colossalai.logging import get_dist_logger
@ -8,9 +9,15 @@ from colossalai.tensor import ColoTensor
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate=False):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
"""
Args:
lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False.
device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu').
"""
super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
@ -26,4 +33,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload = True if not self._lazy_memory_allocate else False
for name, param in name_list:
delattr(module, name)
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=save_torch_payload))
setattr(module, name,
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), save_payload=save_torch_payload))

22
tests/test_tensor/test_context.py

@ -5,17 +5,16 @@ import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
from colossalai.utils.cuda import get_current_device
def test_linear():
def test_lazy_init():
in_dim = 4
out_dim = 5
with ColoInitContext(lazy_memory_allocate=True) as ctx:
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
print(fc.weight.numel())
print(fc.bias.numel())
# lazy_memory_allocate=True, no payload is maintained
assert fc.weight._torch_tensor.numel() == 0
@ -23,5 +22,18 @@ def test_linear():
assert fc.weight._torch_tensor.numel() == in_dim * out_dim
def test_device():
in_dim = 4
out_dim = 5
with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx:
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
# eval an lazy parameter
fc.weight.torch_tensor()
assert fc.weight.device == get_current_device()
if __name__ == '__main__':
test_linear()
test_lazy_init()
test_device()

Loading…
Cancel
Save