From a98319f0231d3a5c241747ca8ccd98813cfedaae Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 7 Jul 2022 18:09:18 +0800 Subject: [PATCH] [tensor] torch function return colotensor (#1229) --- colossalai/nn/_ops/element_wise.py | 6 ++--- colossalai/nn/_ops/linear.py | 5 +++-- colossalai/nn/_ops/loss.py | 2 +- colossalai/tensor/colo_tensor.py | 36 ++++++++++++++++++++++-------- colossalai/tensor/process_group.py | 8 +++---- tests/test_tensor/test_model.py | 1 + tests/test_tensor/test_op.py | 3 ++- tests/test_tensor/test_tensor.py | 2 ++ 8 files changed, 42 insertions(+), 21 deletions(-) diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py index 9409b8081..b7b6b7c9c 100644 --- a/colossalai/nn/_ops/element_wise.py +++ b/colossalai/nn/_ops/element_wise.py @@ -17,14 +17,12 @@ def register_elementwise_op(op): """ output = op(input_tensor, *args, **kwargs) - if isinstance(input_tensor, ColoTensor): if not isinstance(output, torch.Tensor): raise NotImplementedError return ColoTensor.from_torch_tensor(output, - spec=ColoTensorSpec(input_tensor.process_group, - dist_attr=input_tensor.dist_spec, - compute_attr=input_tensor.compute_spec)) + spec=ColoTensorSpec(input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) # Tensor op diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index 9a77b259a..dea8c1484 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -22,7 +22,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' output = output + bias - pg = input_tensor.get_process_group() + pg = weight.get_process_group() output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate())) return output @@ -61,6 +61,7 @@ def colo_linear_imp(input_tensor: GeneralTensor, """ assert isinstance(weight, ColoTensor) pg = weight.get_process_group() + assert pg input_tensor = convert_to_colo_tensor(input_tensor, pg) bias = convert_to_colo_tensor(bias, pg) # input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) @@ -70,7 +71,7 @@ def colo_linear_imp(input_tensor: GeneralTensor, if not weight.has_compute_spec(): # No Model Parallel Applied assert weight.is_replicate(), 'Invalid weight spec for native Linear op' assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op' - ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias)) + ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg)) elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()): mode = 'row' diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py index 7c47daca8..c17406c18 100644 --- a/colossalai/nn/_ops/loss.py +++ b/colossalai/nn/_ops/loss.py @@ -35,7 +35,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor, elif input_tensor.has_compute_spec(): # Single Model Parallel Applied if input_tensor.is_shard_1dcol(): output = VocabParallelCrossEntropyLoss1D()(input_tensor, target) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) + return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate() else: raise NotImplementedError else: diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index a01d0b7ac..743401468 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -11,12 +11,30 @@ from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from typing import Optional -def _check_output(output): - if not isinstance(output, torch.Tensor): - raise RuntimeError +def _convert_output(output, pg: ProcessGroup): + if type(output) == torch.Tensor: + return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) elif isinstance(output, (list, tuple)): - output = type(output)(_check_output(o) for o in output) - return output + return type(output)(_convert_output(o, pg) for o in output) + else: + return output + + +def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup: + for elem in args: + if isinstance(elem, ColoTensor): + pg = elem.get_process_group() + return pg + elif isinstance(elem, (list, tuple)): + pg = _scan_for_pg_from_args(elem, {}) + if pg is not None: + return pg + print(type(elem), elem, isinstance(elem, (list, tuple))) + for k, v in kwargs: + if isinstance(v, ColoTensor): + pg = v.get_process_group() + return pg + return None class ColoTensor(torch.Tensor): @@ -108,6 +126,7 @@ class ColoTensor(torch.Tensor): dist_spec (_DistSpec): target dist spec. """ assert isinstance(dist_spec, _DistSpec) + assert self.process_group self._convert_to_dist_spec(dist_spec) def set_tensor_spec(self, dist_spec, compute_spec): @@ -136,12 +155,11 @@ class ColoTensor(torch.Tensor): if func in get_default_nowrap_functions(): return ret else: - # TODO(jiaruifang) its parallel Op's duty to convert output activations - return ret - # return _check_output(ret) + pg = _scan_for_pg_from_args(args, kwargs) + return _convert_output(ret, pg) def __repr__(self): - return f'ColoTensor: {super().__repr__()}' + return f'ColoTensor: {super().__repr__()}\n dist spec: {self.dist_spec}\n process group: {self.process_group}' def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None: """_convert_to_dist_spec diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index 3c959395c..90337864f 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -19,6 +19,10 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta): pg_key = (backend, rank_tuple) if pg_key not in self.dict: + + self.logger = get_dist_logger('ProcessGroup') + self.logger.info(f'NCCL initialize TP group on {rank_list}', ranks=[0]) + self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend) return self.dict[pg_key] @@ -92,10 +96,6 @@ class ProcessGroup: self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') - self.logger = get_dist_logger('ProcessGroup') - self.logger.info( - f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') - self._has_cpu_groups = False self._cpu_dp_process_group = None self._cpu_tp_process_group = None diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 8553d0978..031bdc25f 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -113,6 +113,7 @@ def run_1d_hybrid_tp(model_name): torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) + # Bcast rank0 data to all processes if criterion: output = model(data) diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 9ac1968da..86d817c7c 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -39,7 +39,7 @@ def check_spec_eq(tensor, other): assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor) for k in dir(tensor.dist_spec): if not k.startswith('__'): - assert hasattr(other.dist_spec, k) + assert hasattr(other.dist_spec, k), f"{k}" assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k) @@ -48,6 +48,7 @@ def check_element_wise_ops(): pg = ProcessGroup(tp_degree=world_size) t = torch.rand(2, 2) x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()]))) + check_spec_eq(x, x.cuda()) assert torch.equal(x.cuda(), t.cuda()) check_spec_eq(x, torch.abs(x)) diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index 9ed267301..c77ba9d59 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -49,6 +49,8 @@ def _run_operand(): t_ref_res = t_ref + t_ref t_res = t + t + + assert isinstance(t_res, ColoTensor) assert torch.allclose(t_ref_res, t_res)