diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index 2abf3b7e1..229443ed9 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -138,3 +138,36 @@ def torch_roll(input, shifts, dims=None): def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): assert out is None, 'assigning result to out is not supported yet' return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad) + + +@meta_patched_function.register(torch.max) +def torch_max(input, dim=None, keepdim=False, *, out=None): + assert out is None, 'assigning value to out is not supported yet' + if dim is not None: + if isinstance(dim, int): + shape = list(input.shape) + shape.pop(dim) + if keepdim: + shape.insert(dim, 1) + return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape, + device='meta', + dtype=input.dtype) + elif isinstance(dim, torch.Tensor): + # when dim is a 0D or 1D tensor, it will maintain the same shape + num_dims = dim.dim() + if num_dims in [0, 1]: + return torch.empty_like(input, device='meta') + else: + raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions") + else: + return torch.empty([], device='meta', dtype=input.dtype) + + +@meta_patched_function.register(torch.Tensor.cpu) +def torch_tensor_cpu(input): + return input.clone() + + +@meta_patched_function.register(torch.Tensor.cuda) +def torch_tensor_cuda(input, *args, **kwargs): + return input.clone() diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py index 05c29b824..4406f02db 100644 --- a/tests/test_fx/test_tracer/test_patched_op.py +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -61,3 +61,22 @@ def test_repeat_interleave(): patch_fn=repeat_interleave, expect_exception=True, output_shape=materialized_output.shape) + + +def test_torch_max(): + data = torch.rand(4, 3) + out = torch.max(data) + patched_out = patched_function.torch_max(data) + assert out.shape == patched_out.shape + + data = torch.rand(4, 3, 2) + out, idx = torch.max(data, dim=1) + patched_out, patched_idx = patched_function.torch_max(data, dim=1) + assert out.shape == patched_out.shape + assert idx.shape == patched_idx.shape + + data = torch.rand(4, 3, 2) + out, idx = torch.max(data, dim=1, keepdim=True) + patched_out, patched_idx = patched_function.torch_max(data, dim=1, keepdim=True) + assert out.shape == patched_out.shape + assert idx.shape == patched_idx.shape