mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] update_getattr_handler (#2193)
parent
f10ce01e31
commit
4851f2d607
|
@ -6,6 +6,7 @@ import torch
|
||||||
from torch.fx import symbolic_trace
|
from torch.fx import symbolic_trace
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
CommAction,
|
CommAction,
|
||||||
CommType,
|
CommType,
|
||||||
|
@ -96,27 +97,23 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
||||||
# to the same strategy of the user node.
|
# to the same strategy of the user node.
|
||||||
if node.op == 'get_attr':
|
if node.op == 'get_attr':
|
||||||
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
|
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
|
||||||
new_sharding_spec = target_sharding_specs[0]
|
target_node = node.strategies_vector.successor_nodes[0]
|
||||||
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
|
node_name = str(node)
|
||||||
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
|
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
|
||||||
origin_node_sharding_spec_dict[index] = new_sharding_spec
|
node_name = str(target_node)
|
||||||
|
target_node = target_node.strategies_vector.successor_nodes[0]
|
||||||
|
user_strategy = target_node.best_strategy
|
||||||
|
op_data_in_user = user_strategy.get_op_data_by_name(node_name)
|
||||||
origin_pending_strategy = node.best_strategy
|
origin_pending_strategy = node.best_strategy
|
||||||
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
|
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
|
||||||
new_sharding_specs = origin_pending_strategy.sharding_specs
|
|
||||||
new_sharding_specs[origin_op_data] = new_sharding_spec
|
|
||||||
new_communication_actions = {}
|
new_communication_actions = {}
|
||||||
if op_data_in_user in user_strategy.communication_actions:
|
if op_data_in_user in user_strategy.communication_actions:
|
||||||
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
|
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
|
||||||
new_communication_action.arg_index = 0
|
new_communication_action.arg_index = 0
|
||||||
new_communication_actions[origin_op_data] = new_communication_action
|
new_communication_actions[origin_op_data] = new_communication_action
|
||||||
new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence),
|
node.best_strategy.communication_actions = new_communication_actions
|
||||||
sharding_specs=new_sharding_specs,
|
|
||||||
compute_cost=origin_pending_strategy.compute_cost,
|
|
||||||
communication_cost=origin_pending_strategy.communication_cost,
|
|
||||||
memory_cost=origin_pending_strategy.memory_cost,
|
|
||||||
communication_actions=new_communication_actions)
|
|
||||||
setattr(node, 'best_strategy', new_strategy)
|
|
||||||
setattr(node, 'sharding_spec', new_sharding_spec)
|
|
||||||
comm_action_dict = {}
|
comm_action_dict = {}
|
||||||
for op_data, comm_action in node.best_strategy.communication_actions.items():
|
for op_data, comm_action in node.best_strategy.communication_actions.items():
|
||||||
comm_action_dict[op_data.name] = comm_action
|
comm_action_dict[op_data.name] = comm_action
|
||||||
|
|
|
@ -86,12 +86,7 @@ class NodeHandler(ABC):
|
||||||
if prev_sharding_spec is None:
|
if prev_sharding_spec is None:
|
||||||
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||||
elif isinstance(prev_sharding_spec, ShardingSpec):
|
elif isinstance(prev_sharding_spec, ShardingSpec):
|
||||||
if isinstance(data, torch.nn.parameter.Parameter):
|
if isinstance(data, torch.Tensor):
|
||||||
# we won't compute the resharding cost for the parameters,
|
|
||||||
# since the parameters will be sharded before runtime and
|
|
||||||
# not converted during runtime.
|
|
||||||
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
|
||||||
elif isinstance(data, torch.Tensor):
|
|
||||||
dtype = data.dtype
|
dtype = data.dtype
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
|
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||||
|
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||||
|
enumerate_all_possible_1d_sharding,
|
||||||
|
enumerate_all_possible_2d_sharding,
|
||||||
|
ignore_sharding_exception,
|
||||||
|
)
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||||
|
|
||||||
from .strategy_generator import StrategyGenerator
|
from .strategy_generator import StrategyGenerator
|
||||||
|
|
||||||
|
@ -37,17 +43,47 @@ class GetattrGenerator(StrategyGenerator):
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
strategy.memory_cost = memory_cost
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
|
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
# we check for the output logical shape to get the number of dimensions
|
||||||
|
dim_partition_list = []
|
||||||
|
dim_size = len(self.op_data['output'].logical_shape)
|
||||||
|
|
||||||
|
# enumerate all the 2D sharding cases
|
||||||
|
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||||
|
dim_partition_list.extend(sharding_list_2d)
|
||||||
|
|
||||||
|
# enumerate all the 1D sharding cases
|
||||||
|
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||||
|
dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||||
|
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||||
|
dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||||
|
|
||||||
|
# add empty dict for fully replicated case
|
||||||
|
dim_partition_list.append({})
|
||||||
|
|
||||||
|
# sharding strategy bookkeeping
|
||||||
|
strategy_list = []
|
||||||
|
|
||||||
|
# convert these dim partition dict to sharding strategy
|
||||||
|
for dim_partition_dict in dim_partition_list:
|
||||||
|
dim_partition_dict_mapping = dict(output=dim_partition_dict)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
communication_action_mapping = {}
|
||||||
|
|
||||||
|
# get name
|
||||||
|
name = f"get_attr {sharding_spec_mapping['output'].sharding_sequence}"
|
||||||
|
sharding_strategy = self.get_sharding_strategy(
|
||||||
|
name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
strategy_list.append(sharding_strategy)
|
||||||
|
except ShardingSpecException:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
||||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
dim_partition_dict_mapping = {
|
return self.enumerate_all_possible_output(0, 1)
|
||||||
"output": {},
|
|
||||||
}
|
|
||||||
communication_action_mapping = {}
|
|
||||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
||||||
|
|
||||||
name = 'Replica Attribute'
|
|
||||||
|
|
||||||
strategy = self.get_sharding_strategy(name=name,
|
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
|
||||||
communication_action_mapping=communication_action_mapping)
|
|
||||||
|
|
||||||
return [strategy]
|
|
||||||
|
|
|
@ -35,25 +35,59 @@ class AddmmModel(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def check_linear_function_handler(rank, input_shape, world_size, port):
|
class AddmmModel_with_param(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, weight_shape, bias_shape):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.rand(weight_shape))
|
||||||
|
self.bias = torch.nn.Parameter(torch.rand(bias_shape))
|
||||||
|
|
||||||
|
def forward(self, m1):
|
||||||
|
x = torch.addmm(self.bias, m1, self.weight, beta=3, alpha=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
model = AddmmModel().cuda()
|
if model_cls == AddmmModel:
|
||||||
|
model = AddmmModel().cuda()
|
||||||
|
else:
|
||||||
|
model = AddmmModel_with_param(weight_shape=(8, 16), bias_shape=input_shape).cuda()
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
|
||||||
input = torch.rand(input_shape).cuda()
|
if model_cls == AddmmModel:
|
||||||
m1 = torch.rand(4, 8).cuda()
|
input = torch.rand(input_shape).cuda()
|
||||||
m2 = torch.rand(8, 16).cuda()
|
m1 = torch.rand(4, 8).cuda()
|
||||||
# the index of addmm node in computation graph
|
m2 = torch.rand(8, 16).cuda()
|
||||||
node_index = 4
|
# construct input args
|
||||||
# strategy number of linear node
|
input_args = [input, m1, m2]
|
||||||
strategy_number = 14
|
# construct meta arg names
|
||||||
# construct input args
|
meta_arg_names = ['input', 'm1', 'm2']
|
||||||
input_args = [input, m1, m2]
|
meta_args_for_tracer = {}
|
||||||
# construct meta arg names
|
for meta_arg, input_arg in zip(meta_arg_names, input_args):
|
||||||
meta_arg_names = ['input', 'm1', 'm2']
|
meta_args_for_tracer[meta_arg] = input_arg.to('meta')
|
||||||
|
|
||||||
|
# the index of addmm node in computation graph
|
||||||
|
node_index = 4
|
||||||
|
# strategy number of linear node
|
||||||
|
strategy_number = 14
|
||||||
|
else:
|
||||||
|
m1 = torch.rand(4, 8).cuda()
|
||||||
|
# construct input args
|
||||||
|
input_args = [m1]
|
||||||
|
# construct meta arg names
|
||||||
|
meta_arg_names = ['m1']
|
||||||
|
# the index of addmm node in computation graph
|
||||||
|
meta_args_for_tracer = {}
|
||||||
|
for meta_arg, input_arg in zip(meta_arg_names, input_args):
|
||||||
|
meta_args_for_tracer[meta_arg] = input_arg.to('meta')
|
||||||
|
node_index = 4
|
||||||
|
# strategy number of linear node
|
||||||
|
strategy_number = 14
|
||||||
|
|
||||||
numerical_test_for_node_strategy(model=model,
|
numerical_test_for_node_strategy(model=model,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
node_index=node_index,
|
node_index=node_index,
|
||||||
|
@ -73,12 +107,7 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
|
||||||
# %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
|
# %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
|
||||||
# %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
|
# %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
|
||||||
# return add
|
# return add
|
||||||
graph = tracer.trace(model,
|
graph = tracer.trace(model, meta_args=meta_args_for_tracer)
|
||||||
meta_args={
|
|
||||||
"input": torch.rand(input_shape).to('meta'),
|
|
||||||
'm1': torch.rand(4, 8).to('meta'),
|
|
||||||
'm2': torch.rand(8, 16).to('meta'),
|
|
||||||
})
|
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
# [input_1, m1, m2, addmm, output]
|
# [input_1, m1, m2, addmm, output]
|
||||||
node_list = list(graph.nodes)
|
node_list = list(graph.nodes)
|
||||||
|
@ -155,11 +184,13 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@parameterize('input_shape', [(16,), (4, 16)])
|
@parameterize('input_shape', [(16,), (4, 16)])
|
||||||
|
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_addmm_handler(input_shape):
|
def test_addmm_handler(input_shape, model_cls):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func_function = partial(check_linear_function_handler,
|
run_func_function = partial(check_addmm_function_handler,
|
||||||
input_shape=input_shape,
|
input_shape=input_shape,
|
||||||
|
model_cls=model_cls,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
port=free_port())
|
port=free_port())
|
||||||
mp.spawn(run_func_function, nprocs=world_size)
|
mp.spawn(run_func_function, nprocs=world_size)
|
||||||
|
|
|
@ -39,6 +39,7 @@ def test_getattr_handler():
|
||||||
strategies_vector=getattr_strategies_vector)
|
strategies_vector=getattr_strategies_vector)
|
||||||
|
|
||||||
getattr_handler.register_strategy(compute_resharding_cost=False)
|
getattr_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
|
||||||
# check operation data mapping
|
# check operation data mapping
|
||||||
mapping = getattr_handler.get_operation_data_mapping()
|
mapping = getattr_handler.get_operation_data_mapping()
|
||||||
|
|
||||||
|
@ -51,7 +52,15 @@ def test_getattr_handler():
|
||||||
assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3))
|
assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3))
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
strategy_name_list = [val.name for val in getattr_handler.strategies_vector]
|
strategy_name_list = [val.name for val in getattr_handler.strategies_vector]
|
||||||
assert "Replica Attribute" in strategy_name_list
|
assert 'get_attr [S0, S1, R, R]' in strategy_name_list
|
||||||
|
assert 'get_attr [S1, S0, R, R]' in strategy_name_list
|
||||||
|
assert 'get_attr [S01, R, R, R]' in strategy_name_list
|
||||||
|
assert 'get_attr [R, S01, R, R]' in strategy_name_list
|
||||||
|
assert 'get_attr [S0, R, R, R]' in strategy_name_list
|
||||||
|
assert 'get_attr [R, S0, R, R]' in strategy_name_list
|
||||||
|
assert 'get_attr [S1, R, R, R]' in strategy_name_list
|
||||||
|
assert 'get_attr [R, S1, R, R]' in strategy_name_list
|
||||||
|
assert 'get_attr [R, R, R, R]' in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -149,10 +149,20 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||||
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
|
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
|
||||||
else:
|
else:
|
||||||
if 'weight' in name:
|
if 'weight' in name:
|
||||||
param_sharding_spec = list(graph.nodes)[4].sharding_spec
|
param_sharding_spec = None
|
||||||
elif 'bias' in name:
|
|
||||||
param_sharding_spec = list(graph.nodes)[5].sharding_spec
|
|
||||||
|
|
||||||
|
for node in list(graph.nodes):
|
||||||
|
if 'weight' in node.name:
|
||||||
|
param_sharding_spec = node.sharding_spec
|
||||||
|
|
||||||
|
elif 'bias' in name:
|
||||||
|
param_sharding_spec = None
|
||||||
|
|
||||||
|
for node in list(graph.nodes):
|
||||||
|
if 'bias' in node.name:
|
||||||
|
param_sharding_spec = node.sharding_spec
|
||||||
|
|
||||||
|
assert param_sharding_spec is not None
|
||||||
grad_sharded = param_to_shard_dict[name].grad
|
grad_sharded = param_to_shard_dict[name].grad
|
||||||
grad_to_compare = param_to_compare_dict[name].grad
|
grad_to_compare = param_to_compare_dict[name].grad
|
||||||
global_grad = to_global(grad_sharded, param_sharding_spec)
|
global_grad = to_global(grad_sharded, param_sharding_spec)
|
||||||
|
|
Loading…
Reference in New Issue