mirror of https://github.com/hpcaitech/ColossalAI
[fx] patched torch.max and data movement operator (#1391)
* [fx] patched torch.max and data movement operator * polish codepull/1395/head
parent
db89600cf2
commit
7d6293927f
|
@ -138,3 +138,36 @@ def torch_roll(input, shifts, dims=None):
|
||||||
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
|
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
|
||||||
assert out is None, 'assigning result to out is not supported yet'
|
assert out is None, 'assigning result to out is not supported yet'
|
||||||
return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)
|
return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.max)
|
||||||
|
def torch_max(input, dim=None, keepdim=False, *, out=None):
|
||||||
|
assert out is None, 'assigning value to out is not supported yet'
|
||||||
|
if dim is not None:
|
||||||
|
if isinstance(dim, int):
|
||||||
|
shape = list(input.shape)
|
||||||
|
shape.pop(dim)
|
||||||
|
if keepdim:
|
||||||
|
shape.insert(dim, 1)
|
||||||
|
return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape,
|
||||||
|
device='meta',
|
||||||
|
dtype=input.dtype)
|
||||||
|
elif isinstance(dim, torch.Tensor):
|
||||||
|
# when dim is a 0D or 1D tensor, it will maintain the same shape
|
||||||
|
num_dims = dim.dim()
|
||||||
|
if num_dims in [0, 1]:
|
||||||
|
return torch.empty_like(input, device='meta')
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions")
|
||||||
|
else:
|
||||||
|
return torch.empty([], device='meta', dtype=input.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.Tensor.cpu)
|
||||||
|
def torch_tensor_cpu(input):
|
||||||
|
return input.clone()
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.Tensor.cuda)
|
||||||
|
def torch_tensor_cuda(input, *args, **kwargs):
|
||||||
|
return input.clone()
|
||||||
|
|
|
@ -61,3 +61,22 @@ def test_repeat_interleave():
|
||||||
patch_fn=repeat_interleave,
|
patch_fn=repeat_interleave,
|
||||||
expect_exception=True,
|
expect_exception=True,
|
||||||
output_shape=materialized_output.shape)
|
output_shape=materialized_output.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_max():
|
||||||
|
data = torch.rand(4, 3)
|
||||||
|
out = torch.max(data)
|
||||||
|
patched_out = patched_function.torch_max(data)
|
||||||
|
assert out.shape == patched_out.shape
|
||||||
|
|
||||||
|
data = torch.rand(4, 3, 2)
|
||||||
|
out, idx = torch.max(data, dim=1)
|
||||||
|
patched_out, patched_idx = patched_function.torch_max(data, dim=1)
|
||||||
|
assert out.shape == patched_out.shape
|
||||||
|
assert idx.shape == patched_idx.shape
|
||||||
|
|
||||||
|
data = torch.rand(4, 3, 2)
|
||||||
|
out, idx = torch.max(data, dim=1, keepdim=True)
|
||||||
|
patched_out, patched_idx = patched_function.torch_max(data, dim=1, keepdim=True)
|
||||||
|
assert out.shape == patched_out.shape
|
||||||
|
assert idx.shape == patched_idx.shape
|
||||||
|
|
Loading…
Reference in New Issue