[autoparallel] shard param and buffer as expected (#1753)

* [autoparallel] shard param and buffer as expected

* fix unit test issue
pull/1751/head
YuliangLiu0306 2022-10-21 15:45:13 +08:00 committed by GitHub
parent cdb7d5e7d2
commit 980ed21723
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 129 additions and 106 deletions

View File

@ -1,4 +1,5 @@
import builtins
import copy
import operator
from ast import NodeTransformer
from copy import deepcopy
@ -11,34 +12,13 @@ from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.comm_spec import CommSpec, _all_reduce
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec, _all_reduce, pattern_to_func_dict
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
shape_consistency_manager = ShapeConsistencyManager()
class ConsistencyApply(torch.autograd.Function):
@staticmethod
def forward(ctx, node, origin_sharding_spec, target_sharding_spec):
ctx.origin_sharding_spec = origin_sharding_spec
ctx.target_sharding_spec = target_sharding_spec
return shape_consistency_manager.apply_for_autoparallel_runtime(node, ctx.origin_sharding_spec,
ctx.target_sharding_spec)
@staticmethod
def backward(ctx, node_grad):
return shape_consistency_manager.apply_for_autoparallel_runtime(
node_grad, ctx.target_sharding_spec, ctx.origin_sharding_spec), None, None, None, None
def runtime_apply_for_leaf_node(node, origin_dict, input_dict, node_index, user_node_index):
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]
return ConsistencyApply.apply(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]
@ -53,7 +33,7 @@ def runtime_comm_spec_apply(tensor, comm_actions_dict, node_index, op_data):
else:
origin_sharding_spec = comm_action.comm_spec['src_spec']
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
rst = ConsistencyApply.apply(tensor, origin_sharding_spec, tgt_sharding_spec)
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
@ -75,12 +55,17 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
for name, param in target_module.named_parameters():
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
shape_consistency_manager.apply(param, target_sharding_spec)
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
param_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
else:
param_sharded = param
setattr(target_module, name, param_sharded)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
@ -88,13 +73,18 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
def hook_fn(grad):
_all_reduce(grad, comm_spec_to_use)
param.register_hook(hook_fn)
param_sharded.register_hook(hook_fn)
sharded_buffer_dict = {}
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
setattr(buffer, 'sharding_spec', origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
shape_consistency_manager.apply(buffer, target_sharding_spec)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
@ -157,19 +147,11 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
for user_node in node.strategies_vector.successor_nodes:
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
if user_node.op != "output":
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
else:
# we need to call an autograd.Function for leaf node
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply_for_leaf_node,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
origin_index_args = user_node.args.index(node)
new_args = list(user_node.args)
@ -179,21 +161,29 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
comm_object = node.args[comm_action.arg_index]
if op_data.type == OperationDataType.ARG:
if comm_action.comm_type == CommType.BEFORE:
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
if op_data.type == OperationDataType.PARAM:
continue
if comm_action.comm_type == CommType.BEFORE:
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(node, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
user_list = list(node.users.keys())
for user in user_list:
if user == comm_spec_apply_node:
continue
new_args = list(user.args)
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
# TODO: consider other OperationDataType, such as OperationDataType.OUTPUT
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
return gm

View File

@ -345,9 +345,9 @@ class CommSpec:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
if self.comm_pattern in pattern_to_func_dict:
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
tensor.data = tensor
tensor = tensor
return tensor

View File

@ -511,13 +511,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
'''
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec.sharding_spec = target_spec
return tensor_with_sharding_spec
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor)
tensor = comm_spec.covert_spec_to_action(tensor)
tensor.sharding_spec = target_spec
return tensor

View File

@ -1,28 +1,32 @@
import copy
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai import device
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from torch.fx import GraphModule
from torchvision.models import resnet34, resnet50
from colossalai import device
from colossalai.auto_parallel.tensor_shard.constants import *
from colossalai.testing import assert_close_loose, assert_close
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
shape_consistency_pass,
solution_annotatation_pass,
)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
seed = 128
cudnn_benchmark = False
@ -108,16 +112,17 @@ class Bottleneck(nn.Module):
def check_apply_bottleneck(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input = torch.rand(256, 64, 64, 64).cuda()
input = torch.rand(4, 4, 4, 4).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
entire_shape = torch.Size((4, 4, 8, 8))
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
model = Bottleneck(64, 64, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
model = Bottleneck(4, 4, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
@ -130,9 +135,8 @@ def check_apply_bottleneck(rank, world_size, port):
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
# return relu_2
input_sample = {'x': torch.rand(256, 64, 224, 224).to('meta')}
cuda_rng_state = torch.cuda.get_rng_state()
origin_output = model(input)
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
@ -147,16 +151,42 @@ def check_apply_bottleneck(rank, world_size, port):
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
print(solution)
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[solution[index]].name)
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
cuda_rng_state = torch.cuda.get_rng_state()
origin_output = test_model(test_input)
torch.cuda.set_rng_state(cuda_rng_state)
output = gm(input, sharding_spec_dict, origin_spec_dict)
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert output.shape == origin_output.shape
assert output.equal(origin_output)
assert_close(output, origin_output)
print("*******************backward starting*******************")
cuda_rng_state = torch.cuda.get_rng_state()
output.sum().backward()
torch.cuda.set_rng_state(cuda_rng_state)
origin_output.sum().backward()
if rank == 0:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
if rank == 1:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
if rank == 2:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
if rank == 3:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
@run_on_environment_flag(name='AUTO_PARALLEL')

View File

@ -1,17 +1,19 @@
import torch
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CommSpec, CollectiveCommPattern
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
def check_all_gather(device_mesh, rank):
@ -37,7 +39,7 @@ def check_all_gather(device_mesh, rank):
sharding_spec,
gather_dim=1,
logical_process_axis=1)
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
assert sharded_tensor_to_comm.equal(tensor_to_check)
@ -60,7 +62,7 @@ def check_shard(device_mesh, rank):
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
comm_spec.covert_spec_to_action(tensor_to_shard)
tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
if rank in (0, 2):
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
@ -110,7 +112,7 @@ def check_all_to_all(device_mesh, rank):
gather_dim=0,
shard_dim=1,
logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
@ -137,7 +139,7 @@ def check_all_reduce_fwd(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
@ -155,7 +157,7 @@ def check_all_reduce_bwd(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
@ -178,7 +180,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
comm_spec.covert_spec_to_action(tensor_to_comm)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)

View File

@ -1,15 +1,16 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
def check_apply(rank, world_size, port):
@ -63,7 +64,7 @@ def check_apply(rank, world_size, port):
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
assert tensor_to_comm.equal(tensor_to_check)
assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)