@ -4,7 +4,6 @@ import warnings
import torch
from colossalai . auto_parallel . solver . sharding_strategy import ShardingStrategy , StrategiesVector
from . operator_handler import OperatorHandler
from . . _utils import generate_sharding_spec
__all__ = [ ' ConvHandler ' ]
@ -109,15 +108,13 @@ class ConvHandler(OperatorHandler):
name = f ' S { mesh_dim_0 } S { mesh_dim_1 } = S { mesh_dim_0 } R x RS { mesh_dim_1 } '
dim_partition_dict_for_input = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { 1 : [ mesh_dim_1 ] }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { 0 : [ mesh_dim_0 ] , 1 : [ mesh_dim_1 ] }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )
@ -158,15 +155,13 @@ class ConvHandler(OperatorHandler):
name = f ' S { mesh_dim_0 } R = S { mesh_dim_0 } R x RR '
dim_partition_dict_for_input = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )
@ -205,15 +200,13 @@ class ConvHandler(OperatorHandler):
name = f ' S { mesh_dim_0 } R = S { mesh_dim_0 } S { mesh_dim_1 } x S { mesh_dim_1 } R '
dim_partition_dict_for_input = { 0 : [ mesh_dim_0 ] , 1 : [ mesh_dim_1 ] }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )
@ -252,15 +245,13 @@ class ConvHandler(OperatorHandler):
name = f ' RS { mesh_dim_1 } = RS { mesh_dim_0 } x S { mesh_dim_0 } S { mesh_dim_1 } '
dim_partition_dict_for_input = { 1 : [ mesh_dim_0 ] }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { 0 : [ mesh_dim_0 ] , 1 : [ mesh_dim_1 ] }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { 1 : [ mesh_dim_1 ] }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )
@ -296,15 +287,13 @@ class ConvHandler(OperatorHandler):
name = f ' RR = RS { mesh_dim_0 } x S { mesh_dim_0 } R '
dim_partition_dict_for_input = { 1 : [ mesh_dim_0 ] }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { 0 : [ mesh_dim_0 ] }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )
@ -340,15 +329,13 @@ class ConvHandler(OperatorHandler):
name = f ' RS { mesh_dim_0 } = RR x RS { mesh_dim_0 } '
dim_partition_dict_for_input = { }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { 1 : [ mesh_dim_0 ] }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { 1 : [ mesh_dim_0 ] }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )
@ -384,15 +371,13 @@ class ConvHandler(OperatorHandler):
name = f ' RR = RR x RR '
dim_partition_dict_for_input = { }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )
@ -426,15 +411,13 @@ class ConvHandler(OperatorHandler):
name = f ' S { mesh_dim_0 } { mesh_dim_1 } R = S { mesh_dim_0 } { mesh_dim_1 } R x RR '
dim_partition_dict_for_input = { 0 : [ mesh_dim_0 , mesh_dim_1 ] }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { 0 : [ mesh_dim_0 , mesh_dim_1 ] }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )
@ -475,15 +458,13 @@ class ConvHandler(OperatorHandler):
name = f ' RR = RS { mesh_dim_0 } { mesh_dim_1 } x S { mesh_dim_0 } { mesh_dim_1 } R '
dim_partition_dict_for_input = { 1 : [ mesh_dim_0 , mesh_dim_1 ] }
sharding_spec_for_input = generate_sharding_spec ( self . input_data , self . device_mesh ,
dim_partition_dict_for_input )
sharding_spec_for_input = self . _generate_sharding_spec ( self . input_data , dim_partition_dict_for_input )
dim_partition_dict_for_weight = { 0 : [ mesh_dim_0 , mesh_dim_1 ] }
sharding_spec_for_weight = generate_sharding_spec ( self . weight , self . device_mesh , dim_partition_dict_for_weight )
sharding_spec_for_weight = self . _ generate_sharding_spec( self . weight , dim_partition_dict_for_weight )
dim_partition_dict_for_output = { }
sharding_spec_for_output = generate_sharding_spec ( self . output_data , self . device_mesh ,
dim_partition_dict_for_output )
sharding_spec_for_output = self . _generate_sharding_spec ( self . output_data , dim_partition_dict_for_output )
# generate resharding cost for this strategy
resharding_costs = self . _generate_resharding_costs ( [ sharding_spec_for_input , sharding_spec_for_weight ] )