[hotfix]: Remove math.prod dependency (#2837)

* Remove math.prod dependency

* Fix style

* Fix style

---------

Co-authored-by: Jiatong Han <jiatong.han@u.nus.edu>
pull/2912/head
Jiatong (Julius) Han 2023-02-23 23:56:15 +08:00 committed by GitHub
parent 819e25d8b1
commit 8c8a39be95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 100 deletions

View File

@ -1,97 +1,96 @@
import math import operator
import torch from functools import reduce
from colossalai.tensor.op_wrapper import colo_op_impl from typing import Optional, Union
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
from typing import Optional, Union import torch
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
def _all_int(my_iter): from colossalai.tensor.op_wrapper import colo_op_impl
return all(isinstance(i, int) for i in my_iter)
def _all_int(my_iter):
def _get_valid_shape(shape): return all(isinstance(i, int) for i in my_iter)
if isinstance(shape, list):
if _all_int(shape):
return tuple(shape) def _get_valid_shape(shape):
else: if isinstance(shape, list):
raise RuntimeError("expects type(int) but finds an other type") if _all_int(shape):
elif isinstance(shape, tuple): return tuple(shape)
if _all_int(shape): else:
return shape raise RuntimeError("expects type(int) but finds an other type")
else: elif isinstance(shape, tuple):
return _get_valid_shape(shape[0]) if _all_int(shape):
else: return shape
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) else:
return _get_valid_shape(shape[0])
else:
def _shape_infer(org_sp, tgt_sp): raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
cnt = 0
pos = 0
for idx, dim in enumerate(tgt_sp): def _shape_infer(org_sp, tgt_sp):
if dim < -1: cnt = 0
raise RuntimeError("invalid shape dimension {}".format(dim)) pos = 0
elif dim == -1: for idx, dim in enumerate(tgt_sp):
cnt += 1 if dim < -1:
pos = idx raise RuntimeError("invalid shape dimension {}".format(dim))
elif dim == -1:
if cnt > 1: cnt += 1
raise RuntimeError("only one dimension can be inferred") pos = idx
org_prod = math.prod(org_sp) if cnt > 1:
tgt_prod = math.prod(tgt_sp) raise RuntimeError("only one dimension can be inferred")
if cnt == 0: org_prod = reduce(operator.mul, org_sp, 1)
if org_prod != tgt_prod: tgt_prod = reduce(operator.mul, tgt_sp, 1)
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
else: if cnt == 0:
return tgt_sp if org_prod != tgt_prod:
elif org_prod % tgt_prod != 0: raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) else:
return tgt_sp
infer_dim = -(org_prod // tgt_prod) elif org_prod % tgt_prod != 0:
return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:] raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
infer_dim = -(org_prod // tgt_prod)
@colo_op_impl(torch.Tensor.view) return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
Changes the shape of the current tensor. @colo_op_impl(torch.Tensor.view)
""" def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
assert isinstance(self, ColoTensor) """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
# apply original `view` function for replicated colo tensors Changes the shape of the current tensor.
if self.is_replicate(): """
return self.view(*shape) assert isinstance(self, ColoTensor)
# apply original `view` function for replicated colo tensors
cur_sp = self.size() if self.is_replicate():
org_sp = self.size_global() return self.view(*shape)
# parse the passed arguments
tgt_sp = _get_valid_shape(shape) cur_sp = self.size()
# get the correct shape from inference org_sp = self.size_global()
inf_sp = _shape_infer(org_sp, tgt_sp) # parse the passed arguments
tgt_sp = _get_valid_shape(shape)
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: # get the correct shape from inference
new_shape = (cur_sp[0],) + tgt_sp[1:] inf_sp = _shape_infer(org_sp, tgt_sp)
res = self.view(*new_shape)
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
new_shape = tgt_sp[:-1] + (cur_sp[-1],) new_shape = (cur_sp[0],) + tgt_sp[1:]
res = self.view(*new_shape) res = self.view(*new_shape)
else: elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
replicated_t = self.redistribute(dist_spec=ReplicaSpec()) new_shape = tgt_sp[:-1] + (cur_sp[-1],)
return ColoTensor.from_torch_tensor( res = self.view(*new_shape)
tensor=replicated_t.view(*shape), else:
spec=ColoTensorSpec(self.get_process_group())) replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
return ColoTensor.from_torch_tensor( spec=ColoTensorSpec(self.get_process_group()))
tensor=res,
spec=ColoTensorSpec( return ColoTensor.from_torch_tensor(tensor=res,
pg=self.get_process_group(), spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
dist_attr=self.dist_spec))
@colo_op_impl(torch.Tensor.size)
@colo_op_impl(torch.Tensor.size) def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: size = self.size_global()
size = self.size_global() if dim is None:
if dim is None: return size
return size else:
else: return size[dim]
return size[dim]

View File

@ -1,6 +1,6 @@
import math import operator
from copy import copy from copy import copy
from functools import lru_cache from functools import lru_cache, reduce
from typing import Callable, Optional, Set from typing import Callable, Optional, Set
import torch import torch
@ -312,7 +312,7 @@ class ColoTensor(torch.Tensor):
def numel_global(self): def numel_global(self):
"""Returns the number of elements in the tensor when it's replicated. """Returns the number of elements in the tensor when it's replicated.
""" """
return math.prod(self.size_global()) return reduce(operator.mul, self.size_global(), 1)
# Some API for dist spec check # Some API for dist spec check