mirror of https://github.com/hpcaitech/ColossalAI
[tensor] add zero_like colo op, important for Optimizer (#1236)
parent
3b500984b1
commit
4a76084dc9
|
@ -195,6 +195,7 @@ register_elementwise_op(torch.tan)
|
||||||
register_elementwise_op(torch.tanh)
|
register_elementwise_op(torch.tanh)
|
||||||
register_elementwise_op(torch.atanh)
|
register_elementwise_op(torch.atanh)
|
||||||
register_elementwise_op(torch.arctanh)
|
register_elementwise_op(torch.arctanh)
|
||||||
|
register_elementwise_op(torch.zeros_like)
|
||||||
|
|
||||||
# nn.functional OP
|
# nn.functional OP
|
||||||
register_elementwise_op(F.threshold)
|
register_elementwise_op(F.threshold)
|
||||||
|
|
|
@ -55,7 +55,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
return f'ColoParameter: {ColoTensor.__repr__(self)}'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||||
|
|
|
@ -271,3 +271,6 @@ class ColoTensor(torch.Tensor):
|
||||||
def is_shard_1drow(self):
|
def is_shard_1drow(self):
|
||||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||||
|
|
||||||
|
def is_sharded(self):
|
||||||
|
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
|
@ -42,7 +42,7 @@ def _run_wrapped_tensor_func():
|
||||||
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}"
|
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}"
|
||||||
|
|
||||||
|
|
||||||
def _run_operand():
|
def _run_operand(world_size):
|
||||||
pg = ProcessGroup()
|
pg = ProcessGroup()
|
||||||
t_ref = torch.randn(4, 5)
|
t_ref = torch.randn(4, 5)
|
||||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||||
|
@ -53,6 +53,13 @@ def _run_operand():
|
||||||
assert isinstance(t_res, ColoTensor)
|
assert isinstance(t_res, ColoTensor)
|
||||||
assert torch.allclose(t_ref_res, t_res)
|
assert torch.allclose(t_ref_res, t_res)
|
||||||
|
|
||||||
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||||
|
t.set_dist_spec(distspec.shard([0], [world_size]))
|
||||||
|
t_new = torch.zeros_like(t)
|
||||||
|
assert isinstance(t_new, ColoTensor)
|
||||||
|
assert t_new.is_sharded()
|
||||||
|
|
||||||
|
|
||||||
#### Test Distributed init a Colotensor
|
#### Test Distributed init a Colotensor
|
||||||
|
|
||||||
|
@ -105,9 +112,8 @@ def run_dist_tests(rank, world_size, port):
|
||||||
_run_view(world_size)
|
_run_view(world_size)
|
||||||
_run_process_group(world_size)
|
_run_process_group(world_size)
|
||||||
_run_tensor_indexing()
|
_run_tensor_indexing()
|
||||||
_run_operand()
|
_run_operand(world_size)
|
||||||
# TODO not passed
|
_run_wrapped_tensor_func()
|
||||||
# _run_wrapped_tensor_func()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -119,4 +125,4 @@ def test_dist_cases(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_dist_cases(2)
|
test_dist_cases(1)
|
||||||
|
|
Loading…
Reference in New Issue