From c6930d8ddfc5e288375da8d0bb3f6a97d7f3f708 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Sun, 24 Apr 2022 18:31:22 +0800 Subject: [PATCH] [pipelinable]use ColoTensor to replace dummy tensor. (#853) --- colossalai/tensor/colo_tensor.py | 16 ++++++++++++++++ colossalai/utils/model/pipelinable.py | 12 ++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 9a99d4c7d..908de7afb 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -53,6 +53,22 @@ class ColoTensor(object): def size(self): return self._size + @property + def shape(self): + return torch.Size(self._size) + + def size(self, dim=None): + if dim is None: + return self.shape + return self._size[dim] + + def dim(self): + return len(self._size) + + def normal_(self, mean=0., std=1.): + torch_tensor = self.torch_tensor() + return torch_tensor.normal_(mean=mean, std=std) + def numel(self): return product(self._size) diff --git a/colossalai/utils/model/pipelinable.py b/colossalai/utils/model/pipelinable.py index ba5bbddb3..379430366 100644 --- a/colossalai/utils/model/pipelinable.py +++ b/colossalai/utils/model/pipelinable.py @@ -3,6 +3,7 @@ import functools from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str from colossalai.builder.pipeline import partition_uniform, partition_balanced from colossalai.core import global_context as gpc +from colossalai.tensor import ColoTensor class PipelinableContext(InsertPostInitMethodToModuleSubClasses): @@ -64,8 +65,15 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): layer_spec = LayerSpec(module.__class__, *modified_args, **kwargs) layer_spec.set_children(module.children()) self._layer_spec_dict[module_id] = layer_spec - for param in module.parameters(recurse=False): - param.data = torch.rand(1, 1) + name_list = [] + for name, param in module.named_parameters(): + if isinstance(param, ColoTensor): + continue + name_list.append((name, param)) + + for name, param in name_list: + delattr(module, name) + setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=False)) def to_layer_list(self, exec_seq=None): """