mirror of https://github.com/hpcaitech/ColossalAI
[hotfix]: Remove math.prod dependency (#2837)
* Remove math.prod dependency * Fix style * Fix style --------- Co-authored-by: Jiatong Han <jiatong.han@u.nus.edu>pull/2912/head
parent
819e25d8b1
commit
8c8a39be95
|
@ -1,9 +1,12 @@
|
|||
import math
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
|
||||
def _all_int(my_iter):
|
||||
return all(isinstance(i, int) for i in my_iter)
|
||||
|
@ -37,8 +40,8 @@ def _shape_infer(org_sp, tgt_sp):
|
|||
if cnt > 1:
|
||||
raise RuntimeError("only one dimension can be inferred")
|
||||
|
||||
org_prod = math.prod(org_sp)
|
||||
tgt_prod = math.prod(tgt_sp)
|
||||
org_prod = reduce(operator.mul, org_sp, 1)
|
||||
tgt_prod = reduce(operator.mul, tgt_sp, 1)
|
||||
|
||||
if cnt == 0:
|
||||
if org_prod != tgt_prod:
|
||||
|
@ -49,7 +52,7 @@ def _shape_infer(org_sp, tgt_sp):
|
|||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
||||
|
||||
infer_dim = -(org_prod // tgt_prod)
|
||||
return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:]
|
||||
return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
|
||||
|
||||
|
||||
@colo_op_impl(torch.Tensor.view)
|
||||
|
@ -77,15 +80,11 @@ def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
|
|||
res = self.view(*new_shape)
|
||||
else:
|
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||
return ColoTensor.from_torch_tensor(
|
||||
tensor=replicated_t.view(*shape),
|
||||
return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
|
||||
spec=ColoTensorSpec(self.get_process_group()))
|
||||
|
||||
return ColoTensor.from_torch_tensor(
|
||||
tensor=res,
|
||||
spec=ColoTensorSpec(
|
||||
pg=self.get_process_group(),
|
||||
dist_attr=self.dist_spec))
|
||||
return ColoTensor.from_torch_tensor(tensor=res,
|
||||
spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
|
||||
|
||||
|
||||
@colo_op_impl(torch.Tensor.size)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
import operator
|
||||
from copy import copy
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, reduce
|
||||
from typing import Callable, Optional, Set
|
||||
|
||||
import torch
|
||||
|
@ -312,7 +312,7 @@ class ColoTensor(torch.Tensor):
|
|||
def numel_global(self):
|
||||
"""Returns the number of elements in the tensor when it's replicated.
|
||||
"""
|
||||
return math.prod(self.size_global())
|
||||
return reduce(operator.mul, self.size_global(), 1)
|
||||
|
||||
# Some API for dist spec check
|
||||
|
||||
|
|
Loading…
Reference in New Issue