[tensor] revert local view back (#1178)

pull/1181/head
Jiarui Fang 2 years ago committed by GitHub
parent 0dd4e2bbfb
commit 1b657f9ce1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -52,7 +52,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
num_embeddings_per_partition = weight.size_base(0) num_embeddings_per_partition = weight.size_local(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition vocab_end_index = vocab_start_index + num_embeddings_per_partition

@ -101,13 +101,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
# TODO(jzy) we don't support object reflection now. # TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`. # distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError raise NotImplementedError
#### the ColoParameter should use the torch.Tensor's builtin methodes ###
def view(self, *args) -> 'ColoTensor':
return super().view_base(*args)
def size(self, *args, **kwargs) -> torch.Size:
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
return super().size_base(*args, **kwargs)

@ -147,13 +147,13 @@ class ColoTensor(torch.Tensor):
##### override builtin functions which must use tensor in replicate placement #### ##### override builtin functions which must use tensor in replicate placement ####
def view_base(self, *args) -> 'ColoTensor': def view_local(self, *args) -> 'ColoTensor':
return super().view(*args) return super().view(*args)
def size_base(self, *args, **kwargs) -> torch.Size: def size_local(self, *args, **kwargs) -> torch.Size:
return super().size(*args, **kwargs) return super().size(*args, **kwargs)
def view(self, *args) -> 'ColoTensor': def view_global(self, *args) -> 'ColoTensor':
"""override the torch buildin view() """override the torch buildin view()
the args passed in must be in a replicate placement. the args passed in must be in a replicate placement.
Returns: Returns:
@ -167,7 +167,7 @@ class ColoTensor(torch.Tensor):
self._tensor_spec.dist_spec = distspec.replicate() self._tensor_spec.dist_spec = distspec.replicate()
return super().view(*args) return super().view(*args)
def size(self, args: Optional[int] = None): def size_global(self, args: Optional[int] = None):
"""override the torch buildin size() """override the torch buildin size()
the shape passed in must be in a replicate placement. the shape passed in must be in a replicate placement.
Returns: Returns:

@ -67,14 +67,14 @@ def _run_view(world_size):
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
num_partitions=[world_size]))) num_partitions=[world_size])))
assert t.size()[0] == 4 * world_size assert t.size_global()[0] == 4 * world_size
assert t.size(1) == 5 assert t.size_global(1) == 5
assert t.size() == torch.Size([4 * world_size, 5]) assert t.size_global() == torch.Size([4 * world_size, 5])
t.view_base(4 * 5) t.view_local(4 * 5)
assert t.tensor_spec.dist_spec.placement.value == 's' assert t.tensor_spec.dist_spec.placement.value == 's'
t = t.view(4 * 5 * world_size) t = t.view_global(4 * 5 * world_size)
assert t.tensor_spec.dist_spec.placement.value == 'r' assert t.tensor_spec.dist_spec.placement.value == 'r'
assert t.shape == torch.Size([4 * 5 * world_size]) assert t.shape == torch.Size([4 * 5 * world_size])

Loading…
Cancel
Save