[autoparallel] adapt runtime passes (#1703)

* [autoparallel] adapt runtime passes v2

* polish code
pull/1704/head
YuliangLiu0306 2022-10-14 10:14:07 +08:00 committed by GitHub
parent 21962e1593
commit 451cd72dea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 204 additions and 6 deletions

View File

@ -58,9 +58,6 @@ class CostGraph:
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
if strategies_vector[i].resharding_costs is None:
print(strategies_vector.node.name)
assert False
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
if self.forward_only:
edge_cost[(j, i)] = resharding_cost_item.fwd

View File

@ -90,8 +90,8 @@ class NodeHandler(ABC):
# compute the resharding costs based on the previous node
# strategies if specified
if compute_resharding_cost:
updated_strategies = map(self.update_resharding_cost, strategies)
strategies = list(updated_strategies)
updated_strategies = map(self.update_resharding_cost, post_processed_strategies)
post_processed_strategies = list(updated_strategies)
self.strategies_vector.extend(post_processed_strategies)

View File

@ -52,7 +52,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {

View File

@ -0,0 +1,115 @@
import torch
from typing import List
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
from copy import deepcopy
def apply(*args, **kwargs):
shape_consistency_manager = ShapeConsistencyManager()
return shape_consistency_manager.apply(*args, **kwargs)
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node))
# apply the sharding spec of parameters
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
for name, param in target_module.named_parameters():
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
apply(param, target_weight_sharding_spec)
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
for index, node in enumerate(nodes):
target_sharding_specs = []
if node.name == 'bn1':
print(node.strategies_vector.successor_nodes)
assert False
for user_node in node.strategies_vector.successor_nodes:
# node_index = user_node.strategies_vector.predecessor_nodes.index(node)
# target_sharding_spec = user_node.best_strategy.input_shardings[node_index]
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
def shape_consistency_pass(gm: torch.fx.GraphModule):
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
input_dict_node = None
origin_dict_node = None
# mapping the node into the origin graph index
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
origin_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
continue
node_to_index_dict[node] = index
index += 1
assert input_dict_node is not None
# add shape consistency apply function into graph
for node in nodes:
if not hasattr(node, 'best_strategy'):
continue
with mod_graph.inserting_after(node):
origin_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(origin_dict_node, node_to_index_dict[node]))
with mod_graph.inserting_after(origin_spec_node):
set_sharding_spec_node = mod_graph.create_node('call_function',
builtins.setattr,
args=(node, 'sharding_spec', origin_spec_node))
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
input_specs_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_dict_node, node_to_index_dict[node]))
with mod_graph.inserting_before(user_node):
sharding_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_specs_node, node_index))
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
return gm

View File

@ -0,0 +1,86 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.solver.solver import Solver_V2
from colossalai.auto_parallel.solver.options import SolverOptions
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = self.conv(x)
return x
def check_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input = torch.rand(4, 4, 4, 4).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
entire_shape = torch.Size((4, 4, 8, 8))
tracer = ColoTracer()
model = ConvModel(4, 4).cuda()
origin_output = model(input)
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
output = gm(input, sharding_spec_dict, origin_spec_dict)
assert output.equal(origin_output)
@pytest.mark.skip("for higher testing speed")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():
world_size = 4
run_func = partial(check_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_apply()