|
|
@ -147,13 +147,13 @@ class ColoTensor(torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
##### override builtin functions which must use tensor in replicate placement ####
|
|
|
|
##### override builtin functions which must use tensor in replicate placement ####
|
|
|
|
|
|
|
|
|
|
|
|
def view_base(self, *args) -> 'ColoTensor':
|
|
|
|
def view_local(self, *args) -> 'ColoTensor':
|
|
|
|
return super().view(*args)
|
|
|
|
return super().view(*args)
|
|
|
|
|
|
|
|
|
|
|
|
def size_base(self, *args, **kwargs) -> torch.Size:
|
|
|
|
def size_local(self, *args, **kwargs) -> torch.Size:
|
|
|
|
return super().size(*args, **kwargs)
|
|
|
|
return super().size(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def view(self, *args) -> 'ColoTensor':
|
|
|
|
def view_global(self, *args) -> 'ColoTensor':
|
|
|
|
"""override the torch buildin view()
|
|
|
|
"""override the torch buildin view()
|
|
|
|
the args passed in must be in a replicate placement.
|
|
|
|
the args passed in must be in a replicate placement.
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
@ -167,7 +167,7 @@ class ColoTensor(torch.Tensor):
|
|
|
|
self._tensor_spec.dist_spec = distspec.replicate()
|
|
|
|
self._tensor_spec.dist_spec = distspec.replicate()
|
|
|
|
return super().view(*args)
|
|
|
|
return super().view(*args)
|
|
|
|
|
|
|
|
|
|
|
|
def size(self, args: Optional[int] = None):
|
|
|
|
def size_global(self, args: Optional[int] = None):
|
|
|
|
"""override the torch buildin size()
|
|
|
|
"""override the torch buildin size()
|
|
|
|
the shape passed in must be in a replicate placement.
|
|
|
|
the shape passed in must be in a replicate placement.
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|