mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix bias addition module (#1800)
parent
6e9730d7ab
commit
f6032ddb17
|
@ -93,7 +93,7 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||
# substitute the origin node with shape_consistency_node
|
||||
origin_index_args = new_args.index(node)
|
||||
new_args[origin_index_args] = shape_consistency_node
|
||||
user_node.args = new_args
|
||||
user_node.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with shape_consistency_node
|
||||
new_kwargs[str(node)] = shape_consistency_node
|
||||
|
@ -118,10 +118,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
|||
comm_actions = node.best_strategy.communication_actions
|
||||
for op_data, comm_action in comm_actions.items():
|
||||
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
if comm_action.comm_type == CommType.HOOK:
|
||||
continue
|
||||
if comm_action.comm_type == CommType.BEFORE:
|
||||
if comm_action.key_for_kwarg is not None:
|
||||
if op_data.type == OperationDataType.OUTPUT:
|
||||
comm_object = node
|
||||
elif comm_action.key_for_kwarg is not None:
|
||||
comm_object = node.kwargs[comm_action.key_for_kwarg]
|
||||
else:
|
||||
comm_object = node.args[comm_action.arg_index]
|
||||
|
@ -140,7 +142,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
|||
# substitute the origin node with comm_spec_apply_node
|
||||
new_args = list(node.args)
|
||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||
node.args = new_args
|
||||
node.args = tuple(new_args)
|
||||
|
||||
elif comm_action.comm_type == CommType.AFTER:
|
||||
with mod_graph.inserting_after(node):
|
||||
|
@ -163,7 +165,6 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
|||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,12 @@ import torch
|
|||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import _all_reduce
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
|
@ -42,7 +47,32 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
|||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
target_sharding_specs.append(target_sharding_spec)
|
||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||
|
||||
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
||||
# to the same strategy of the user node.
|
||||
if node.op == 'get_attr':
|
||||
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
|
||||
new_sharding_spec = target_sharding_specs[0]
|
||||
user_node = node.strategies_vector.successor_nodes[0]
|
||||
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
|
||||
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
|
||||
origin_node_sharding_spec_dict[index] = new_sharding_spec
|
||||
origin_pending_strategy = node.best_strategy
|
||||
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
|
||||
new_sharding_specs = origin_pending_strategy.sharding_specs
|
||||
new_sharding_specs[origin_op_data] = new_sharding_spec
|
||||
new_communication_actions = {}
|
||||
if op_data_in_user in user_strategy.communication_actions:
|
||||
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
|
||||
new_communication_action.arg_index = 0
|
||||
new_communication_actions[origin_op_data] = new_communication_action
|
||||
new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence),
|
||||
sharding_specs=new_sharding_specs,
|
||||
compute_cost=origin_pending_strategy.compute_cost,
|
||||
communication_cost=origin_pending_strategy.communication_cost,
|
||||
memory_cost=origin_pending_strategy.memory_cost,
|
||||
communication_actions=new_communication_actions)
|
||||
setattr(node, 'best_strategy', new_strategy)
|
||||
setattr(node, 'sharding_spec', new_sharding_spec)
|
||||
comm_action_dict = {}
|
||||
for op_data, comm_action in node.best_strategy.communication_actions.items():
|
||||
comm_action_dict[op_data.name] = comm_action
|
||||
|
@ -111,6 +141,43 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
|||
for name, buffer_sharded in sharded_buffer_dict.items():
|
||||
setattr(target_module, name, buffer_sharded.detach().clone())
|
||||
|
||||
if node.op == 'get_attr':
|
||||
root = node.graph.owning_module
|
||||
atoms = node.target.split(".")
|
||||
attr_len = len(atoms)
|
||||
if attr_len == 1:
|
||||
target_module = root
|
||||
target = getattr(root, atoms[0])
|
||||
else:
|
||||
target_module = root.get_submodule(atoms[-2])
|
||||
target = getattr(target_module, atoms[-1])
|
||||
|
||||
target_sharding_spec = node.sharding_spec
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
|
||||
setattr(target, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
target_sharded = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
else:
|
||||
target_sharded = target
|
||||
setattr(target_module, atoms[-1], target_sharded)
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
# register hook to the parameters
|
||||
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(target_sharded, comm_spec_to_use)
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
@ -29,8 +29,15 @@ class ReshapeHandler(NodeHandler):
|
|||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
type=data_type,
|
||||
data=self.node.args[0]._meta_data)
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
|
|||
arg_index=0)
|
||||
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
|
||||
|
||||
else:
|
||||
elif len(total_mesh_dim_list) >= 2:
|
||||
source_spec = sharding_spec_mapping["input"]
|
||||
target_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=source_spec.entire_shape,
|
||||
|
@ -104,7 +104,11 @@ class ReshapeGenerator(FollowingStrategyGenerator):
|
|||
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
|
||||
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
|
||||
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
else:
|
||||
input_comm_action = None
|
||||
|
||||
if input_comm_action is not None:
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
|
|
@ -43,7 +43,7 @@ class BiasAdditionConv(BiasAdditionModule):
|
|||
bias_shape[0] = -1
|
||||
bias_reshape_node_kind = 'call_method'
|
||||
bias_reshape_node_target = 'view'
|
||||
bias_reshape_node_args = (self.bias_proxy, bias_shape)
|
||||
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
|
||||
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
|
||||
bias_reshape_node_args, {})
|
||||
return bias_reshape_proxy
|
||||
|
|
|
@ -58,7 +58,7 @@ def torch_bmm(input, mat2, *, out=None):
|
|||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.linear)
|
||||
def torch_linear(input, mat2, *, out=None):
|
||||
def torch_linear(input, mat2, bias=None, *, out=None):
|
||||
if out is not None:
|
||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||
output_shape = list(input.shape)
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = x * 2
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x * 2
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def check_linear_module(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearModel(4, 8).cuda()
|
||||
input = torch.rand(4, 4).cuda()
|
||||
output_compare = model(input)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
|
||||
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
|
||||
# return mul
|
||||
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')})
|
||||
# def forward(self, x : torch.Tensor):
|
||||
# linear_weight = self.linear.weight
|
||||
# linear_bias = self.linear.bias
|
||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
||||
# add = linear + linear_bias; linear = linear_bias = None
|
||||
# mul = add * 2; add = None
|
||||
# return mul
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
linear_node = node_list[3]
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
|
||||
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
assert_close(output, output_compare)
|
||||
|
||||
|
||||
def check_conv_module(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = ConvModel(3, 6, 2).cuda()
|
||||
input = torch.rand(4, 3, 64, 64).cuda()
|
||||
output_compare = model(input)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
|
||||
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
|
||||
# return mul
|
||||
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
|
||||
# def forward(self, x : torch.Tensor):
|
||||
# conv_weight = self.conv.weight
|
||||
# conv_bias = self.conv.bias
|
||||
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
|
||||
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
|
||||
# add = conv2d + view; conv2d = view = None
|
||||
# mul = add * 2; add = None
|
||||
# return mul
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
gm.recompile()
|
||||
|
||||
node_list = list(graph.nodes)
|
||||
conv_node = node_list[3]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
|
||||
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
assert_close(output, output_compare)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bias_addition_module():
|
||||
world_size = 4
|
||||
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_linear, nprocs=world_size)
|
||||
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_conv, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bias_addition_module()
|
|
@ -0,0 +1,146 @@
|
|||
from faulthandler import disable
|
||||
from functools import partial
|
||||
from xml.dom import WrongDocumentErr
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
class LinearModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, bias):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def check_linear_module_handler(rank, bias, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearModule(16, 32, bias=bias).cuda()
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
input = torch.rand(2, 2, 4, 16).cuda()
|
||||
# the index of linear node in computation graph
|
||||
node_index = 3
|
||||
# strategy number of linear node
|
||||
strategy_number = 10
|
||||
# construct input args
|
||||
input_args = [input]
|
||||
# construct meta arg names
|
||||
meta_arg_names = ['x']
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names,
|
||||
node_type='bias_module')
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"x": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
||||
# build handler
|
||||
handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
assert op_data.logical_shape is not None
|
||||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "x"
|
||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
||||
|
||||
assert mapping['other'].name == "linear_weight"
|
||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert 'bias' not in mapping
|
||||
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
# one strategy will be converted to different physical sharding spec
|
||||
assert len(strategy_name_list) > 8
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||
|
||||
# RR = RS x SR
|
||||
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||
|
||||
# RS= RR x RS
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler(bias=True):
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_handler()
|
|
@ -7,6 +7,9 @@ from torch.fx import GraphModule
|
|||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import to_global
|
||||
|
@ -56,7 +59,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
strategy_number: int,
|
||||
input_args: List[torch.Tensor],
|
||||
meta_arg_names: List[str],
|
||||
input_kwargs: Dict[str, torch.Tensor] = {}):
|
||||
input_kwargs: Dict[str, torch.Tensor] = {},
|
||||
node_type: str = 'normal'):
|
||||
for strategy_index in range(strategy_number):
|
||||
print(f'#strategy_index: {strategy_index}')
|
||||
# We need to copy the model to avoid do backward more than once in same graph
|
||||
|
@ -79,11 +83,21 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
target_node = list(graph.nodes)[node_index]
|
||||
|
||||
# solution construction
|
||||
solution_len = len(strategies_constructor.leaf_strategies)
|
||||
solution = [0] * solution_len
|
||||
solution[node_index] = strategy_index
|
||||
if node_type == 'normal':
|
||||
solution_len = len(strategies_constructor.leaf_strategies)
|
||||
solution = [0] * solution_len
|
||||
solution[node_index] = strategy_index
|
||||
else:
|
||||
node_vector = strategies_constructor.leaf_strategies[node_index]
|
||||
strategy_to_keep = node_vector[strategy_index]
|
||||
node_vector = [strategy_to_keep]
|
||||
# solution construction
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
gm, solution, device_mesh)
|
||||
gm = runtime_apply_pass(gm)
|
||||
|
@ -110,11 +124,18 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
|
||||
# extract the strategy used in this iter
|
||||
strategy_in_use = target_node.strategies_vector[strategy_index]
|
||||
param_to_shard_dict = dict(model_to_shard.named_parameters())
|
||||
param_to_shard_dict = dict(gm.named_parameters())
|
||||
param_to_compare_dict = dict(model_to_compare.named_parameters())
|
||||
for name in param_to_shard_dict.keys():
|
||||
param_name = name.split('.')[-1]
|
||||
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
|
||||
if node_type == 'normal':
|
||||
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
|
||||
else:
|
||||
if 'weight' in name:
|
||||
param_sharding_spec = list(graph.nodes)[4].sharding_spec
|
||||
elif 'bias' in name:
|
||||
param_sharding_spec = list(graph.nodes)[5].sharding_spec
|
||||
|
||||
grad_sharded = param_to_shard_dict[name].grad
|
||||
grad_to_compare = param_to_compare_dict[name].grad
|
||||
global_grad = to_global(grad_sharded, param_sharding_spec)
|
||||
|
|
Loading…
Reference in New Issue