mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py code style (#1829)
Co-authored-by: siqi <siqi@siqis-MacBook-Pro.local>pull/1849/head
parent
5da03c936d
commit
95ac4f88ea
|
@ -3,9 +3,9 @@ import warnings
|
|||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
|
@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler):
|
|||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
|
@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler):
|
|||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
|
||||
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
|
||||
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
|
||||
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
|
||||
conv_handler.register_strategy_into_strategies_vector()
|
||||
for strategy in conv_handler.strategies_vector:
|
||||
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
|
||||
|
||||
|
||||
Output:
|
||||
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
||||
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
||||
|
|
Loading…
Reference in New Issue