mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fixed wrong sharding strategy in conv handler (#1747)
* [autoparallel] fixed wrong sharding strategy in conv handler * polish codepull/1748/head
parent
8b8937d901
commit
474111ecb5
|
@ -4,6 +4,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import transpose_partition_dim
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||
|
@ -55,20 +56,7 @@ class ConvModuleHandler(ModuleHandler):
|
|||
"""
|
||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||
if op_data.name == "weight":
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
|
||||
# switch first and second dim of the conv module weight
|
||||
first_dim_partition = dim_partition_dict.pop(1, None)
|
||||
second_dim_partition = dim_partition_dict.pop(0, None)
|
||||
|
||||
if first_dim_partition:
|
||||
dim_partition_dict[0] = first_dim_partition
|
||||
|
||||
if second_dim_partition:
|
||||
dim_partition_dict[1] = second_dim_partition
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
|
||||
transpose_partition_dim(sharding_spec, 0, 1)
|
||||
return strategy
|
||||
|
||||
|
||||
|
@ -110,7 +98,7 @@ class ConvFunctionHandler(NodeHandler):
|
|||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
if "bias" in self.node.kwargs:
|
||||
if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
|
||||
# check if the other operand is a parameter
|
||||
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
|
@ -128,19 +116,5 @@ class ConvFunctionHandler(NodeHandler):
|
|||
"""
|
||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||
if op_data.name == str(self.node.args[1]):
|
||||
assert op_data.logical_shape != op_data.data.shape
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
|
||||
# switch first and second dim of the conv function weight
|
||||
first_dim_partition = dim_partition_dict.pop(1, None)
|
||||
second_dim_partition = dim_partition_dict.pop(0, None)
|
||||
|
||||
if first_dim_partition:
|
||||
dim_partition_dict[0] = first_dim_partition
|
||||
|
||||
if second_dim_partition:
|
||||
dim_partition_dict[1] = second_dim_partition
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
transpose_partition_dim(sharding_spec, 0, 1)
|
||||
return strategy
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List, Union
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import tranpose_partition_dim, update_partition_dim
|
||||
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
||||
|
||||
|
@ -30,7 +30,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
|
|||
op_data = strategy.get_op_data_by_name(weight_name)
|
||||
assert op_data.logical_shape != op_data.data.shape, \
|
||||
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
|
||||
tranpose_partition_dim(sharding_spec, 0, -1)
|
||||
transpose_partition_dim(sharding_spec, 0, -1)
|
||||
return strategy
|
||||
|
||||
|
||||
|
|
|
@ -5,13 +5,13 @@ from .sharding import (
|
|||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size,
|
||||
tranpose_partition_dim,
|
||||
transpose_partition_dim,
|
||||
update_partition_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
|
|
@ -68,4 +68,5 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
|||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
||||
|
||||
# make sure the entire shape matches the physical tensor shape
|
||||
assert sharding_spec.entire_shape == tensor.shape
|
||||
assert sharding_spec.entire_shape == tensor.shape, \
|
||||
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
|
||||
|
|
|
@ -8,12 +8,12 @@ import torch
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
||||
|
||||
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
"""
|
||||
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
||||
|
||||
|
|
|
@ -5,12 +5,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
|
|||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_module_handler():
|
||||
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta'))
|
||||
@parameterize('bias', [True, False])
|
||||
def test_conv_module_handler(bias):
|
||||
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
|
@ -49,11 +49,12 @@ def test_conv_module_handler():
|
|||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
if bias:
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
|
@ -99,6 +100,24 @@ def test_conv_module_handler():
|
|||
# RS01 = RR x RS01
|
||||
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
||||
|
||||
if bias:
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
|
||||
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
|
||||
|
||||
if bias:
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
|
@ -110,8 +129,8 @@ class ConvModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_function_handler():
|
||||
@parameterize('bias', [True, False])
|
||||
def test_conv_function_handler(bias):
|
||||
model = ConvModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
|
@ -119,18 +138,20 @@ def test_conv_function_handler():
|
|||
# %others : torch.Tensor [#users=1] = placeholder[target=others]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {})
|
||||
# return conv2d
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
"others": torch.rand(16, 4, 3, 3).to('meta'),
|
||||
"bias": torch.rand(16).to('meta')
|
||||
})
|
||||
meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')}
|
||||
if bias:
|
||||
meta_args['bias'] = torch.rand(16).to('meta')
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
conv_mod_node = list(graph.nodes)[3]
|
||||
|
||||
if bias:
|
||||
conv_mod_node = list(graph.nodes)[3]
|
||||
else:
|
||||
conv_mod_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(conv_mod_node)
|
||||
|
||||
# build handler
|
||||
|
@ -157,11 +178,12 @@ def test_conv_function_handler():
|
|||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
if bias:
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['output'].name == "conv2d"
|
||||
assert mapping['output'].data.is_meta
|
||||
|
@ -207,6 +229,24 @@ def test_conv_function_handler():
|
|||
# RS01 = RR x RS01
|
||||
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d')
|
||||
|
||||
if bias:
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
|
||||
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
|
||||
|
||||
if bias:
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_module_handler()
|
||||
|
|
Loading…
Reference in New Issue