mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* Remove math.prod dependency * Fix style * Fix style --------- Co-authored-by: Jiatong Han <jiatong.han@u.nus.edu>pull/2912/head
Jiatong (Julius) Han
2 years ago
committed by
GitHub
2 changed files with 99 additions and 100 deletions
@ -1,97 +1,96 @@
|
||||
import math |
||||
import torch |
||||
from colossalai.tensor.op_wrapper import colo_op_impl |
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec |
||||
from typing import Optional, Union |
||||
|
||||
|
||||
def _all_int(my_iter): |
||||
return all(isinstance(i, int) for i in my_iter) |
||||
|
||||
|
||||
def _get_valid_shape(shape): |
||||
if isinstance(shape, list): |
||||
if _all_int(shape): |
||||
return tuple(shape) |
||||
else: |
||||
raise RuntimeError("expects type(int) but finds an other type") |
||||
elif isinstance(shape, tuple): |
||||
if _all_int(shape): |
||||
return shape |
||||
else: |
||||
return _get_valid_shape(shape[0]) |
||||
else: |
||||
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) |
||||
|
||||
|
||||
def _shape_infer(org_sp, tgt_sp): |
||||
cnt = 0 |
||||
pos = 0 |
||||
for idx, dim in enumerate(tgt_sp): |
||||
if dim < -1: |
||||
raise RuntimeError("invalid shape dimension {}".format(dim)) |
||||
elif dim == -1: |
||||
cnt += 1 |
||||
pos = idx |
||||
|
||||
if cnt > 1: |
||||
raise RuntimeError("only one dimension can be inferred") |
||||
|
||||
org_prod = math.prod(org_sp) |
||||
tgt_prod = math.prod(tgt_sp) |
||||
|
||||
if cnt == 0: |
||||
if org_prod != tgt_prod: |
||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) |
||||
else: |
||||
return tgt_sp |
||||
elif org_prod % tgt_prod != 0: |
||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) |
||||
|
||||
infer_dim = -(org_prod // tgt_prod) |
||||
return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:] |
||||
|
||||
|
||||
@colo_op_impl(torch.Tensor.view) |
||||
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': |
||||
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. |
||||
Changes the shape of the current tensor. |
||||
""" |
||||
assert isinstance(self, ColoTensor) |
||||
# apply original `view` function for replicated colo tensors |
||||
if self.is_replicate(): |
||||
return self.view(*shape) |
||||
|
||||
cur_sp = self.size() |
||||
org_sp = self.size_global() |
||||
# parse the passed arguments |
||||
tgt_sp = _get_valid_shape(shape) |
||||
# get the correct shape from inference |
||||
inf_sp = _shape_infer(org_sp, tgt_sp) |
||||
|
||||
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: |
||||
new_shape = (cur_sp[0],) + tgt_sp[1:] |
||||
res = self.view(*new_shape) |
||||
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: |
||||
new_shape = tgt_sp[:-1] + (cur_sp[-1],) |
||||
res = self.view(*new_shape) |
||||
else: |
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec()) |
||||
return ColoTensor.from_torch_tensor( |
||||
tensor=replicated_t.view(*shape), |
||||
spec=ColoTensorSpec(self.get_process_group())) |
||||
|
||||
return ColoTensor.from_torch_tensor( |
||||
tensor=res, |
||||
spec=ColoTensorSpec( |
||||
pg=self.get_process_group(), |
||||
dist_attr=self.dist_spec)) |
||||
|
||||
|
||||
@colo_op_impl(torch.Tensor.size) |
||||
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: |
||||
size = self.size_global() |
||||
if dim is None: |
||||
return size |
||||
else: |
||||
return size[dim] |
||||
import operator |
||||
from functools import reduce |
||||
from typing import Optional, Union |
||||
|
||||
import torch |
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec |
||||
from colossalai.tensor.op_wrapper import colo_op_impl |
||||
|
||||
|
||||
def _all_int(my_iter): |
||||
return all(isinstance(i, int) for i in my_iter) |
||||
|
||||
|
||||
def _get_valid_shape(shape): |
||||
if isinstance(shape, list): |
||||
if _all_int(shape): |
||||
return tuple(shape) |
||||
else: |
||||
raise RuntimeError("expects type(int) but finds an other type") |
||||
elif isinstance(shape, tuple): |
||||
if _all_int(shape): |
||||
return shape |
||||
else: |
||||
return _get_valid_shape(shape[0]) |
||||
else: |
||||
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) |
||||
|
||||
|
||||
def _shape_infer(org_sp, tgt_sp): |
||||
cnt = 0 |
||||
pos = 0 |
||||
for idx, dim in enumerate(tgt_sp): |
||||
if dim < -1: |
||||
raise RuntimeError("invalid shape dimension {}".format(dim)) |
||||
elif dim == -1: |
||||
cnt += 1 |
||||
pos = idx |
||||
|
||||
if cnt > 1: |
||||
raise RuntimeError("only one dimension can be inferred") |
||||
|
||||
org_prod = reduce(operator.mul, org_sp, 1) |
||||
tgt_prod = reduce(operator.mul, tgt_sp, 1) |
||||
|
||||
if cnt == 0: |
||||
if org_prod != tgt_prod: |
||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) |
||||
else: |
||||
return tgt_sp |
||||
elif org_prod % tgt_prod != 0: |
||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) |
||||
|
||||
infer_dim = -(org_prod // tgt_prod) |
||||
return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:] |
||||
|
||||
|
||||
@colo_op_impl(torch.Tensor.view) |
||||
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': |
||||
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. |
||||
Changes the shape of the current tensor. |
||||
""" |
||||
assert isinstance(self, ColoTensor) |
||||
# apply original `view` function for replicated colo tensors |
||||
if self.is_replicate(): |
||||
return self.view(*shape) |
||||
|
||||
cur_sp = self.size() |
||||
org_sp = self.size_global() |
||||
# parse the passed arguments |
||||
tgt_sp = _get_valid_shape(shape) |
||||
# get the correct shape from inference |
||||
inf_sp = _shape_infer(org_sp, tgt_sp) |
||||
|
||||
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: |
||||
new_shape = (cur_sp[0],) + tgt_sp[1:] |
||||
res = self.view(*new_shape) |
||||
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: |
||||
new_shape = tgt_sp[:-1] + (cur_sp[-1],) |
||||
res = self.view(*new_shape) |
||||
else: |
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec()) |
||||
return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape), |
||||
spec=ColoTensorSpec(self.get_process_group())) |
||||
|
||||
return ColoTensor.from_torch_tensor(tensor=res, |
||||
spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec)) |
||||
|
||||
|
||||
@colo_op_impl(torch.Tensor.size) |
||||
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: |
||||
size = self.size_global() |
||||
if dim is None: |
||||
return size |
||||
else: |
||||
return size[dim] |
||||
|
Loading…
Reference in new issue