mirror of https://github.com/hpcaitech/ColossalAI
[hotfix]: Remove math.prod dependency (#2837)
* Remove math.prod dependency * Fix style * Fix style --------- Co-authored-by: Jiatong Han <jiatong.han@u.nus.edu>pull/2912/head
parent
819e25d8b1
commit
8c8a39be95
|
@ -1,97 +1,96 @@
|
||||||
import math
|
import operator
|
||||||
import torch
|
from functools import reduce
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from typing import Optional, Union
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
|
||||||
from typing import Optional, Union
|
import torch
|
||||||
|
|
||||||
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
||||||
def _all_int(my_iter):
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
return all(isinstance(i, int) for i in my_iter)
|
|
||||||
|
|
||||||
|
def _all_int(my_iter):
|
||||||
def _get_valid_shape(shape):
|
return all(isinstance(i, int) for i in my_iter)
|
||||||
if isinstance(shape, list):
|
|
||||||
if _all_int(shape):
|
|
||||||
return tuple(shape)
|
def _get_valid_shape(shape):
|
||||||
else:
|
if isinstance(shape, list):
|
||||||
raise RuntimeError("expects type(int) but finds an other type")
|
if _all_int(shape):
|
||||||
elif isinstance(shape, tuple):
|
return tuple(shape)
|
||||||
if _all_int(shape):
|
else:
|
||||||
return shape
|
raise RuntimeError("expects type(int) but finds an other type")
|
||||||
else:
|
elif isinstance(shape, tuple):
|
||||||
return _get_valid_shape(shape[0])
|
if _all_int(shape):
|
||||||
else:
|
return shape
|
||||||
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
|
else:
|
||||||
|
return _get_valid_shape(shape[0])
|
||||||
|
else:
|
||||||
def _shape_infer(org_sp, tgt_sp):
|
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
|
||||||
cnt = 0
|
|
||||||
pos = 0
|
|
||||||
for idx, dim in enumerate(tgt_sp):
|
def _shape_infer(org_sp, tgt_sp):
|
||||||
if dim < -1:
|
cnt = 0
|
||||||
raise RuntimeError("invalid shape dimension {}".format(dim))
|
pos = 0
|
||||||
elif dim == -1:
|
for idx, dim in enumerate(tgt_sp):
|
||||||
cnt += 1
|
if dim < -1:
|
||||||
pos = idx
|
raise RuntimeError("invalid shape dimension {}".format(dim))
|
||||||
|
elif dim == -1:
|
||||||
if cnt > 1:
|
cnt += 1
|
||||||
raise RuntimeError("only one dimension can be inferred")
|
pos = idx
|
||||||
|
|
||||||
org_prod = math.prod(org_sp)
|
if cnt > 1:
|
||||||
tgt_prod = math.prod(tgt_sp)
|
raise RuntimeError("only one dimension can be inferred")
|
||||||
|
|
||||||
if cnt == 0:
|
org_prod = reduce(operator.mul, org_sp, 1)
|
||||||
if org_prod != tgt_prod:
|
tgt_prod = reduce(operator.mul, tgt_sp, 1)
|
||||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
|
||||||
else:
|
if cnt == 0:
|
||||||
return tgt_sp
|
if org_prod != tgt_prod:
|
||||||
elif org_prod % tgt_prod != 0:
|
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
||||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
else:
|
||||||
|
return tgt_sp
|
||||||
infer_dim = -(org_prod // tgt_prod)
|
elif org_prod % tgt_prod != 0:
|
||||||
return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:]
|
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
||||||
|
|
||||||
|
infer_dim = -(org_prod // tgt_prod)
|
||||||
@colo_op_impl(torch.Tensor.view)
|
return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
|
||||||
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
|
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
|
|
||||||
Changes the shape of the current tensor.
|
@colo_op_impl(torch.Tensor.view)
|
||||||
"""
|
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
|
||||||
assert isinstance(self, ColoTensor)
|
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
|
||||||
# apply original `view` function for replicated colo tensors
|
Changes the shape of the current tensor.
|
||||||
if self.is_replicate():
|
"""
|
||||||
return self.view(*shape)
|
assert isinstance(self, ColoTensor)
|
||||||
|
# apply original `view` function for replicated colo tensors
|
||||||
cur_sp = self.size()
|
if self.is_replicate():
|
||||||
org_sp = self.size_global()
|
return self.view(*shape)
|
||||||
# parse the passed arguments
|
|
||||||
tgt_sp = _get_valid_shape(shape)
|
cur_sp = self.size()
|
||||||
# get the correct shape from inference
|
org_sp = self.size_global()
|
||||||
inf_sp = _shape_infer(org_sp, tgt_sp)
|
# parse the passed arguments
|
||||||
|
tgt_sp = _get_valid_shape(shape)
|
||||||
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
|
# get the correct shape from inference
|
||||||
new_shape = (cur_sp[0],) + tgt_sp[1:]
|
inf_sp = _shape_infer(org_sp, tgt_sp)
|
||||||
res = self.view(*new_shape)
|
|
||||||
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
|
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
|
||||||
new_shape = tgt_sp[:-1] + (cur_sp[-1],)
|
new_shape = (cur_sp[0],) + tgt_sp[1:]
|
||||||
res = self.view(*new_shape)
|
res = self.view(*new_shape)
|
||||||
else:
|
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
|
||||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
new_shape = tgt_sp[:-1] + (cur_sp[-1],)
|
||||||
return ColoTensor.from_torch_tensor(
|
res = self.view(*new_shape)
|
||||||
tensor=replicated_t.view(*shape),
|
else:
|
||||||
spec=ColoTensorSpec(self.get_process_group()))
|
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||||
|
return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
|
||||||
return ColoTensor.from_torch_tensor(
|
spec=ColoTensorSpec(self.get_process_group()))
|
||||||
tensor=res,
|
|
||||||
spec=ColoTensorSpec(
|
return ColoTensor.from_torch_tensor(tensor=res,
|
||||||
pg=self.get_process_group(),
|
spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
|
||||||
dist_attr=self.dist_spec))
|
|
||||||
|
|
||||||
|
@colo_op_impl(torch.Tensor.size)
|
||||||
@colo_op_impl(torch.Tensor.size)
|
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
|
||||||
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
|
size = self.size_global()
|
||||||
size = self.size_global()
|
if dim is None:
|
||||||
if dim is None:
|
return size
|
||||||
return size
|
else:
|
||||||
else:
|
return size[dim]
|
||||||
return size[dim]
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import math
|
import operator
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import lru_cache
|
from functools import lru_cache, reduce
|
||||||
from typing import Callable, Optional, Set
|
from typing import Callable, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -312,7 +312,7 @@ class ColoTensor(torch.Tensor):
|
||||||
def numel_global(self):
|
def numel_global(self):
|
||||||
"""Returns the number of elements in the tensor when it's replicated.
|
"""Returns the number of elements in the tensor when it's replicated.
|
||||||
"""
|
"""
|
||||||
return math.prod(self.size_global())
|
return reduce(operator.mul, self.size_global(), 1)
|
||||||
|
|
||||||
# Some API for dist spec check
|
# Some API for dist spec check
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue