mirror of https://github.com/hpcaitech/ColossalAI
[lazyinit] fix clone and deepcopy (#3553)
parent
1c7734bc94
commit
4341f5e8e6
|
@ -14,8 +14,8 @@ from colossalai.tensor.d_tensor.layout import Layout
|
|||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
"arange",
|
||||
"empty",
|
||||
"full",
|
||||
"empty",
|
||||
"linspace",
|
||||
"logspace",
|
||||
"ones",
|
||||
|
@ -324,7 +324,9 @@ class LazyTensor(torch.Tensor):
|
|||
def clone(self) -> "LazyTensor":
|
||||
|
||||
def factory_fn():
|
||||
return self.materialize().clone()
|
||||
# if self is materialized, return self
|
||||
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||
return new_tensor.clone()
|
||||
|
||||
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||
|
||||
|
@ -333,6 +335,26 @@ class LazyTensor(torch.Tensor):
|
|||
def detach(self) -> Tensor:
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if not self.is_leaf:
|
||||
raise RuntimeError("Only Tensors created explicitly by the user "
|
||||
"(graph leaves) support the deepcopy protocol at the moment")
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
|
||||
def factory_fn():
|
||||
# if self is materialized, return self
|
||||
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||
copied = new_tensor.detach().clone()
|
||||
if new_tensor.requires_grad:
|
||||
copied.requires_grad_()
|
||||
return copied
|
||||
|
||||
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||
|
||||
memo[id(self)] = target
|
||||
return target
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue