[autoparallel] add sequential order to communication actions (#1735)

pull/1753/head
YuliangLiu0306 2022-10-20 18:48:18 +08:00 committed by GitHub
parent b893342f95
commit a4ce180e85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 293 additions and 90 deletions

View File

@ -4,9 +4,18 @@ import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -122,26 +131,28 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"input": input_comm_spec}
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["other"] = other_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"):
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["bias"] = bias_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -167,18 +178,20 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
if self.is_param("other"):
other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["other"] = other_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"):
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["bias"] = bias_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -206,26 +219,30 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER,
arg_index=0)
communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}
if self.is_param("other"):
other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["other"] = other_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"):
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["bias"] = bias_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -256,16 +273,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
input_comm_spec = self.get_communication_spec(
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
communication_action_mapping = {"output": output_comm_action, "input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -291,12 +312,14 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -324,12 +347,13 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
communication_action_mapping = {"input": input_comm_spec}
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -375,18 +399,20 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
if self.is_param("other"):
other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping["other"] = other_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"):
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping["bias"] = bias_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -411,12 +437,14 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER,
arg_index=0)
communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -443,12 +471,14 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_spec}
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,8 +1,15 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
from .strategy_generator import FollowingStrategyGenerator
@ -81,12 +88,23 @@ class ReshapeGenerator(FollowingStrategyGenerator):
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["input"] = input_comm_spec
else:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)

View File

@ -4,17 +4,27 @@ from functools import reduce
from typing import Any, Dict, List, Union
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
TrainCycleItem)
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx import Node
class StrategyGenerator(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
StrategyGenerator is used to generate the same group of sharding strategies.
TODO: remove the original strategy_generator.py after refactoring
"""
@ -97,6 +107,21 @@ class StrategyGenerator(ABC):
sharding_spec=sharding_spec,
logical_process_axis=logical_process_axis)
def get_communication_action(self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index)
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Compute the communication cost involved in the forward and backward iteration.
@ -117,8 +142,21 @@ class StrategyGenerator(ABC):
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, comm_spec in strategy.communication_actions.items():
_compute_and_add(operand, comm_spec)
for operand, comm_action in strategy.communication_actions.items():
if isinstance(comm_action, CommAction):
comm_spec = comm_action.comm_spec
else:
# this condition branch will be removed after all the handler updated.
comm_spec = comm_action
if isinstance(comm_spec, dict):
src_spec = comm_spec['src_spec']
tgt_spec = comm_spec['tgt_spec']
shape_consistency_manager = ShapeConsistencyManager()
_, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec)
for comm_spec_ in comm_action_sequence:
_compute_and_add(operand, comm_spec_)
else:
_compute_and_add(operand, comm_spec)
# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
@ -141,7 +179,7 @@ class StrategyGenerator(ABC):
def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
"""
Compute the size of a tensor in bytes.
Args:
strategy (ShardingStrategy): the ShardingStrategy generated.
key (str): the name of the operation data defined by the generator.
@ -182,7 +220,7 @@ class StrategyGenerator(ABC):
@abstractmethod
def validate(self) -> bool:
"""
Validate if the operands are of desired shape.
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
@ -190,7 +228,7 @@ class StrategyGenerator(ABC):
class FollowingStrategyGenerator(StrategyGenerator):
"""
FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node.
FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node.
TODO: remove the original strategy_generator.py after refactoring
"""

View File

@ -4,11 +4,12 @@ from enum import Enum
from typing import Any, Dict, List, Tuple, Union
import torch
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from .constants import (BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP)
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
@ -84,6 +85,38 @@ class MemoryCost:
buffer: int = 0
class CommType(Enum):
"""
CommType describes the sequential order of a communication action and a computation action.
Meaning:
BEFORE: the communication action happens just before the computation operation.
AFTER: the communication action happens after the computation operation.
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
BEFORE = 0
AFTER = 1
HOOK = 2
IMPLICIT = 3
@dataclass
class CommAction:
"""
CommAction is used to record the communication action.
Args:
comm_spec: express the communication pattern and the process groups to execute the communication action.
comm_type: describes the sequential order of a communication action and a computation action.
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
@dataclass
class ShardingStrategy:
"""
@ -102,7 +135,7 @@ class ShardingStrategy:
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
communication_actions: Dict[OperationData, CommSpec] = None
communication_actions: Dict[OperationData, CommAction] = None
resharding_costs: Dict[Node, List[TrainCycleItem]] = None
@property

View File

@ -8,8 +8,10 @@ import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.comm_spec import CommSpec, _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
@ -19,9 +21,9 @@ shape_consistency_manager = ShapeConsistencyManager()
class ConsistencyApply(torch.autograd.Function):
@staticmethod
def forward(ctx, node, origin_dict, input_dict, node_index, user_node_index):
ctx.origin_sharding_spec = origin_dict[node_index]
ctx.target_sharding_spec = input_dict[node_index][user_node_index]
def forward(ctx, node, origin_sharding_spec, target_sharding_spec):
ctx.origin_sharding_spec = origin_sharding_spec
ctx.target_sharding_spec = target_sharding_spec
return shape_consistency_manager.apply_for_autoparallel_runtime(node, ctx.origin_sharding_spec,
ctx.target_sharding_spec)
@ -32,7 +34,9 @@ class ConsistencyApply(torch.autograd.Function):
def runtime_apply_for_leaf_node(node, origin_dict, input_dict, node_index, user_node_index):
return ConsistencyApply.apply(node, origin_dict, input_dict, node_index, user_node_index)
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]
return ConsistencyApply.apply(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
@ -41,6 +45,18 @@ def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_comm_spec_apply(tensor, comm_actions_dict, node_index, op_data):
comm_action = comm_actions_dict[node_index][op_data]
if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else:
origin_sharding_spec = comm_action.comm_spec['src_spec']
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
rst = ConsistencyApply.apply(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
@ -63,6 +79,16 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
setattr(param, 'sharding_spec', origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
shape_consistency_manager.apply(param, target_sharding_spec)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def hook_fn(grad):
_all_reduce(grad, comm_spec_to_use)
param.register_hook(hook_fn)
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
@ -79,15 +105,24 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
# the dict to record comm actions of nodes
comm_actions_dict = {}
for index, node in enumerate(nodes):
comm_action_dict = {}
for op_data, comm_action in node.best_strategy.communication_actions.items():
comm_action_dict[op_data.name] = comm_action
comm_actions_dict[index] = comm_action_dict
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
return sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
def shape_consistency_pass(gm: torch.fx.GraphModule):
@ -106,6 +141,9 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
if node.target == 'origin_node_sharding_spec_dict':
origin_dict_node = node
continue
if node.target == 'comm_actions_dict':
comm_actions_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
continue
node_to_index_dict[node] = index
@ -138,4 +176,24 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
comm_object = node.args[comm_action.arg_index]
if op_data.type == OperationDataType.ARG:
if comm_action.comm_type == CommType.BEFORE:
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
# TODO: consider other OperationDataType, such as OperationDataType.OUTPUT
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
return gm

View File

@ -1,8 +1,9 @@
import torch
from enum import Enum
import torch.distributed as dist
from functools import reduce
import operator
from enum import Enum
from functools import reduce
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
__all__ = [
@ -238,7 +239,7 @@ class CommSpec:
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
It contains comm_pattern to determine the
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
@ -296,7 +297,7 @@ class CommSpec:
'''
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
cost_dict = {}
@ -347,6 +348,7 @@ class CommSpec:
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
tensor.data = tensor
return tensor
pattern_to_func_dict = {

View File

@ -1,3 +1,4 @@
import copy
from functools import partial
import pytest
@ -6,15 +7,22 @@ import torch.multiprocessing as mp
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
StrategiesConstructor)
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass,
solution_annotatation_pass)
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
shape_consistency_pass,
solution_annotatation_pass,
)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
@ -27,6 +35,7 @@ class ConvModel(nn.Module):
def forward(self, x):
x = self.conv(x)
x = torch.flatten(x)
return x
@ -38,12 +47,13 @@ def check_apply(rank, world_size, port):
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
entire_shape = torch.Size((4, 4, 8, 8))
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
model = ConvModel(4, 4).cuda()
origin_output = model(input)
test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
@ -62,16 +72,30 @@ def check_apply(rank, world_size, port):
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
output = gm(input, sharding_spec_dict, origin_spec_dict)
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
origin_output = test_model(test_input)
assert output.equal(origin_output)
origin_loss = origin_output.sum()
loss = output.sum()
origin_loss.backward()
loss.backward()
grad_0 = test_model.conv.weight.grad.narrow(0, 0, 2)
grad_1 = test_model.conv.weight.grad.narrow(0, 2, 2)
if rank in (0, 1):
assert_close(gm.conv.weight.grad.data, grad_0.data)
elif rank in (2, 3):
assert_close(gm.conv.weight.grad.data, grad_1.data)
# skip this test due to pulp not installed in CI environment
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()