mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add reshape handler (#1594)
* [autoparallel] add reshape handler * polish codepull/1604/head
parent
c938dda028
commit
faa23b9d9a
|
@ -2,16 +2,17 @@ import torch
|
|||
import operator
|
||||
|
||||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
|
||||
'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP'
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP'
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
# TODO: flatten should not be added into this group
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.flatten
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
|
||||
CONV_MODULE_OP = [
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d
|
||||
|
@ -23,5 +24,6 @@ LINEAR_MODULE_OP = [torch.nn.Linear]
|
|||
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
||||
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
||||
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
||||
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP
|
||||
|
||||
INFINITY_COST = 1e13
|
||||
|
|
|
@ -2,5 +2,6 @@ from .operator_handler import OperatorHandler
|
|||
from .dot_handler import DotHandler
|
||||
from .conv_handler import ConvHandler
|
||||
from .batch_norm_handler import BatchNormHandler
|
||||
from .reshape_handler import ReshapeHandler
|
||||
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler']
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler']
|
|
@ -8,6 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
|||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
|
||||
from ..sharding_strategy import StrategiesVector
|
||||
|
||||
|
@ -44,7 +45,7 @@ class OperatorHandler(ABC):
|
|||
named_parameters = list(module.named_parameters(recurse=False))
|
||||
# convert named parameters from list to dict
|
||||
named_parameters = {k: v for k, v in named_parameters}
|
||||
elif self.node.op == 'call_function':
|
||||
elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP:
|
||||
module = None
|
||||
parameters = list(self.node.args)[1]
|
||||
named_parameters = {'weight': parameters._meta_data}
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
from .operator_handler import OperatorHandler
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from copy import deepcopy
|
||||
import math
|
||||
|
||||
|
||||
class ReshapeHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
def register_strategy(self):
|
||||
input_node = self.strategies_vector.predecessor_nodes[0]
|
||||
# For reshape function, to keep the computing correctness we keep the sharding
|
||||
# spec of input is fully replicated. In addition, we will keep the output in
|
||||
# replica status and let the successor node choose the way to resharding the
|
||||
# output node. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for reshape function.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict_for_output = {}
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = 0
|
||||
memory_cost = self.node._meta_data.numel()
|
||||
|
||||
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
|
||||
dim_partition_dict_for_replicate_input = {}
|
||||
replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data,
|
||||
dim_partition_dict_for_replicate_input)
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
|
||||
replicate_input_sharding_spec)
|
||||
|
||||
# generate resharding cost
|
||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [0 if cost == 0 else math.inf for cost in resharding_costs[input_node]]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
self.strategies_vector.append(sharding_strategy)
|
|
@ -61,12 +61,17 @@ class StrategiesVector(list):
|
|||
root_module = self.node.graph.owning_module
|
||||
submod = root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
# merge elementwise module node into following nodes
|
||||
# merge elementwise module node into source nodes
|
||||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||
if submod_type in ELEMENTWISE_MODULE_OP:
|
||||
merge_label = True
|
||||
|
||||
if self.node.op == 'call_function':
|
||||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||
if self.node.target in ELEMENTWISE_FUNC_OP:
|
||||
merge_label = True
|
||||
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
||||
if self.node.target in RESHAPE_FUNC_OP:
|
||||
merge_label = True
|
||||
|
||||
return merge_label
|
||||
|
|
|
@ -157,8 +157,7 @@ class StrategiesConstructor:
|
|||
# print(node, node.op, node.target, node.args)
|
||||
# create sharding strategy for element-wise module
|
||||
# input_node = strategies_vector.predecessor_nodes[0]
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
|
||||
norm_handler.register_strategy()
|
||||
# for strategy in norm_handler.strategies_vector:
|
||||
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
|
@ -214,18 +213,22 @@ class StrategiesConstructor:
|
|||
if target in CONV_FUNC_OP:
|
||||
# use ConvHandler to create sharding strategies for conv node
|
||||
# TODO: the operator_handler does NOT support function node processing now.
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear function
|
||||
elif target in LINEAR_FUNC_OP:
|
||||
# use DotHandler to create sharding strategies for linear node
|
||||
# TODO: the operator_handler does NOT support function node processing now.
|
||||
linear_handler = DotHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||
linear_handler.register_strategy()
|
||||
|
||||
# reshape function
|
||||
elif target in RESHAPE_FUNC_OP:
|
||||
# use ReshapeHandler to create sharding strategies for rehsape node
|
||||
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
|
||||
reshape_handler.register_strategy()
|
||||
|
||||
# element-wise function
|
||||
elif target in ELEMENTWISE_FUNC_OP:
|
||||
# TODO: integrate element-wise func and module together
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = torch.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||
# return flatten
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
# [x, conv, flatten, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
strategy_map = strategies_constructor.strategy_map
|
||||
conv_strategies = strategy_map[nodes[1]]
|
||||
flatten_strategies = strategy_map[nodes[2]]
|
||||
flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies]
|
||||
for strategy in conv_strategies:
|
||||
assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_handler()
|
|
@ -14,6 +14,7 @@ from colossalai.auto_parallel.solver import Solver
|
|||
from torchvision.models import resnet34, resnet50
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
@ -81,8 +82,8 @@ def test_cost_graph():
|
|||
liveness_list = graph_analyser.liveness_analysis()
|
||||
# print(len(liveness_dict[0].unique_live_vars))
|
||||
# assert False
|
||||
solver_options = {'fast_mode': True}
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
|
|
Loading…
Reference in New Issue