[autoparallel] fix bias addition module (#1800)

pull/1828/head
YuliangLiu0306 2022-11-08 16:21:25 +08:00 committed by GitHub
parent 6e9730d7ab
commit f6032ddb17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 438 additions and 20 deletions

View File

@ -93,7 +93,7 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
user_node.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
@ -118,10 +118,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
if op_data.type == OperationDataType.PARAM:
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
if comm_action.key_for_kwarg is not None:
if op_data.type == OperationDataType.OUTPUT:
comm_object = node
elif comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
@ -140,7 +142,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
node.args = tuple(new_args)
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
@ -163,7 +165,6 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
return gm

View File

@ -5,7 +5,12 @@ import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationDataType,
ShardingStrategy,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
@ -42,7 +47,32 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
new_sharding_spec = target_sharding_specs[0]
user_node = node.strategies_vector.successor_nodes[0]
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
origin_node_sharding_spec_dict[index] = new_sharding_spec
origin_pending_strategy = node.best_strategy
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
new_sharding_specs = origin_pending_strategy.sharding_specs
new_sharding_specs[origin_op_data] = new_sharding_spec
new_communication_actions = {}
if op_data_in_user in user_strategy.communication_actions:
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
new_communication_action.arg_index = 0
new_communication_actions[origin_op_data] = new_communication_action
new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence),
sharding_specs=new_sharding_specs,
compute_cost=origin_pending_strategy.compute_cost,
communication_cost=origin_pending_strategy.communication_cost,
memory_cost=origin_pending_strategy.memory_cost,
communication_actions=new_communication_actions)
setattr(node, 'best_strategy', new_strategy)
setattr(node, 'sharding_spec', new_sharding_spec)
comm_action_dict = {}
for op_data, comm_action in node.best_strategy.communication_actions.items():
comm_action_dict[op_data.name] = comm_action
@ -111,6 +141,43 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
if node.op == 'get_attr':
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
if attr_len == 1:
target_module = root
target = getattr(root, atoms[0])
else:
target_module = root.get_submodule(atoms[-2])
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
setattr(target, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
target_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
target_sharding_spec).detach().clone())
else:
target_sharded = target
setattr(target_module, atoms[-1], target_sharded)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec):
def hook_fn(grad):
_all_reduce(grad, comm_spec)
param.register_hook(hook_fn)
wrapper(target_sharded, comm_spec_to_use)
return gm

View File

@ -29,8 +29,15 @@ class ReshapeHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
type=data_type,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

View File

@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
else:
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
@ -104,7 +104,11 @@ class ReshapeGenerator(FollowingStrategyGenerator):
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
communication_action_mapping["input"] = input_comm_action
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)

View File

@ -43,7 +43,7 @@ class BiasAdditionConv(BiasAdditionModule):
bias_shape[0] = -1
bias_reshape_node_kind = 'call_method'
bias_reshape_node_target = 'view'
bias_reshape_node_args = (self.bias_proxy, bias_shape)
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
bias_reshape_node_args, {})
return bias_reshape_proxy

View File

@ -58,7 +58,7 @@ def torch_bmm(input, mat2, *, out=None):
@meta_patched_function.register(torch.nn.functional.linear)
def torch_linear(input, mat2, *, out=None):
def torch_linear(input, mat2, bias=None, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
output_shape = list(input.shape)

View File

@ -0,0 +1,172 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
class LinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
def forward(self, x):
x = self.linear(x)
x = x * 2
return x
class ConvModel(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=bias)
def forward(self, x):
x = self.conv(x)
x = x * 2
return x
def check_linear_module(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel(4, 8).cuda()
input = torch.rand(4, 4).cuda()
output_compare = model(input)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')})
# def forward(self, x : torch.Tensor):
# linear_weight = self.linear.weight
# linear_bias = self.linear.bias
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
linear_node = node_list[3]
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert_close(output, output_compare)
def check_conv_module(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel(3, 6, 2).cuda()
input = torch.rand(4, 3, 64, 64).cuda()
output_compare = model(input)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
# def forward(self, x : torch.Tensor):
# conv_weight = self.conv.weight
# conv_bias = self.conv.bias
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
# add = conv2d + view; conv2d = view = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)
conv_node = node_list[3]
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert_close(output, output_compare)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bias_addition_module():
world_size = 4
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port())
mp.spawn(run_func_linear, nprocs=world_size)
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
mp.spawn(run_func_conv, nprocs=world_size)
if __name__ == '__main__':
test_bias_addition_module()

View File

@ -0,0 +1,146 @@
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
class LinearModule(torch.nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
x = self.linear(x)
return x
def check_linear_module_handler(rank, bias, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModule(16, 32, bias=bias).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(2, 2, 4, 16).cuda()
# the index of linear node in computation graph
node_index = 3
# strategy number of linear node
strategy_number = 10
# construct input args
input_args = [input]
# construct meta arg names
meta_arg_names = ['x']
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=input_args,
meta_arg_names=meta_arg_names,
node_type='bias_module')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"x": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)
# build handler
handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "x"
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([16, 16])
assert mapping['other'].name == "linear_weight"
assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert 'bias' not in mapping
assert mapping['output'].name == "linear"
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# one strategy will be converted to different physical sharding spec
assert len(strategy_name_list) > 8
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
for strategy in strategies_vector:
strategy: ShardingStrategy
input_sharding_spec = strategy.get_sharding_spec_by_name('x')
weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight')
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(bias=True):
world_size = 4
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__':
test_linear_handler()

View File

@ -7,6 +7,9 @@ from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import to_global
@ -56,7 +59,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
strategy_number: int,
input_args: List[torch.Tensor],
meta_arg_names: List[str],
input_kwargs: Dict[str, torch.Tensor] = {}):
input_kwargs: Dict[str, torch.Tensor] = {},
node_type: str = 'normal'):
for strategy_index in range(strategy_number):
print(f'#strategy_index: {strategy_index}')
# We need to copy the model to avoid do backward more than once in same graph
@ -79,11 +83,21 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
target_node = list(graph.nodes)[node_index]
# solution construction
solution_len = len(strategies_constructor.leaf_strategies)
solution = [0] * solution_len
solution[node_index] = strategy_index
if node_type == 'normal':
solution_len = len(strategies_constructor.leaf_strategies)
solution = [0] * solution_len
solution[node_index] = strategy_index
else:
node_vector = strategies_constructor.leaf_strategies[node_index]
strategy_to_keep = node_vector[strategy_index]
node_vector = [strategy_to_keep]
# solution construction
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
@ -110,11 +124,18 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# extract the strategy used in this iter
strategy_in_use = target_node.strategies_vector[strategy_index]
param_to_shard_dict = dict(model_to_shard.named_parameters())
param_to_shard_dict = dict(gm.named_parameters())
param_to_compare_dict = dict(model_to_compare.named_parameters())
for name in param_to_shard_dict.keys():
param_name = name.split('.')[-1]
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
if node_type == 'normal':
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
else:
if 'weight' in name:
param_sharding_spec = list(graph.nodes)[4].sharding_spec
elif 'bias' in name:
param_sharding_spec = list(graph.nodes)[5].sharding_spec
grad_sharded = param_to_shard_dict[name].grad
grad_to_compare = param_to_compare_dict[name].grad
global_grad = to_global(grad_sharded, param_sharding_spec)