diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 1e392c04d..3f7787694 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -52,7 +52,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - num_embeddings_per_partition = weight.size_base(0) + num_embeddings_per_partition = weight.size_local(0) vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition vocab_end_index = vocab_start_index + num_embeddings_per_partition diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 960c26474..0414c7d07 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -101,13 +101,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): # TODO(jzy) we don't support object reflection now. # distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`. raise NotImplementedError - - #### the ColoParameter should use the torch.Tensor's builtin methodes ### - - def view(self, *args) -> 'ColoTensor': - return super().view_base(*args) - - def size(self, *args, **kwargs) -> torch.Size: - # import inspect - # print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()]) - return super().size_base(*args, **kwargs) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index bd7859dee..f36c18313 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -147,13 +147,13 @@ class ColoTensor(torch.Tensor): ##### override builtin functions which must use tensor in replicate placement #### - def view_base(self, *args) -> 'ColoTensor': + def view_local(self, *args) -> 'ColoTensor': return super().view(*args) - def size_base(self, *args, **kwargs) -> torch.Size: + def size_local(self, *args, **kwargs) -> torch.Size: return super().size(*args, **kwargs) - def view(self, *args) -> 'ColoTensor': + def view_global(self, *args) -> 'ColoTensor': """override the torch buildin view() the args passed in must be in a replicate placement. Returns: @@ -167,7 +167,7 @@ class ColoTensor(torch.Tensor): self._tensor_spec.dist_spec = distspec.replicate() return super().view(*args) - def size(self, args: Optional[int] = None): + def size_global(self, args: Optional[int] = None): """override the torch buildin size() the shape passed in must be in a replicate placement. Returns: diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index 836d7c16e..b04f14f57 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -67,14 +67,14 @@ def _run_view(world_size): TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size]))) - assert t.size()[0] == 4 * world_size - assert t.size(1) == 5 - assert t.size() == torch.Size([4 * world_size, 5]) + assert t.size_global()[0] == 4 * world_size + assert t.size_global(1) == 5 + assert t.size_global() == torch.Size([4 * world_size, 5]) - t.view_base(4 * 5) + t.view_local(4 * 5) assert t.tensor_spec.dist_spec.placement.value == 's' - t = t.view(4 * 5 * world_size) + t = t.view_global(4 * 5 * world_size) assert t.tensor_spec.dist_spec.placement.value == 'r' assert t.shape == torch.Size([4 * 5 * world_size])