mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] adapt runtime passes (#1703)
* [autoparallel] adapt runtime passes v2 * polish codepull/1704/head
parent
21962e1593
commit
451cd72dea
|
@ -58,9 +58,6 @@ class CostGraph:
|
||||||
edge_cost = {}
|
edge_cost = {}
|
||||||
for i in range(len(strategies_vector)):
|
for i in range(len(strategies_vector)):
|
||||||
for j in range(len(src_node.strategies_vector)):
|
for j in range(len(src_node.strategies_vector)):
|
||||||
if strategies_vector[i].resharding_costs is None:
|
|
||||||
print(strategies_vector.node.name)
|
|
||||||
assert False
|
|
||||||
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
|
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
|
||||||
if self.forward_only:
|
if self.forward_only:
|
||||||
edge_cost[(j, i)] = resharding_cost_item.fwd
|
edge_cost[(j, i)] = resharding_cost_item.fwd
|
||||||
|
|
|
@ -90,8 +90,8 @@ class NodeHandler(ABC):
|
||||||
# compute the resharding costs based on the previous node
|
# compute the resharding costs based on the previous node
|
||||||
# strategies if specified
|
# strategies if specified
|
||||||
if compute_resharding_cost:
|
if compute_resharding_cost:
|
||||||
updated_strategies = map(self.update_resharding_cost, strategies)
|
updated_strategies = map(self.update_resharding_cost, post_processed_strategies)
|
||||||
strategies = list(updated_strategies)
|
post_processed_strategies = list(updated_strategies)
|
||||||
|
|
||||||
self.strategies_vector.extend(post_processed_strategies)
|
self.strategies_vector.extend(post_processed_strategies)
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
|
||||||
total_compute_cost = forward_compute_cost + backward_compute_cost
|
total_compute_cost = forward_compute_cost + backward_compute_cost
|
||||||
|
|
||||||
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
|
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
|
||||||
return compute_cost
|
strategy.compute_cost = compute_cost
|
||||||
|
|
||||||
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||||
forward_size_mapping = {
|
forward_size_mapping = {
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
from torch.fx.node import Node
|
||||||
|
from colossalai.fx.passes.split_module import split_module
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
|
import builtins
|
||||||
|
import operator
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
def apply(*args, **kwargs):
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
return shape_consistency_manager.apply(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
|
||||||
|
mod_graph = gm.graph
|
||||||
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
|
||||||
|
# the dict to get origin sharding spec of node
|
||||||
|
origin_node_sharding_spec_dict = {}
|
||||||
|
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
|
||||||
|
strategies_vector = node.strategies_vector
|
||||||
|
setattr(node, 'best_strategy', strategies_vector[strategy_index])
|
||||||
|
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
|
||||||
|
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||||
|
str(node))
|
||||||
|
|
||||||
|
# apply the sharding spec of parameters
|
||||||
|
for node in nodes:
|
||||||
|
if node.op == 'call_module':
|
||||||
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
|
for name, param in target_module.named_parameters():
|
||||||
|
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||||
|
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||||
|
target_weight_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||||
|
apply(param, target_weight_sharding_spec)
|
||||||
|
|
||||||
|
# the dict to get input sharding specs of user node
|
||||||
|
sharding_spec_convert_dict = {}
|
||||||
|
for index, node in enumerate(nodes):
|
||||||
|
target_sharding_specs = []
|
||||||
|
if node.name == 'bn1':
|
||||||
|
print(node.strategies_vector.successor_nodes)
|
||||||
|
assert False
|
||||||
|
for user_node in node.strategies_vector.successor_nodes:
|
||||||
|
# node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
||||||
|
# target_sharding_spec = user_node.best_strategy.input_shardings[node_index]
|
||||||
|
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||||
|
target_sharding_specs.append(target_sharding_spec)
|
||||||
|
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||||
|
|
||||||
|
# add above dicts into graph
|
||||||
|
for node in nodes:
|
||||||
|
if node.op != 'placeholder':
|
||||||
|
with mod_graph.inserting_before(node):
|
||||||
|
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
|
||||||
|
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
|
||||||
|
break
|
||||||
|
|
||||||
|
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
|
||||||
|
|
||||||
|
|
||||||
|
def shape_consistency_pass(gm: torch.fx.GraphModule):
|
||||||
|
mod_graph = gm.graph
|
||||||
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
input_dict_node = None
|
||||||
|
origin_dict_node = None
|
||||||
|
|
||||||
|
# mapping the node into the origin graph index
|
||||||
|
node_to_index_dict = {}
|
||||||
|
index = 0
|
||||||
|
for node in nodes:
|
||||||
|
if node.target == 'sharding_spec_convert_dict':
|
||||||
|
input_dict_node = node
|
||||||
|
continue
|
||||||
|
if node.target == 'origin_node_sharding_spec_dict':
|
||||||
|
origin_dict_node = node
|
||||||
|
continue
|
||||||
|
if not hasattr(node, 'best_strategy'):
|
||||||
|
continue
|
||||||
|
node_to_index_dict[node] = index
|
||||||
|
index += 1
|
||||||
|
assert input_dict_node is not None
|
||||||
|
|
||||||
|
# add shape consistency apply function into graph
|
||||||
|
for node in nodes:
|
||||||
|
if not hasattr(node, 'best_strategy'):
|
||||||
|
continue
|
||||||
|
with mod_graph.inserting_after(node):
|
||||||
|
origin_spec_node = mod_graph.create_node('call_function',
|
||||||
|
operator.getitem,
|
||||||
|
args=(origin_dict_node, node_to_index_dict[node]))
|
||||||
|
with mod_graph.inserting_after(origin_spec_node):
|
||||||
|
set_sharding_spec_node = mod_graph.create_node('call_function',
|
||||||
|
builtins.setattr,
|
||||||
|
args=(node, 'sharding_spec', origin_spec_node))
|
||||||
|
|
||||||
|
for user_node in node.strategies_vector.successor_nodes:
|
||||||
|
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
||||||
|
with mod_graph.inserting_before(user_node):
|
||||||
|
input_specs_node = mod_graph.create_node('call_function',
|
||||||
|
operator.getitem,
|
||||||
|
args=(input_dict_node, node_to_index_dict[node]))
|
||||||
|
with mod_graph.inserting_before(user_node):
|
||||||
|
sharding_spec_node = mod_graph.create_node('call_function',
|
||||||
|
operator.getitem,
|
||||||
|
args=(input_specs_node, node_index))
|
||||||
|
with mod_graph.inserting_before(user_node):
|
||||||
|
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
|
||||||
|
|
||||||
|
return gm
|
|
@ -0,0 +1,86 @@
|
||||||
|
from functools import partial
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||||
|
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||||
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
|
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
||||||
|
from colossalai.auto_parallel.solver.solver import Solver_V2
|
||||||
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
|
|
||||||
|
|
||||||
|
class ConvModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, c_in, c_out):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def check_apply(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
input = torch.rand(4, 4, 4, 4).cuda()
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
|
||||||
|
entire_shape = torch.Size((4, 4, 8, 8))
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
model = ConvModel(4, 4).cuda()
|
||||||
|
origin_output = model(input)
|
||||||
|
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
|
||||||
|
# graph():
|
||||||
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
|
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||||
|
# return conv
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
solver_options = SolverOptions(fast=True)
|
||||||
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
|
cost_graph.simplify_graph()
|
||||||
|
graph_analyser = GraphAnalyser(gm)
|
||||||
|
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||||
|
ret = solver.call_solver_serialized_args()
|
||||||
|
solution = list(ret[0])
|
||||||
|
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
|
||||||
|
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||||
|
shape_consistency_pass(gm)
|
||||||
|
gm.recompile()
|
||||||
|
nodes = [node for node in gm.graph.nodes]
|
||||||
|
# TODO: wrap the gm to avoid the influence of the user training code
|
||||||
|
output = gm(input, sharding_spec_dict, origin_spec_dict)
|
||||||
|
assert output.equal(origin_output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("for higher testing speed")
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_apply():
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(check_apply, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_apply()
|
Loading…
Reference in New Issue