@ -3,9 +3,9 @@ import warnings
from functools import reduce
import torch
from colossalai . auto_parallel . tensor_shard . deprecated . _utils import \
ignore_sharding_exception
from colossalai . auto_parallel . tensor_shard . deprecated . sharding_strategy import ( ShardingStrategy , StrategiesVector )
from colossalai . auto_parallel . tensor_shard . deprecated . _utils import ignore_sharding_exception
from colossalai . auto_parallel . tensor_shard . deprecated . sharding_strategy import ShardingStrategy , StrategiesVector
from . operator_handler import OperatorHandler
@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler):
Argument :
sharding_size_forward ( int ) : The forward activation will be divided
into sharding_size_forward number partions .
sharding_size_backward_activation ( int ) : The backward activation will
sharding_size_backward_activation ( int ) : The backward activation will
be divided into sharding_size_backward_activation number partions .
sharding_size_weight ( int ) : The backward weight will be divided
into sharding_size_weight number partions .
Return :
memory_cost ( Tuple [ float ] ) : Memory cost per device with this
memory_cost ( Tuple [ float ] ) : Memory cost per device with this
specific strategy , the first element of this tuple is forward
memory cost , and the second element of this tuple is backward
memory cost .
memory_cost_forward ( float ) : Memory cost of forward activation per
memory_cost_forward ( float ) : Memory cost of forward activation per
device with this specific strategy .
memory_cost_backward_activation ( float ) : Memory cost of backward activation
memory_cost_backward_activation ( float ) : Memory cost of backward activation
per device with this specific strategy .
'''
# compute the memory cost of this strategy
@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler):
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector ( node = nodes [ 0 ] , in_nodes = [ nodes [ 1 ] , 2 ] , strategies = strategies_for_input )
setattr ( nodes [ 1 ] , ' strategies_vector ' , strategies_vector_for_input )
strategies_vector = StrategiesVector ( node = nodes [ 2 ] , in_nodes = [ nodes [ 1 ] , ] )
conv_handler = ConvHandler ( input_node = nodes [ 1 ] , input_index = 0 , weight = dict ( gm . named_modules ( ) ) [ nodes [ 2 ] . name ] . weight , output_node = nodes [ 2 ] ,
device_mesh = device_mesh , strategies_vector = strategies_vector , shape_consistency_manager = shape_consistency_manager )
conv_handler . register_strategy_into_strategies_vector ( )
for strategy in conv_handler . strategies_vector :
print ( f ' { strategy . name } : compute_cost is { strategy . compute_cost } , communication_cost is { strategy . communication_cost } , memory_cost is { strategy . memory_cost } , resharding_costs is { strategy . resharding_costs } ' )
Output :
S0S1 = S0R x RS1 : compute_cost is 8856576 , communication_cost is 0 , memory_cost is 492032.0 , resharding_costs is { mul : [ 0 , 32769.001 , 131074.2 , 0 , 32769.1 , 131074.2 , 98307.201 ] }
S1S0 = S1R x RS0 : compute_cost is 8856576 , communication_cost is 0 , memory_cost is 492032.0 , resharding_costs is { mul : [ 0 , 131074.2 , 32769.001 , 131074.2 , 98307.201 , 0 , 32769.1 ] }