mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] refactor the runtime apply pass and add docstring to passes (#1757)
* [autoparallel] refactor the runtime apply pass and add doc string to passes * fix unit test * polishpull/1759/head
parent
f9a613d660
commit
314d8c497f
|
@ -0,0 +1,151 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
CommAction,
|
||||||
|
CommType,
|
||||||
|
OperationData,
|
||||||
|
OperationDataType,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.tensor.comm_spec import CommSpec
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
|
def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int):
|
||||||
|
"""
|
||||||
|
This method will be invoked during runtime to do the shape consistency, which make sure the activations is converted into
|
||||||
|
the user node expected form.
|
||||||
|
"""
|
||||||
|
origin_sharding_spec = origin_dict[node_index]
|
||||||
|
target_sharding_spec = input_dict[node_index][user_node_index]
|
||||||
|
|
||||||
|
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||||
|
"""
|
||||||
|
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||||
|
"""
|
||||||
|
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||||
|
if isinstance(comm_action.comm_spec, CommSpec):
|
||||||
|
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
|
||||||
|
else:
|
||||||
|
origin_sharding_spec = comm_action.comm_spec['src_spec']
|
||||||
|
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
|
||||||
|
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
|
||||||
|
return rst
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_graph(nodes: List[Node]):
|
||||||
|
"""
|
||||||
|
This method is used to extract all the placeholders with sharding information,
|
||||||
|
and mapping the nodes into the index of the origin graph.
|
||||||
|
"""
|
||||||
|
# mapping the node into the origin graph index
|
||||||
|
node_to_index_dict = {}
|
||||||
|
index = 0
|
||||||
|
for node in nodes:
|
||||||
|
if node.target == 'sharding_spec_convert_dict':
|
||||||
|
input_dict_node = node
|
||||||
|
continue
|
||||||
|
if node.target == 'origin_node_sharding_spec_dict':
|
||||||
|
origin_dict_node = node
|
||||||
|
continue
|
||||||
|
if node.target == 'comm_actions_dict':
|
||||||
|
comm_actions_dict_node = node
|
||||||
|
continue
|
||||||
|
if not hasattr(node, 'best_strategy'):
|
||||||
|
continue
|
||||||
|
node_to_index_dict[node] = index
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
return input_dict_node, origin_dict_node, comm_actions_dict_node, node_to_index_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||||
|
"""
|
||||||
|
This pass is used to add the shape consistency node to the origin graph.
|
||||||
|
"""
|
||||||
|
mod_graph = gm.graph
|
||||||
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
|
||||||
|
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||||
|
continue
|
||||||
|
|
||||||
|
for user_node in node.strategies_vector.successor_nodes:
|
||||||
|
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
||||||
|
with mod_graph.inserting_before(user_node):
|
||||||
|
shape_consistency_node = mod_graph.create_node('call_function',
|
||||||
|
runtime_apply,
|
||||||
|
args=(node, origin_dict_node, input_dict_node,
|
||||||
|
node_to_index_dict[node], user_node_index))
|
||||||
|
|
||||||
|
origin_index_args = user_node.args.index(node)
|
||||||
|
new_args = list(user_node.args)
|
||||||
|
new_args[origin_index_args] = shape_consistency_node
|
||||||
|
user_node.args = new_args
|
||||||
|
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||||
|
"""
|
||||||
|
This pass is used to add the comm spec apply node to the origin graph.
|
||||||
|
"""
|
||||||
|
mod_graph = gm.graph
|
||||||
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
|
||||||
|
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||||
|
continue
|
||||||
|
|
||||||
|
comm_actions = node.best_strategy.communication_actions
|
||||||
|
for op_data, comm_action in comm_actions.items():
|
||||||
|
comm_object = node.args[comm_action.arg_index]
|
||||||
|
if op_data.type == OperationDataType.PARAM:
|
||||||
|
continue
|
||||||
|
if comm_action.comm_type == CommType.BEFORE:
|
||||||
|
with mod_graph.inserting_before(node):
|
||||||
|
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||||
|
runtime_comm_spec_apply,
|
||||||
|
args=(comm_object, comm_actions_dict_node,
|
||||||
|
node_to_index_dict[node], op_data.name))
|
||||||
|
new_args = list(node.args)
|
||||||
|
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||||
|
node.args = new_args
|
||||||
|
elif comm_action.comm_type == CommType.AFTER:
|
||||||
|
with mod_graph.inserting_after(node):
|
||||||
|
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||||
|
runtime_comm_spec_apply,
|
||||||
|
args=(node, comm_actions_dict_node,
|
||||||
|
node_to_index_dict[node], op_data.name))
|
||||||
|
user_list = list(node.users.keys())
|
||||||
|
for user in user_list:
|
||||||
|
if user == comm_spec_apply_node:
|
||||||
|
continue
|
||||||
|
new_args = list(user.args)
|
||||||
|
new_args[new_args.index(node)] = comm_spec_apply_node
|
||||||
|
user.args = tuple(new_args)
|
||||||
|
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def runtime_apply_pass(gm: torch.fx.GraphModule):
|
||||||
|
"""
|
||||||
|
The method manages all the passes acting on the distributed training runtime.
|
||||||
|
"""
|
||||||
|
gm = _shape_consistency_apply(gm)
|
||||||
|
gm = _comm_spec_apply(gm)
|
||||||
|
|
||||||
|
return gm
|
|
@ -0,0 +1,130 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.tensor.comm_spec import _all_reduce
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
|
def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
||||||
|
"""
|
||||||
|
This method is used to stick the solution strategy to the nodes and add the information
|
||||||
|
required in runtime into graph as placeholder nodes.
|
||||||
|
"""
|
||||||
|
mod_graph = gm.graph
|
||||||
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
|
||||||
|
# the dict to get origin sharding spec of node
|
||||||
|
origin_node_sharding_spec_dict = {}
|
||||||
|
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
|
||||||
|
strategies_vector = node.strategies_vector
|
||||||
|
# stick the solution strategy to the corresponding node
|
||||||
|
setattr(node, 'best_strategy', strategies_vector[strategy_index])
|
||||||
|
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
|
||||||
|
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||||
|
str(node))
|
||||||
|
|
||||||
|
# the dict to get input sharding specs of user node
|
||||||
|
sharding_spec_convert_dict = {}
|
||||||
|
# the dict to record comm actions of nodes
|
||||||
|
comm_actions_dict = {}
|
||||||
|
for index, node in enumerate(nodes):
|
||||||
|
target_sharding_specs = []
|
||||||
|
for user_node in node.strategies_vector.successor_nodes:
|
||||||
|
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||||
|
target_sharding_specs.append(target_sharding_spec)
|
||||||
|
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||||
|
|
||||||
|
comm_action_dict = {}
|
||||||
|
for op_data, comm_action in node.best_strategy.communication_actions.items():
|
||||||
|
comm_action_dict[op_data.name] = comm_action
|
||||||
|
comm_actions_dict[index] = comm_action_dict
|
||||||
|
|
||||||
|
# add above dicts into graph
|
||||||
|
for node in nodes:
|
||||||
|
if node.op != 'placeholder':
|
||||||
|
with mod_graph.inserting_before(node):
|
||||||
|
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
|
||||||
|
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
|
||||||
|
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
|
||||||
|
break
|
||||||
|
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||||
|
"""
|
||||||
|
Apply the sharding action to the module parameters and buffers following the
|
||||||
|
instructions of solver solution.
|
||||||
|
"""
|
||||||
|
mod_graph = gm.graph
|
||||||
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if node.op == 'call_module':
|
||||||
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
|
|
||||||
|
for name, param in target_module.named_parameters():
|
||||||
|
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||||
|
# apply the sharding spec of parameters
|
||||||
|
if target_sharding_spec.dim_partition_dict != {}:
|
||||||
|
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||||
|
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||||
|
param_sharded = torch.nn.Parameter(
|
||||||
|
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||||
|
target_sharding_spec).detach().clone())
|
||||||
|
else:
|
||||||
|
param_sharded = param
|
||||||
|
setattr(target_module, name, param_sharded)
|
||||||
|
comm_actions = node.best_strategy.communication_actions
|
||||||
|
for operation_data, comm_action in comm_actions.items():
|
||||||
|
comm_spec_to_use = comm_action.comm_spec
|
||||||
|
# register hook to the parameters
|
||||||
|
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||||
|
|
||||||
|
def wrapper(param, comm_spec):
|
||||||
|
|
||||||
|
def hook_fn(grad):
|
||||||
|
_all_reduce(grad, comm_spec)
|
||||||
|
|
||||||
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
|
wrapper(param_sharded, comm_spec_to_use)
|
||||||
|
|
||||||
|
sharded_buffer_dict = {}
|
||||||
|
# apply the sharding spec of buffers
|
||||||
|
for name, buffer in target_module.named_buffers():
|
||||||
|
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
||||||
|
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
||||||
|
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||||
|
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
|
||||||
|
sharded_buffer_dict[name] = buffer_sharded
|
||||||
|
|
||||||
|
for name, buffer_sharded in sharded_buffer_dict.items():
|
||||||
|
setattr(target_module, name, buffer_sharded.detach().clone())
|
||||||
|
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
||||||
|
"""
|
||||||
|
replace the origin kernel into kernel with implicit communication inside.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
|
||||||
|
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||||
|
gm, solution)
|
||||||
|
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||||
|
# gm = implicit_comm_action_apply(gm)
|
||||||
|
gm = _module_params_sharding(gm, device_mesh)
|
||||||
|
|
||||||
|
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
|
@ -1,193 +0,0 @@
|
||||||
import builtins
|
|
||||||
import copy
|
|
||||||
import operator
|
|
||||||
from ast import NodeTransformer
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.fx import symbolic_trace
|
|
||||||
from torch.fx.node import Node
|
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
|
||||||
from colossalai.fx.passes.split_module import split_module
|
|
||||||
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec, _all_reduce, pattern_to_func_dict
|
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
|
||||||
|
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
|
||||||
|
|
||||||
|
|
||||||
def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
|
|
||||||
origin_sharding_spec = origin_dict[node_index]
|
|
||||||
target_sharding_spec = input_dict[node_index][user_node_index]
|
|
||||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
|
|
||||||
|
|
||||||
|
|
||||||
def runtime_comm_spec_apply(tensor, comm_actions_dict, node_index, op_data):
|
|
||||||
|
|
||||||
comm_action = comm_actions_dict[node_index][op_data]
|
|
||||||
if isinstance(comm_action.comm_spec, CommSpec):
|
|
||||||
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
|
|
||||||
else:
|
|
||||||
origin_sharding_spec = comm_action.comm_spec['src_spec']
|
|
||||||
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
|
|
||||||
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
|
|
||||||
return rst
|
|
||||||
|
|
||||||
|
|
||||||
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
|
|
||||||
mod_graph = gm.graph
|
|
||||||
nodes = tuple(mod_graph.nodes)
|
|
||||||
|
|
||||||
# the dict to get origin sharding spec of node
|
|
||||||
origin_node_sharding_spec_dict = {}
|
|
||||||
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
|
|
||||||
strategies_vector = node.strategies_vector
|
|
||||||
setattr(node, 'best_strategy', strategies_vector[strategy_index])
|
|
||||||
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
|
|
||||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
|
||||||
str(node))
|
|
||||||
|
|
||||||
# apply the sharding spec of parameters
|
|
||||||
for node in nodes:
|
|
||||||
if node.op == 'call_module':
|
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
|
||||||
for name, param in target_module.named_parameters():
|
|
||||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
|
||||||
if target_sharding_spec.dim_partition_dict != {}:
|
|
||||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
|
||||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
|
||||||
param_sharded = torch.nn.Parameter(
|
|
||||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
|
||||||
target_sharding_spec).detach().clone())
|
|
||||||
else:
|
|
||||||
param_sharded = param
|
|
||||||
setattr(target_module, name, param_sharded)
|
|
||||||
comm_actions = node.best_strategy.communication_actions
|
|
||||||
for operation_data, comm_action in comm_actions.items():
|
|
||||||
comm_spec_to_use = comm_action.comm_spec
|
|
||||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
|
||||||
|
|
||||||
def wrapper(param, comm_spec):
|
|
||||||
|
|
||||||
def hook_fn(grad):
|
|
||||||
_all_reduce(grad, comm_spec)
|
|
||||||
|
|
||||||
param.register_hook(hook_fn)
|
|
||||||
|
|
||||||
wrapper(param_sharded, comm_spec_to_use)
|
|
||||||
|
|
||||||
sharded_buffer_dict = {}
|
|
||||||
for name, buffer in target_module.named_buffers():
|
|
||||||
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
|
||||||
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
|
||||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
|
||||||
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
|
|
||||||
sharded_buffer_dict[name] = buffer_sharded
|
|
||||||
|
|
||||||
for name, buffer_sharded in sharded_buffer_dict.items():
|
|
||||||
setattr(target_module, name, buffer_sharded.detach().clone())
|
|
||||||
|
|
||||||
# the dict to get input sharding specs of user node
|
|
||||||
sharding_spec_convert_dict = {}
|
|
||||||
for index, node in enumerate(nodes):
|
|
||||||
target_sharding_specs = []
|
|
||||||
for user_node in node.strategies_vector.successor_nodes:
|
|
||||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
|
||||||
target_sharding_specs.append(target_sharding_spec)
|
|
||||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
|
||||||
|
|
||||||
# the dict to record comm actions of nodes
|
|
||||||
comm_actions_dict = {}
|
|
||||||
for index, node in enumerate(nodes):
|
|
||||||
comm_action_dict = {}
|
|
||||||
for op_data, comm_action in node.best_strategy.communication_actions.items():
|
|
||||||
comm_action_dict[op_data.name] = comm_action
|
|
||||||
comm_actions_dict[index] = comm_action_dict
|
|
||||||
|
|
||||||
# add above dicts into graph
|
|
||||||
for node in nodes:
|
|
||||||
if node.op != 'placeholder':
|
|
||||||
with mod_graph.inserting_before(node):
|
|
||||||
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
|
|
||||||
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
|
|
||||||
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
|
|
||||||
break
|
|
||||||
|
|
||||||
return sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
|
||||||
|
|
||||||
|
|
||||||
def shape_consistency_pass(gm: torch.fx.GraphModule):
|
|
||||||
mod_graph = gm.graph
|
|
||||||
nodes = tuple(mod_graph.nodes)
|
|
||||||
input_dict_node = None
|
|
||||||
origin_dict_node = None
|
|
||||||
|
|
||||||
# mapping the node into the origin graph index
|
|
||||||
node_to_index_dict = {}
|
|
||||||
index = 0
|
|
||||||
for node in nodes:
|
|
||||||
if node.target == 'sharding_spec_convert_dict':
|
|
||||||
input_dict_node = node
|
|
||||||
continue
|
|
||||||
if node.target == 'origin_node_sharding_spec_dict':
|
|
||||||
origin_dict_node = node
|
|
||||||
continue
|
|
||||||
if node.target == 'comm_actions_dict':
|
|
||||||
comm_actions_dict_node = node
|
|
||||||
continue
|
|
||||||
if not hasattr(node, 'best_strategy'):
|
|
||||||
continue
|
|
||||||
node_to_index_dict[node] = index
|
|
||||||
index += 1
|
|
||||||
assert input_dict_node is not None
|
|
||||||
|
|
||||||
# add shape consistency apply function into graph
|
|
||||||
for node in nodes:
|
|
||||||
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
|
||||||
continue
|
|
||||||
|
|
||||||
for user_node in node.strategies_vector.successor_nodes:
|
|
||||||
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
|
||||||
with mod_graph.inserting_before(user_node):
|
|
||||||
shape_consistency_node = mod_graph.create_node('call_function',
|
|
||||||
runtime_apply,
|
|
||||||
args=(node, origin_dict_node, input_dict_node,
|
|
||||||
node_to_index_dict[node], user_node_index))
|
|
||||||
|
|
||||||
origin_index_args = user_node.args.index(node)
|
|
||||||
new_args = list(user_node.args)
|
|
||||||
new_args[origin_index_args] = shape_consistency_node
|
|
||||||
user_node.args = new_args
|
|
||||||
|
|
||||||
comm_actions = node.best_strategy.communication_actions
|
|
||||||
for op_data, comm_action in comm_actions.items():
|
|
||||||
comm_object = node.args[comm_action.arg_index]
|
|
||||||
if op_data.type == OperationDataType.PARAM:
|
|
||||||
continue
|
|
||||||
if comm_action.comm_type == CommType.BEFORE:
|
|
||||||
with mod_graph.inserting_before(node):
|
|
||||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
|
||||||
runtime_comm_spec_apply,
|
|
||||||
args=(comm_object, comm_actions_dict_node,
|
|
||||||
node_to_index_dict[node], op_data.name))
|
|
||||||
new_args = list(node.args)
|
|
||||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
|
||||||
node.args = new_args
|
|
||||||
elif comm_action.comm_type == CommType.AFTER:
|
|
||||||
with mod_graph.inserting_after(node):
|
|
||||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
|
||||||
runtime_comm_spec_apply,
|
|
||||||
args=(node, comm_actions_dict_node,
|
|
||||||
node_to_index_dict[node], op_data.name))
|
|
||||||
user_list = list(node.users.keys())
|
|
||||||
for user in user_list:
|
|
||||||
if user == comm_spec_apply_node:
|
|
||||||
continue
|
|
||||||
new_args = list(user.args)
|
|
||||||
new_args[new_args.index(node)] = comm_spec_apply_node
|
|
||||||
user.args = tuple(new_args)
|
|
||||||
# TODO: consider other OperationDataType, such as OperationDataType.OUTPUT
|
|
||||||
return gm
|
|
|
@ -10,6 +10,8 @@ from torch.fx import GraphModule
|
||||||
from torchvision.models import resnet34, resnet50
|
from torchvision.models import resnet34, resnet50
|
||||||
|
|
||||||
from colossalai import device
|
from colossalai import device
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||||
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||||
from colossalai.auto_parallel.tensor_shard.constants import *
|
from colossalai.auto_parallel.tensor_shard.constants import *
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||||
|
@ -17,10 +19,6 @@ from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
|
|
||||||
shape_consistency_pass,
|
|
||||||
solution_annotatation_pass,
|
|
||||||
)
|
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
@ -153,8 +151,8 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||||
print(solution)
|
print(solution)
|
||||||
for index, node in enumerate(graph.nodes):
|
for index, node in enumerate(graph.nodes):
|
||||||
print(node.name, node.strategies_vector[solution[index]].name)
|
print(node.name, node.strategies_vector[solution[index]].name)
|
||||||
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
|
||||||
shape_consistency_pass(gm)
|
gm = runtime_apply_pass(gm)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
nodes = [node for node in gm.graph.nodes]
|
nodes = [node for node in gm.graph.nodes]
|
||||||
# TODO: wrap the gm to avoid the influence of the user training code
|
# TODO: wrap the gm to avoid the influence of the user training code
|
||||||
|
|
|
@ -7,6 +7,8 @@ import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||||
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||||
CostGraph,
|
CostGraph,
|
||||||
GraphAnalyser,
|
GraphAnalyser,
|
||||||
|
@ -15,10 +17,6 @@ from colossalai.auto_parallel.tensor_shard.solver import (
|
||||||
StrategiesConstructor,
|
StrategiesConstructor,
|
||||||
)
|
)
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
|
|
||||||
shape_consistency_pass,
|
|
||||||
solution_annotatation_pass,
|
|
||||||
)
|
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
@ -72,8 +70,8 @@ def check_apply(rank, world_size, port):
|
||||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
solution = list(ret[0])
|
solution = list(ret[0])
|
||||||
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
|
||||||
shape_consistency_pass(gm)
|
gm = runtime_apply_pass(gm)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
nodes = [node for node in gm.graph.nodes]
|
nodes = [node for node in gm.graph.nodes]
|
||||||
# TODO: wrap the gm to avoid the influence of the user training code
|
# TODO: wrap the gm to avoid the influence of the user training code
|
||||||
|
|
Loading…
Reference in New Issue