[tensor] add zero_like colo op, important for Optimizer (#1236)

pull/1237/head
Jiarui Fang 2022-07-08 14:55:27 +08:00 committed by GitHub
parent 3b500984b1
commit 4a76084dc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 6 deletions

View File

@ -195,6 +195,7 @@ register_elementwise_op(torch.tan)
register_elementwise_op(torch.tanh) register_elementwise_op(torch.tanh)
register_elementwise_op(torch.atanh) register_elementwise_op(torch.atanh)
register_elementwise_op(torch.arctanh) register_elementwise_op(torch.arctanh)
register_elementwise_op(torch.zeros_like)
# nn.functional OP # nn.functional OP
register_elementwise_op(F.threshold) register_elementwise_op(F.threshold)

View File

@ -55,7 +55,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return tensor return tensor
def __repr__(self): def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}' return f'ColoParameter: {ColoTensor.__repr__(self)}'
@classmethod @classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None): def __torch_function__(cls, func, types, args=..., kwargs=None):

View File

@ -271,3 +271,6 @@ class ColoTensor(torch.Tensor):
def is_shard_1drow(self): def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \ return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD

View File

@ -42,7 +42,7 @@ def _run_wrapped_tensor_func():
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}" assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}"
def _run_operand(): def _run_operand(world_size):
pg = ProcessGroup() pg = ProcessGroup()
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
@ -53,6 +53,13 @@ def _run_operand():
assert isinstance(t_res, ColoTensor) assert isinstance(t_res, ColoTensor)
assert torch.allclose(t_ref_res, t_res) assert torch.allclose(t_ref_res, t_res)
pg = ProcessGroup(tp_degree=world_size)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
t.set_dist_spec(distspec.shard([0], [world_size]))
t_new = torch.zeros_like(t)
assert isinstance(t_new, ColoTensor)
assert t_new.is_sharded()
#### Test Distributed init a Colotensor #### Test Distributed init a Colotensor
@ -105,9 +112,8 @@ def run_dist_tests(rank, world_size, port):
_run_view(world_size) _run_view(world_size)
_run_process_group(world_size) _run_process_group(world_size)
_run_tensor_indexing() _run_tensor_indexing()
_run_operand() _run_operand(world_size)
# TODO not passed _run_wrapped_tensor_func()
# _run_wrapped_tensor_func()
@pytest.mark.dist @pytest.mark.dist
@ -119,4 +125,4 @@ def test_dist_cases(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_dist_cases(2) test_dist_cases(1)