[fx] patched torch.full for huggingface opt (#1386)

pull/1388/head
Frank Lee 2022-07-29 17:56:28 +08:00 committed by GitHub
parent 527758b2ae
commit ad678921db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 0 deletions

View File

@ -132,3 +132,9 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None)
@meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None):
return torch.empty(input.shape, device='meta')
@meta_patched_function.register(torch.full)
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
assert out is None, 'assigning result to out is not supported yet'
return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)