|
|
|
@ -138,3 +138,36 @@ def torch_roll(input, shifts, dims=None):
|
|
|
|
|
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): |
|
|
|
|
assert out is None, 'assigning result to out is not supported yet' |
|
|
|
|
return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@meta_patched_function.register(torch.max) |
|
|
|
|
def torch_max(input, dim=None, keepdim=False, *, out=None): |
|
|
|
|
assert out is None, 'assigning value to out is not supported yet' |
|
|
|
|
if dim is not None: |
|
|
|
|
if isinstance(dim, int): |
|
|
|
|
shape = list(input.shape) |
|
|
|
|
shape.pop(dim) |
|
|
|
|
if keepdim: |
|
|
|
|
shape.insert(dim, 1) |
|
|
|
|
return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape, |
|
|
|
|
device='meta', |
|
|
|
|
dtype=input.dtype) |
|
|
|
|
elif isinstance(dim, torch.Tensor): |
|
|
|
|
# when dim is a 0D or 1D tensor, it will maintain the same shape |
|
|
|
|
num_dims = dim.dim() |
|
|
|
|
if num_dims in [0, 1]: |
|
|
|
|
return torch.empty_like(input, device='meta') |
|
|
|
|
else: |
|
|
|
|
raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions") |
|
|
|
|
else: |
|
|
|
|
return torch.empty([], device='meta', dtype=input.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@meta_patched_function.register(torch.Tensor.cpu) |
|
|
|
|
def torch_tensor_cpu(input): |
|
|
|
|
return input.clone() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@meta_patched_function.register(torch.Tensor.cuda) |
|
|
|
|
def torch_tensor_cuda(input, *args, **kwargs): |
|
|
|
|
return input.clone() |
|
|
|
|