From 4a76084dc97fcba4c576970656c07493e38ddb1b Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 8 Jul 2022 14:55:27 +0800 Subject: [PATCH] [tensor] add zero_like colo op, important for Optimizer (#1236) --- colossalai/nn/_ops/element_wise.py | 1 + colossalai/tensor/colo_parameter.py | 2 +- colossalai/tensor/colo_tensor.py | 3 +++ tests/test_tensor/test_tensor.py | 16 +++++++++++----- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py index b7b6b7c9c..c3c1421e7 100644 --- a/colossalai/nn/_ops/element_wise.py +++ b/colossalai/nn/_ops/element_wise.py @@ -195,6 +195,7 @@ register_elementwise_op(torch.tan) register_elementwise_op(torch.tanh) register_elementwise_op(torch.atanh) register_elementwise_op(torch.arctanh) +register_elementwise_op(torch.zeros_like) # nn.functional OP register_elementwise_op(F.threshold) diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 65eb77d4b..8963d2194 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -55,7 +55,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): return tensor def __repr__(self): - return f'ColoParameter: {torch.Tensor.__repr__(self)}' + return f'ColoParameter: {ColoTensor.__repr__(self)}' @classmethod def __torch_function__(cls, func, types, args=..., kwargs=None): diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 699f56e53..17c30ad34 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -271,3 +271,6 @@ class ColoTensor(torch.Tensor): def is_shard_1drow(self): return self.dist_spec.placement == DistPlacementPattern.SHARD \ and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 + + def is_sharded(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD \ No newline at end of file diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index c77ba9d59..3c763562f 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -42,7 +42,7 @@ def _run_wrapped_tensor_func(): assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}" -def _run_operand(): +def _run_operand(world_size): pg = ProcessGroup() t_ref = torch.randn(4, 5) t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) @@ -53,6 +53,13 @@ def _run_operand(): assert isinstance(t_res, ColoTensor) assert torch.allclose(t_ref_res, t_res) + pg = ProcessGroup(tp_degree=world_size) + t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) + t.set_dist_spec(distspec.shard([0], [world_size])) + t_new = torch.zeros_like(t) + assert isinstance(t_new, ColoTensor) + assert t_new.is_sharded() + #### Test Distributed init a Colotensor @@ -105,9 +112,8 @@ def run_dist_tests(rank, world_size, port): _run_view(world_size) _run_process_group(world_size) _run_tensor_indexing() - _run_operand() - # TODO not passed - # _run_wrapped_tensor_func() + _run_operand(world_size) + _run_wrapped_tensor_func() @pytest.mark.dist @@ -119,4 +125,4 @@ def test_dist_cases(world_size): if __name__ == '__main__': - test_dist_cases(2) + test_dist_cases(1)