Browse Source

[fx] patched torch.max and data movement operator (#1391)

* [fx] patched torch.max and data movement operator

* polish code
pull/1395/head
Frank Lee 2 years ago committed by GitHub
parent
commit
7d6293927f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 33
      colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
  2. 19
      tests/test_fx/test_tracer/test_patched_op.py

33
colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py

@ -138,3 +138,36 @@ def torch_roll(input, shifts, dims=None):
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
assert out is None, 'assigning result to out is not supported yet'
return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)
@meta_patched_function.register(torch.max)
def torch_max(input, dim=None, keepdim=False, *, out=None):
assert out is None, 'assigning value to out is not supported yet'
if dim is not None:
if isinstance(dim, int):
shape = list(input.shape)
shape.pop(dim)
if keepdim:
shape.insert(dim, 1)
return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape,
device='meta',
dtype=input.dtype)
elif isinstance(dim, torch.Tensor):
# when dim is a 0D or 1D tensor, it will maintain the same shape
num_dims = dim.dim()
if num_dims in [0, 1]:
return torch.empty_like(input, device='meta')
else:
raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions")
else:
return torch.empty([], device='meta', dtype=input.dtype)
@meta_patched_function.register(torch.Tensor.cpu)
def torch_tensor_cpu(input):
return input.clone()
@meta_patched_function.register(torch.Tensor.cuda)
def torch_tensor_cuda(input, *args, **kwargs):
return input.clone()

19
tests/test_fx/test_tracer/test_patched_op.py

@ -61,3 +61,22 @@ def test_repeat_interleave():
patch_fn=repeat_interleave,
expect_exception=True,
output_shape=materialized_output.shape)
def test_torch_max():
data = torch.rand(4, 3)
out = torch.max(data)
patched_out = patched_function.torch_max(data)
assert out.shape == patched_out.shape
data = torch.rand(4, 3, 2)
out, idx = torch.max(data, dim=1)
patched_out, patched_idx = patched_function.torch_max(data, dim=1)
assert out.shape == patched_out.shape
assert idx.shape == patched_idx.shape
data = torch.rand(4, 3, 2)
out, idx = torch.max(data, dim=1, keepdim=True)
patched_out, patched_idx = patched_function.torch_max(data, dim=1, keepdim=True)
assert out.shape == patched_out.shape
assert idx.shape == patched_idx.shape

Loading…
Cancel
Save