@ -93,6 +93,45 @@ class ConvHandler(OperatorHandler):
input_shardings = ( sharding_spec_for_input , sharding_spec_for_weight ) )
self . strategies_vector . append ( sharding_strategies )
def split_input_batch ( self , mesh_dim_0 ) :
name = f ' S { mesh_dim_0 } R = S { mesh_dim_0 } R x RR '
dim_partition_dict_for_input = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { }
sharding_spec_for_weight = self . _generate_sharding_spec ( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_ouput = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input ] )
# compute the computation cost of this strategy
bs = self . input_data . shape [ 0 ] / / self . device_mesh . shape [ mesh_dim_0 ]
channel_in = self . input_data . shape [ 1 ]
channel_out = self . weight . shape [ 1 ]
compute_cost = self . _generate_compute_cost ( bs , channel_in , channel_out )
# compute the memory cost of this strategy
dtype = self . input_data . dtype
numel = self . output_data . numel ( )
size_per_elem_bytes = torch . tensor ( [ ] , dtype = dtype ) . element_size ( )
sharding_size = self . device_mesh . shape [ mesh_dim_0 ]
memory_cost = numel * size_per_elem_bytes / sharding_size
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy ( name ,
output_sharding_spec = sharding_spec_for_ouput ,
compute_cost = compute_cost ,
communication_cost = communication_cost ,
memory_cost = memory_cost ,
resharding_costs = resharding_costs ,
input_shardings = ( sharding_spec_for_input , sharding_spec_for_weight ) )
self . strategies_vector . append ( sharding_strategies )
def split_input_both_dim_weight_in_channel ( self , mesh_dim_0 , mesh_dim_1 ) :
name = f ' S { mesh_dim_0 } R = S { mesh_dim_0 } S { mesh_dim_1 } x S { mesh_dim_1 } R '
@ -161,7 +200,45 @@ class ConvHandler(OperatorHandler):
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost of this strategy
communication_cost = self . device_mesh . all_reduce_cost ( memory_cost , mesh_dim_1 )
communication_cost = self . device_mesh . all_reduce_cost ( memory_cost , mesh_dim_0 )
sharding_strategies = ShardingStrategy ( name ,
output_sharding_spec = sharding_spec_for_ouput ,
compute_cost = compute_cost ,
communication_cost = communication_cost ,
memory_cost = memory_cost ,
resharding_costs = resharding_costs ,
input_shardings = ( sharding_spec_for_input , sharding_spec_for_weight ) )
self . strategies_vector . append ( sharding_strategies )
def split_input_in_channel_weight_in_channel ( self , mesh_dim_0 ) :
name = f ' RR = RS { mesh_dim_0 } x S { mesh_dim_0 } R '
dim_partition_dict_for_input = { 1 : [ mesh_dim_0 ] }
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_weight = self . _generate_sharding_spec ( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { }
sharding_spec_for_ouput = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_input )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input ] )
# compute the computation cost of this strategy
bs = self . input_data . shape [ 0 ]
channel_in = self . input_data . shape [ 1 ] / / self . device_mesh . shape [ mesh_dim_0 ]
channel_out = self . weight . shape [ 1 ]
compute_cost = self . _generate_compute_cost ( bs , channel_in , channel_out )
# compute the memory cost of this strategy
dtype = self . input_data . dtype
numel = self . output_data . numel ( )
size_per_elem_bytes = torch . tensor ( [ ] , dtype = dtype ) . element_size ( )
memory_cost = numel * size_per_elem_bytes
# compute the communication cost of this strategy
communication_cost = self . device_mesh . all_reduce_cost ( memory_cost , mesh_dim_0 )
sharding_strategies = ShardingStrategy ( name ,
output_sharding_spec = sharding_spec_for_ouput ,
compute_cost = compute_cost ,
@ -250,6 +327,86 @@ class ConvHandler(OperatorHandler):
input_shardings = ( sharding_spec_for_input , sharding_spec_for_weight ) )
self . strategies_vector . append ( sharding_strategies )
def split_1d_parallel_on_input_batch ( self , mesh_dim_0 , mesh_dim_1 ) :
name = f ' S { mesh_dim_0 } { mesh_dim_1 } R = S { mesh_dim_0 } { mesh_dim_1 } R x RR '
dim_partition_dict_for_input = { 0 : [ mesh_dim_0 , mesh_dim_1 ] }
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { }
sharding_spec_for_weight = self . _generate_sharding_spec ( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { 0 : [ mesh_dim_0 , mesh_dim_1 ] }
sharding_spec_for_ouput = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_input )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input ] )
# compute the computation cost of this strategy
bs = self . input_data . shape [ 0 ] / / ( self . device_mesh . shape [ mesh_dim_0 ] * self . device_mesh . shape [ mesh_dim_1 ] )
channel_in = self . input_data . shape [ 1 ]
channel_out = self . weight . shape [ 1 ]
compute_cost = self . _generate_compute_cost ( bs , channel_in , channel_out )
# compute the memory cost of this strategy
dtype = self . input_data . dtype
numel = self . output_data . numel ( )
size_per_elem_bytes = torch . tensor ( [ ] , dtype = dtype ) . element_size ( )
sharding_size = self . device_mesh . shape [ mesh_dim_0 ] * self . device_mesh . shape [ mesh_dim_1 ]
memory_cost = numel * size_per_elem_bytes / sharding_size
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy ( name ,
output_sharding_spec = sharding_spec_for_ouput ,
compute_cost = compute_cost ,
communication_cost = communication_cost ,
memory_cost = memory_cost ,
resharding_costs = resharding_costs ,
input_shardings = ( sharding_spec_for_input , sharding_spec_for_weight ) )
self . strategies_vector . append ( sharding_strategies )
def split_1d_parallel_on_in_channel ( self , mesh_dim_0 , mesh_dim_1 ) :
name = f ' RR = RS { mesh_dim_0 } { mesh_dim_1 } x S { mesh_dim_0 } { mesh_dim_1 } R '
dim_partition_dict_for_input = { 1 : [ mesh_dim_0 , mesh_dim_1 ] }
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { 0 : [ mesh_dim_0 , mesh_dim_1 ] }
sharding_spec_for_weight = self . _generate_sharding_spec ( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { }
sharding_spec_for_ouput = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_input )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input ] )
# compute the computation cost of this strategy
bs = self . input_data . shape [ 0 ]
channel_in = self . input_data . shape [ 1 ] / / ( self . device_mesh . shape [ mesh_dim_0 ] *
self . device_mesh . shape [ mesh_dim_1 ] )
channel_out = self . weight . shape [ 1 ]
compute_cost = self . _generate_compute_cost ( bs , channel_in , channel_out )
# compute the memory cost of this strategy
dtype = self . input_data . dtype
numel = self . output_data . numel ( )
size_per_elem_bytes = torch . tensor ( [ ] , dtype = dtype ) . element_size ( )
memory_cost = numel * size_per_elem_bytes
# compute communication cost
communication_cost = self . device_mesh . flatten_device_mesh . all_reduce_cost ( memory_cost , 0 )
sharding_strategies = ShardingStrategy ( name ,
output_sharding_spec = sharding_spec_for_ouput ,
compute_cost = compute_cost ,
communication_cost = communication_cost ,
memory_cost = memory_cost ,
resharding_costs = resharding_costs ,
input_shardings = ( sharding_spec_for_input , sharding_spec_for_weight ) )
self . strategies_vector . append ( sharding_strategies )
def register_strategy ( self ) - > StrategiesVector :
'''
Generate every possible strategies for a Conv node , and record all strategies into the strategies_vector .
@ -283,24 +440,34 @@ class ConvHandler(OperatorHandler):
conv_handler = ConvHandler ( input_node = nodes [ 1 ] , input_index = 0 , weight = dict ( gm . named_modules ( ) ) [ nodes [ 2 ] . name ] . weight , output_node = nodes [ 2 ] ,
device_mesh = device_mesh , strategies_vector = strategies_vector , shape_consistency_manager = shape_consistency_manager )
conv_handler . register_strategy_into_strategies_vector ( )
for strategy in conv_handler . strategies_vector . strategies :
for strategy in conv_handler . strategies_vector :
print ( f ' { strategy . name } : compute_cost is { strategy . compute_cost } , communication_cost is { strategy . communication_cost } , memory_cost is { strategy . memory_cost } , resharding_costs is { strategy . resharding_costs } ' )
Output :
S0S1 = S0R x RS1 : compute_cost is 8856576 , communication_cost is 0 , memory_cost is 492032.0 , resharding_costs is { 0 : [ 0 , 32769.001 , 131074.2 , 0 , 32769.1 , 131074.2 , 98307.201 ] }
S1S0 = S1R x RS0 : compute_cost is 8856576 , communication_cost is 0 , memory_cost is 492032.0 , resharding_costs is { 0 : [ 0 , 131074.2 , 32769.001 , 131074.2 , 98307.201 , 0 , 32769.1 ] }
S0R = S0S1 x S1R : compute_cost is 8856576 , communication_cost is 984065.01 , memory_cost is 984064.0 , resharding_costs is { 0 : [ 0 , 65538.002 , 0 , 0 , 0 , 65538.002 , 196614.402 ] }
S1R = S1S0 x S0R : compute_cost is 8856576 , communication_cost is 984065.01 , memory_cost is 984064.0 , resharding_costs is { 0 : [ 0 , 0 , 65538.002 , 65538.002 , 196614.402 , 0 , 0 ] }
RS1 = RS0 x S0S1 : compute_cost is 8856576 , communication_cost is 984065.01 , memory_cost is 984064.0 , resharding_costs is { 0 : [ 0 , 0 , 131074.2 , 32769.001 , 98307.201 , 131074.2 , 32769.1 ] }
RS0 = RS1 x S1S0 : compute_cost is 8856576 , communication_cost is 984065.01 , memory_cost is 984064.0 , resharding_costs is { 0 : [ 0 , 131074.2 , 0 , 131074.2 , 32769.1 , 32769.001 , 98307.201 ] }
RS0 = RR x RS0 : compute_cost is 17713152 , communication_cost is 0 , memory_cost is 984064.0 , resharding_costs is { 0 : [ 0 , 65537.1 , 65537.1 , 65537.1 , 131075.30000000002 , 65537.1 , 131075.30000000002 ] }
RS1 = RR x RS1 : compute_cost is 17713152 , communication_cost is 0 , memory_cost is 984064.0 , resharding_costs is { 0 : [ 0 , 65537.1 , 65537.1 , 65537.1 , 131075.30000000002 , 65537.1 , 131075.30000000002 ] }
RR = RR x RR : compute_cost is 35426304 , communication_cost is 0 , memory_cost is 1968128 , resharding_costs is { 0 : [ 0 , 65537.1 , 65537.1 , 65537.1 , 131075.30000000002 , 65537.1 , 131075.30000000002 ] }
S0S1 = S0R x RS1 : compute_cost is 8856576 , communication_cost is 0 , memory_cost is 492032.0 , resharding_costs is { mul : [ 0 , 32769.001 , 131074.2 , 0 , 32769.1 , 131074.2 , 98307.201 ] }
S1S0 = S1R x RS0 : compute_cost is 8856576 , communication_cost is 0 , memory_cost is 492032.0 , resharding_costs is { mul : [ 0 , 131074.2 , 32769.001 , 131074.2 , 98307.201 , 0 , 32769.1 ] }
S0R = S0R x RR : compute_cost is 17713152 , communication_cost is 0 , memory_cost is 984064.0 , resharding_costs is { mul : [ 0 , 32769.001 , 131074.2 , 0 , 32769.1 , 131074.2 , 98307.201 ] }
S1R = S1R x RR : compute_cost is 17713152 , communication_cost is 0 , memory_cost is 984064.0 , resharding_costs is { mul : [ 0 , 131074.2 , 32769.001 , 131074.2 , 98307.201 , 0 , 32769.1 ] }
S0R = S0S1 x S1R : compute_cost is 8856576 , communication_cost is 984065.01 , memory_cost is 984064.0 , resharding_costs is { mul : [ 0 , 65538.002 , 0 , 0 , 0 , 65538.002 , 196614.402 ] }
S1R = S1S0 x S0R : compute_cost is 8856576 , communication_cost is 984065.01 , memory_cost is 984064.0 , resharding_costs is { mul : [ 0 , 0 , 65538.002 , 65538.002 , 196614.402 , 0 , 0 ] }
RS1 = RS0 x S0S1 : compute_cost is 8856576 , communication_cost is 984065.01 , memory_cost is 984064.0 , resharding_costs is { mul : [ 0 , 0 , 131074.2 , 32769.001 , 98307.201 , 131074.2 , 32769.1 ] }
RS0 = RS1 x S1S0 : compute_cost is 8856576 , communication_cost is 984065.01 , memory_cost is 984064.0 , resharding_costs is { mul : [ 0 , 131074.2 , 0 , 131074.2 , 32769.1 , 32769.001 , 98307.201 ] }
RR = RS0 x S0R : compute_cost is 17713152 , communication_cost is 1968129.01 , memory_cost is 1968128 , resharding_costs is { mul : [ 0 , 0 , 131074.2 , 32769.001 , 98307.201 , 131074.2 , 32769.1 ] }
RR = RS1 x S1R : compute_cost is 17713152 , communication_cost is 1968129.01 , memory_cost is 1968128 , resharding_costs is { mul : [ 0 , 131074.2 , 0 , 131074.2 , 32769.1 , 32769.001 , 98307.201 ] }
RS0 = RR x RS0 : compute_cost is 17713152 , communication_cost is 0 , memory_cost is 984064.0 , resharding_costs is { mul : [ 0 , 65537.1 , 65537.1 , 65537.1 , 131075.30000000002 , 65537.1 , 131075.30000000002 ] }
RS1 = RR x RS1 : compute_cost is 17713152 , communication_cost is 0 , memory_cost is 984064.0 , resharding_costs is { mul : [ 0 , 65537.1 , 65537.1 , 65537.1 , 131075.30000000002 , 65537.1 , 131075.30000000002 ] }
RR = RR x RR : compute_cost is 35426304 , communication_cost is 0 , memory_cost is 1968128 , resharding_costs is { mul : [ 0 , 65537.1 , 65537.1 , 65537.1 , 131075.30000000002 , 65537.1 , 131075.30000000002 ] }
S01R = S01R x RR : compute_cost is 8856576 , communication_cost is 0 , memory_cost is 492032.0 , resharding_costs is { mul : [ 0 , 65538.002 , 262148.4 , 0 , 16385.001 , 262148.4 , 196614.402 ] }
RR = RS01 x S01R : compute_cost is 8856576 , communication_cost is 0 , memory_cost is 1968128 , resharding_costs is { mul : [ 0 , 0 , 262148.4 , 65538.002 , 196614.402 , 262148.4 , 65538.2 ] }
'''
# SS = SR x RS
self . split_input_batch_weight_out_channel ( 0 , 1 )
self . split_input_batch_weight_out_channel ( 1 , 0 )
# SR = SR x RR
self . split_input_batch ( 0 )
self . split_input_batch ( 1 )
# SR = SS x SR
self . split_input_both_dim_weight_in_channel ( 0 , 1 )
self . split_input_both_dim_weight_in_channel ( 1 , 0 )
@ -309,6 +476,10 @@ class ConvHandler(OperatorHandler):
self . split_input_in_channel_weight_both_channel ( 0 , 1 )
self . split_input_in_channel_weight_both_channel ( 1 , 0 )
# RR = RS x SR
self . split_input_in_channel_weight_in_channel ( 0 )
self . split_input_in_channel_weight_in_channel ( 1 )
# RS = RR x RS
self . split_weight_out_channel ( 0 )
self . split_weight_out_channel ( 1 )
@ -316,4 +487,10 @@ class ConvHandler(OperatorHandler):
# RR= RR x RR
self . non_split ( )
# S01R = S01R x RR
self . split_1d_parallel_on_input_batch ( 0 , 1 )
# RR = RS01 x S01R
self . split_1d_parallel_on_in_channel ( 0 , 1 )
return self . strategies_vector