From 177c374401c7c507c3b4b5ec31cf7830c9b75c50 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 23 Jun 2022 16:35:05 +0800 Subject: [PATCH] remove gather out in parallel action (#1163) --- colossalai/nn/_ops/addmm.py | 17 ++++----------- colossalai/nn/_ops/embedding.py | 4 +--- colossalai/nn/_ops/embedding_bag.py | 5 ++--- colossalai/nn/_ops/linear.py | 5 +---- colossalai/tensor/colo_tensor.py | 32 ++++++++++++++++++++++++++--- colossalai/tensor/spec.py | 5 ++--- tests/test_tensor/test_linear_tp.py | 1 + tests/test_tensor/test_model.py | 6 +++--- 8 files changed, 43 insertions(+), 32 deletions(-) diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index bcfdd72ae..7d7f84a0d 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -37,10 +37,10 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]), ParallelAction(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - if parallel_action.gather_out: - # All-Gather(Output) - output = output.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group())) - return output + + # TODO(jiaruifang) addam is special case + # since gpt call view after the Op. + return output.to_replicate() def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, @@ -62,11 +62,6 @@ def colo_addmm(input_tensor: GeneralTensor, """ input_tensor, mat1, mat2 = tuple(map(convert_to_colo_tensor, (input_tensor, mat1, mat2))) - # building the computing graph, inputs -> op - # if GraphGlobalEnv().graph_building: - # cur_op_node = GraphOpNode('linear', [weight, bias]) - # cur_op_node.add_prev_tensor(input_tensor) - # Add communication logic before and after linear call. ret_tensor = None if not mat2.has_spec(): # No Model Parallel Applied @@ -84,8 +79,4 @@ def colo_addmm(input_tensor: GeneralTensor, else: raise NotImplementedError - # building the computing graph, op -> output - # if GraphGlobalEnv().graph_building: - # cur_op_node.add_post_tensor(ret_tensor) - return ret_tensor diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 18b59eb34..9fb053e1d 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -30,9 +30,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), ParallelAction(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - if weight.spec.parallel_action.gather_out: - output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) - return output + return output.to_replicate() def colo_embedding_1Drow(input_tensor: ColoTensor, diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py index bf9dcbdd1..35221d213 100644 --- a/colossalai/nn/_ops/embedding_bag.py +++ b/colossalai/nn/_ops/embedding_bag.py @@ -36,9 +36,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), ParallelAction(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - if weight.spec.parallel_action.gather_out: - output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) - return output + + return output.to_replicate() def colo_embedding_bag_1d(tp_mode: str, diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index eccb1b467..2ec53bec3 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -42,10 +42,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), ParallelAction(ComputePattern.TP1D))) - if parallel_action.gather_out: - # All-Gather(Output) - output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) - return output + return output.to_replicate() def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 377e6b575..f507aebb7 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -92,10 +92,13 @@ class ColoTensor(torch.Tensor): def __repr__(self): return f'ColoTensor: {super().__repr__()}' - def is_model_data(self) -> bool: - return self._type == TensorType.MODEL - def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None: + """_convert_to_dist_spec + Note the function will not handle the logic of backward propagation! + It is used during model tensor initializations as an internal function. + Args: + dist_spec (_DistSpec): the target dist. spec. + """ with DistSpecManager.no_grad(): self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) self._tensor_spec.dist_spec = dist_spec @@ -106,6 +109,19 @@ class ColoTensor(torch.Tensor): ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) return ColoTensor.from_torch_tensor(ret, tensor_spec) + def to_replicate_(self): + """to_replicate_ + an inline member function, converting dist spec of the tensor to REPLICATE + """ + self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, distspec.replicate()) + self._tensor_spec.dist_spec = distspec.replicate() + + def to_replicate(self) -> 'ColoTensor': + """to_replicate + converting dist spec of the tensor to REPLICATE + """ + return self.convert_to_dist_spec(distspec.replicate(self.spec.get_process_group())) + @staticmethod def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': tensor = tensor.as_subclass(ColoTensor) @@ -121,3 +137,13 @@ class ColoTensor(torch.Tensor): tensor = ColoTensor(data, spec=copy(self.spec)) memo[id(self)] = tensor return tensor + + # TODO(jiaruifang) a patch for gpt test. + # We need to override the member function must operate on a replicated tensor + # def view(self, *args, **kwargs): + # self.data = DistSpecManager.handle_trans_spec(self, + # self.spec.dist_spec, + # distspec.replicate(self.spec.get_process_group())) + # # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group()) + # self.data.view(*args, **kwargs) + # return ColoTensor.from_torch_tensor(self.data) diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index 5de6b2c85..5e238dd61 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -13,13 +13,12 @@ class ComputePattern(Enum): class ParallelAction(object): - def __init__(self, compute_pattern: ComputePattern, gather_out: bool = True) -> None: + def __init__(self, compute_pattern: ComputePattern) -> None: assert isinstance(compute_pattern, ComputePattern) self.compute_pattern = compute_pattern - self.gather_out = gather_out def __repr__(self): - return f'compute pattern: {self.compute_pattern}, gather out: {self.gather_out}' + return f'compute pattern: {self.compute_pattern}' class TensorSpec(object): diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index f673687ea..834fcd4c1 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -41,6 +41,7 @@ def run_with_spec(spec_init_func): x = torch.rand(2, 4).cuda() out = model(x) colo_out = F.linear(x, weight, bias) + colo_out = colo_out.to_replicate() assert tensor_equal(out, colo_out) grad = torch.rand_like(out) out.backward(grad) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index ba30549bc..10ca53121 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -26,10 +26,10 @@ def init_1d_row_linear(weight): weight.set_spec(spec) -def init_1d_col_linear(weight, gather_out=True): +def init_1d_col_linear(weight): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - ParallelAction(ComputePattern.TP1D, gather_out=gather_out)) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -98,7 +98,7 @@ def run_1d_hybrid_tp(model_name): if 'proj2' in name and 'weight' in name: init_1d_row_linear(p) if 'classifier' in name and ('weight' in name or 'bias' in name): - init_1d_col_linear(p, gather_out=False) + init_1d_col_linear(p) model = model.cuda() colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)