diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index 2ee5cb112..2abf3b7e1 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -132,3 +132,9 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None) @meta_patched_function.register(torch.roll) def torch_roll(input, shifts, dims=None): return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.full) +def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + assert out is None, 'assigning result to out is not supported yet' + return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)