mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] adapt solver with gpt (#1653)
parent
c638bec028
commit
1e7816a460
|
@ -24,7 +24,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
|||
"""
|
||||
|
||||
if isinstance(input_, Node):
|
||||
assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data'
|
||||
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
|
||||
meta_tensor = input_._meta_data
|
||||
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
|
||||
shape = meta_tensor.shape
|
||||
|
@ -47,7 +47,8 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
|||
def generate_resharding_costs(nodes: List[Node],
|
||||
sharding_specs: List[ShardingSpec],
|
||||
count_backward: Optional[bool] = True,
|
||||
dtype: Optional[torch.dtype] = None):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
index=None):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
|
@ -68,6 +69,9 @@ def generate_resharding_costs(nodes: List[Node],
|
|||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
if not isinstance(input_sharding_spec, ShardingSpec):
|
||||
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
|
||||
input_sharding_spec = input_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
try:
|
||||
# compute the resharding cost
|
||||
|
|
|
@ -9,12 +9,22 @@ __all__ = [
|
|||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout, torch.flatten
|
||||
torch.abs,
|
||||
torch.cos,
|
||||
torch.exp,
|
||||
operator.neg,
|
||||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
torch.flatten,
|
||||
# softmax should not be here
|
||||
torch.nn.functional.softmax
|
||||
]
|
||||
ELEMENTWISE_METHOD_OP = [
|
||||
torch.Tensor.to,
|
||||
torch.Tensor.type,
|
||||
# TODO: contiguous maybe need some extra processes.
|
||||
torch.Tensor.contiguous
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
||||
RESHAPE_METHOD_OP = [
|
||||
|
@ -26,7 +36,7 @@ RESHAPE_METHOD_OP = [
|
|||
]
|
||||
BCAST_FUNC_OP = [
|
||||
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul
|
||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
|
||||
]
|
||||
CONV_MODULE_OP = [
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
|
@ -41,6 +51,34 @@ LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
|||
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
||||
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
|
||||
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
||||
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP
|
||||
NON_PARAM_FUNC_OP = [
|
||||
torch.flatten,
|
||||
torch.reshape,
|
||||
torch.abs,
|
||||
torch.cos,
|
||||
torch.exp,
|
||||
operator.neg,
|
||||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
torch.flatten,
|
||||
torch.where,
|
||||
operator.pow,
|
||||
torch.pow,
|
||||
torch.tanh,
|
||||
torch.add,
|
||||
torch.sub,
|
||||
torch.mul,
|
||||
torch.div,
|
||||
torch.floor_divide,
|
||||
torch.true_divide,
|
||||
operator.add,
|
||||
operator.sub,
|
||||
operator.mul,
|
||||
operator.floordiv,
|
||||
operator.truediv,
|
||||
# softmax should not be here
|
||||
torch.nn.functional.softmax
|
||||
]
|
||||
|
||||
INFINITY_COST = 1e13
|
||||
|
|
|
@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute computation cost
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
|
@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
|
@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
|
@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
|
@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
|
@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
|
@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
|
@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
|
|
|
@ -121,7 +121,13 @@ class OperatorHandler(ABC):
|
|||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
dtype = self.node._meta_data.dtype
|
||||
if hasattr(self.node._meta_data, 'dtype'):
|
||||
dtype = self.node._meta_data.dtype
|
||||
else:
|
||||
assert isinstance(self.node._meta_data,
|
||||
tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
|
||||
dtype = self.node._meta_data[0].dtype
|
||||
|
||||
nodes = self.predecessor_node
|
||||
return generate_resharding_costs(nodes=nodes,
|
||||
sharding_specs=sharding_specs,
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
import colorsys
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from copy import deepcopy
|
||||
import math
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||
import warnings
|
||||
import torch
|
||||
from ..constants import INFINITY_COST
|
||||
|
||||
|
||||
class ReshapeHandler(OperatorHandler):
|
||||
|
@ -19,6 +24,7 @@ class ReshapeHandler(OperatorHandler):
|
|||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@exception_handler
|
||||
def register_strategy(self):
|
||||
# TODO: add strategies with more output sharding specs other than only fully replicated.
|
||||
input_node = self.strategies_vector.predecessor_nodes[0]
|
||||
|
@ -37,11 +43,23 @@ class ReshapeHandler(OperatorHandler):
|
|||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict_for_output = {}
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
if isinstance(self.output_data, tuple):
|
||||
dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
|
||||
try:
|
||||
if isinstance(self.output_data, tuple):
|
||||
output_sharding_spec = []
|
||||
for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
|
||||
output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
|
||||
else:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
continue
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = 0
|
||||
memory_cost = self.node._meta_data.numel()
|
||||
# consider node._meta_data is in type of tuple
|
||||
memory_cost = 0
|
||||
|
||||
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
|
||||
dim_partition_dict_for_replicate_input = {}
|
||||
|
@ -56,7 +74,7 @@ class ReshapeHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [0 if cost == 0 else math.inf for cost in resharding_costs[input_node]]
|
||||
resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver import Solver
|
||||
import transformers
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGHT = 8
|
||||
|
||||
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
def test_cost_graph():
|
||||
physical_mesh_id = torch.arange(0, 8)
|
||||
mesh_shape = (2, 4)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12)
|
||||
model = transformers.GPT2LMHeadModel(config=config)
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
print(graph)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
for check_node, strategies_vector in strategies_constructor.strategy_map.items():
|
||||
print(check_node, len(strategies_vector))
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
|
||||
ret = solver.call_solver_serialized_args()
|
||||
print(ret)
|
||||
strategies_list = list(ret[0])
|
||||
print(strategies_list)
|
||||
computation_cost = 0
|
||||
communication_cost = 0
|
||||
memory_cost = 0
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
for index, node in enumerate(nodes):
|
||||
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
|
||||
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
|
||||
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
|
||||
if isinstance(node_memory_cost, tuple):
|
||||
node_memory_cost = node_memory_cost[0]
|
||||
memory_cost += node_memory_cost
|
||||
|
||||
print(f'computation cost is {computation_cost}')
|
||||
print(f'communication cost is {communication_cost}')
|
||||
print(f'memory cost is {memory_cost}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cost_graph()
|
Loading…
Reference in New Issue