[hotfix]: Remove math.prod dependency (#2837)

* Remove math.prod dependency

* Fix style

* Fix style

---------

Co-authored-by: Jiatong Han <jiatong.han@u.nus.edu>
pull/2912/head
Jiatong (Julius) Han 2 years ago committed by GitHub
parent 819e25d8b1
commit 8c8a39be95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,11 @@
import math import operator
from functools import reduce
from typing import Optional, Union
import torch import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
from typing import Optional, Union from colossalai.tensor.op_wrapper import colo_op_impl
def _all_int(my_iter): def _all_int(my_iter):
@ -37,8 +40,8 @@ def _shape_infer(org_sp, tgt_sp):
if cnt > 1: if cnt > 1:
raise RuntimeError("only one dimension can be inferred") raise RuntimeError("only one dimension can be inferred")
org_prod = math.prod(org_sp) org_prod = reduce(operator.mul, org_sp, 1)
tgt_prod = math.prod(tgt_sp) tgt_prod = reduce(operator.mul, tgt_sp, 1)
if cnt == 0: if cnt == 0:
if org_prod != tgt_prod: if org_prod != tgt_prod:
@ -77,15 +80,11 @@ def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
res = self.view(*new_shape) res = self.view(*new_shape)
else: else:
replicated_t = self.redistribute(dist_spec=ReplicaSpec()) replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return ColoTensor.from_torch_tensor( return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
tensor=replicated_t.view(*shape),
spec=ColoTensorSpec(self.get_process_group())) spec=ColoTensorSpec(self.get_process_group()))
return ColoTensor.from_torch_tensor( return ColoTensor.from_torch_tensor(tensor=res,
tensor=res, spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
spec=ColoTensorSpec(
pg=self.get_process_group(),
dist_attr=self.dist_spec))
@colo_op_impl(torch.Tensor.size) @colo_op_impl(torch.Tensor.size)

@ -1,6 +1,6 @@
import math import operator
from copy import copy from copy import copy
from functools import lru_cache from functools import lru_cache, reduce
from typing import Callable, Optional, Set from typing import Callable, Optional, Set
import torch import torch
@ -312,7 +312,7 @@ class ColoTensor(torch.Tensor):
def numel_global(self): def numel_global(self):
"""Returns the number of elements in the tensor when it's replicated. """Returns the number of elements in the tensor when it's replicated.
""" """
return math.prod(self.size_global()) return reduce(operator.mul, self.size_global(), 1)
# Some API for dist spec check # Some API for dist spec check

Loading…
Cancel
Save