[autoparallel] adapt solver with gpt (#1653)

pull/1664/head
YuliangLiu0306 2022-09-28 11:17:26 +08:00 committed by GitHub
parent c638bec028
commit 1e7816a460
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 164 additions and 18 deletions

View File

@ -24,7 +24,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
"""
if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data'
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
@ -47,7 +47,8 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None):
dtype: Optional[torch.dtype] = None,
index=None):
'''
Compute the resharding costs with this specific strategy.
@ -68,6 +69,9 @@ def generate_resharding_costs(nodes: List[Node],
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
input_sharding_spec = input_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
try:
# compute the resharding cost

View File

@ -9,12 +9,22 @@ __all__ = [
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
torch.nn.functional.dropout, torch.flatten
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
# softmax should not be here
torch.nn.functional.softmax
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous
]
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
RESHAPE_METHOD_OP = [
@ -26,7 +36,7 @@ RESHAPE_METHOD_OP = [
]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv, torch.matmul
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
@ -41,6 +51,34 @@ LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP
NON_PARAM_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
torch.where,
operator.pow,
torch.pow,
torch.tanh,
torch.add,
torch.sub,
torch.mul,
torch.div,
torch.floor_divide,
torch.true_divide,
operator.add,
operator.sub,
operator.mul,
operator.floordiv,
operator.truediv,
# softmax should not be here
torch.nn.functional.softmax
]
INFINITY_COST = 1e13

View File

@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute computation cost
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim]
@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim]
@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]

View File

@ -121,7 +121,13 @@ class OperatorHandler(ABC):
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype = self.node._meta_data.dtype
if hasattr(self.node._meta_data, 'dtype'):
dtype = self.node._meta_data.dtype
else:
assert isinstance(self.node._meta_data,
tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
dtype = self.node._meta_data[0].dtype
nodes = self.predecessor_node
return generate_resharding_costs(nodes=nodes,
sharding_specs=sharding_specs,

View File

@ -1,9 +1,14 @@
import colorsys
from .operator_handler import OperatorHandler
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from copy import deepcopy
import math
from colossalai.auto_parallel.solver._utils import exception_handler
import warnings
import torch
from ..constants import INFINITY_COST
class ReshapeHandler(OperatorHandler):
@ -19,6 +24,7 @@ class ReshapeHandler(OperatorHandler):
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@exception_handler
def register_strategy(self):
# TODO: add strategies with more output sharding specs other than only fully replicated.
input_node = self.strategies_vector.predecessor_nodes[0]
@ -37,11 +43,23 @@ class ReshapeHandler(OperatorHandler):
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict_for_output = {}
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
if isinstance(self.output_data, tuple):
dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
try:
if isinstance(self.output_data, tuple):
output_sharding_spec = []
for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
else:
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
except AssertionError as e:
warnings.warn(f'{e}')
continue
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = 0
memory_cost = self.node._meta_data.numel()
# consider node._meta_data is in type of tuple
memory_cost = 0
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
dim_partition_dict_for_replicate_input = {}
@ -56,7 +74,7 @@ class ReshapeHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [0 if cost == 0 else math.inf for cost in resharding_costs[input_node]]
resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,

View File

@ -0,0 +1,80 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.solver import Solver
import transformers
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions
BATCH_SIZE = 8
SEQ_LENGHT = 8
@pytest.mark.skip("for higher testing speed")
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12)
model = transformers.GPT2LMHeadModel(config=config)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
print(graph)
strategies_constructor.build_strategies_and_cost()
for check_node, strategies_vector in strategies_constructor.strategy_map.items():
print(check_node, len(strategies_vector))
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
print(ret)
strategies_list = list(ret[0])
print(strategies_list)
computation_cost = 0
communication_cost = 0
memory_cost = 0
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_cost_graph()