[autoparallel] fixed wrong sharding strategy in conv handler (#1747)

* [autoparallel] fixed wrong sharding strategy in conv handler

* polish code
pull/1748/head
Frank Lee 2022-10-20 16:12:39 +08:00 committed by GitHub
parent 8b8937d901
commit 474111ecb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 60 deletions

View File

@ -4,6 +4,7 @@ import torch
import torch.nn.functional as F
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import transpose_partition_dim
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
@ -55,20 +56,7 @@ class ConvModuleHandler(ModuleHandler):
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight":
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and second dim of the conv module weight
first_dim_partition = dim_partition_dict.pop(1, None)
second_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if second_dim_partition:
dim_partition_dict[1] = second_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
transpose_partition_dim(sharding_spec, 0, 1)
return strategy
@ -110,7 +98,7 @@ class ConvFunctionHandler(NodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.node.kwargs:
if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
@ -128,19 +116,5 @@ class ConvFunctionHandler(NodeHandler):
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and second dim of the conv function weight
first_dim_partition = dim_partition_dict.pop(1, None)
second_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if second_dim_partition:
dim_partition_dict[1] = second_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
transpose_partition_dim(sharding_spec, 0, 1)
return strategy

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import tranpose_partition_dim, update_partition_dim
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
@ -30,7 +30,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
tranpose_partition_dim(sharding_spec, 0, -1)
transpose_partition_dim(sharding_spec, 0, -1)
return strategy

View File

@ -5,13 +5,13 @@ from .sharding import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
tranpose_partition_dim,
transpose_partition_dim,
update_partition_dim,
)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]

View File

@ -68,4 +68,5 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape
assert sharding_spec.entire_shape == tensor.shape, \
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'

View File

@ -8,12 +8,12 @@ import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
"""
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.

View File

@ -5,12 +5,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import parameterize
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_module_handler():
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta'))
@parameterize('bias', [True, False])
def test_conv_module_handler(bias):
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -49,11 +49,12 @@ def test_conv_module_handler():
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
if bias:
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
@ -99,6 +100,24 @@ def test_conv_module_handler():
# RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
if bias:
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
if bias:
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
class ConvModel(nn.Module):
@ -110,8 +129,8 @@ class ConvModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_function_handler():
@parameterize('bias', [True, False])
def test_conv_function_handler(bias):
model = ConvModel()
tracer = ColoTracer()
# graph():
@ -119,18 +138,20 @@ def test_conv_function_handler():
# %others : torch.Tensor [#users=1] = placeholder[target=others]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {})
# return conv2d
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
"others": torch.rand(16, 4, 3, 3).to('meta'),
"bias": torch.rand(16).to('meta')
})
meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')}
if bias:
meta_args['bias'] = torch.rand(16).to('meta')
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[3]
if bias:
conv_mod_node = list(graph.nodes)[3]
else:
conv_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(conv_mod_node)
# build handler
@ -157,11 +178,12 @@ def test_conv_function_handler():
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([16])
if bias:
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "conv2d"
assert mapping['output'].data.is_meta
@ -207,6 +229,24 @@ def test_conv_function_handler():
# RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d')
if bias:
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
if bias:
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
if __name__ == '__main__':
test_conv_module_handler()