From 4341f5e8e65b532d49a6a4dfc64367a417091fb3 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 17 Apr 2023 11:25:13 +0800 Subject: [PATCH] [lazyinit] fix clone and deepcopy (#3553) --- colossalai/utils/model/experimental.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index c91751f1c..bf3e3d05b 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -14,8 +14,8 @@ from colossalai.tensor.d_tensor.layout import Layout # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ "arange", - "empty", "full", + "empty", "linspace", "logspace", "ones", @@ -324,7 +324,9 @@ class LazyTensor(torch.Tensor): def clone(self) -> "LazyTensor": def factory_fn(): - return self.materialize().clone() + # if self is materialized, return self + new_tensor = self.materialize() if type(self) is LazyTensor else self + return new_tensor.clone() target = LazyTensor(factory_fn, meta_data=self._meta_data) @@ -333,6 +335,26 @@ class LazyTensor(torch.Tensor): def detach(self) -> Tensor: return self + def __deepcopy__(self, memo): + if not self.is_leaf: + raise RuntimeError("Only Tensors created explicitly by the user " + "(graph leaves) support the deepcopy protocol at the moment") + if id(self) in memo: + return memo[id(self)] + + def factory_fn(): + # if self is materialized, return self + new_tensor = self.materialize() if type(self) is LazyTensor else self + copied = new_tensor.detach().clone() + if new_tensor.requires_grad: + copied.requires_grad_() + return copied + + target = LazyTensor(factory_fn, meta_data=self._meta_data) + + memo[id(self)] = target + return target + @property def data(self): return self