From 95ac4f88eac53a30f399926bdd7e39c345fb7da8 Mon Sep 17 00:00:00 2001 From: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Date: Tue, 8 Nov 2022 17:09:16 +0800 Subject: [PATCH] [NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py code style (#1829) Co-authored-by: siqi --- .../deprecated/op_handler/conv_handler.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py index c41ca6370..d8952040d 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py @@ -3,9 +3,9 @@ import warnings from functools import reduce import torch -from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ - ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler @@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler): Argument: sharding_size_forward(int): The forward activation will be divided into sharding_size_forward number partions. - sharding_size_backward_activation(int): The backward activation will + sharding_size_backward_activation(int): The backward activation will be divided into sharding_size_backward_activation number partions. sharding_size_weight(int): The backward weight will be divided into sharding_size_weight number partions. Return: - memory_cost(Tuple[float]): Memory cost per device with this + memory_cost(Tuple[float]): Memory cost per device with this specific strategy, the first element of this tuple is forward memory cost, and the second element of this tuple is backward memory cost. - memory_cost_forward(float): Memory cost of forward activation per + memory_cost_forward(float): Memory cost of forward activation per device with this specific strategy. - memory_cost_backward_activation(float): Memory cost of backward activation + memory_cost_backward_activation(float): Memory cost of backward activation per device with this specific strategy. ''' # compute the memory cost of this strategy @@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler): # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - + strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ]) conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) conv_handler.register_strategy_into_strategies_vector() for strategy in conv_handler.strategies_vector: print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') - + Output: S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}