[NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py code style (#1829)

Co-authored-by: siqi <siqi@siqis-MacBook-Pro.local>
pull/1849/head
Sze-qq 2 years ago committed by binmakeswell
parent 5da03c936d
commit 95ac4f88ea

@ -3,9 +3,9 @@ import warnings
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler):
Argument: Argument:
sharding_size_forward(int): The forward activation will be divided sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions. into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions. be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions. into sharding_size_weight number partions.
Return: Return:
memory_cost(Tuple[float]): Memory cost per device with this memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward memory cost, and the second element of this tuple is backward
memory cost. memory cost.
memory_cost_forward(float): Memory cost of forward activation per memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy. device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy. per device with this specific strategy.
''' '''
# compute the memory cost of this strategy # compute the memory cost of this strategy
@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler):
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input) strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ]) strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
conv_handler.register_strategy_into_strategies_vector() conv_handler.register_strategy_into_strategies_vector()
for strategy in conv_handler.strategies_vector: for strategy in conv_handler.strategies_vector:
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
Output: Output:
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}

Loading…
Cancel
Save