mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] adapt solver with gpt (#1653)
parent
c638bec028
commit
1e7816a460
|
@ -24,7 +24,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(input_, Node):
|
if isinstance(input_, Node):
|
||||||
assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data'
|
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
|
||||||
meta_tensor = input_._meta_data
|
meta_tensor = input_._meta_data
|
||||||
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
|
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
|
||||||
shape = meta_tensor.shape
|
shape = meta_tensor.shape
|
||||||
|
@ -47,7 +47,8 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
|
||||||
def generate_resharding_costs(nodes: List[Node],
|
def generate_resharding_costs(nodes: List[Node],
|
||||||
sharding_specs: List[ShardingSpec],
|
sharding_specs: List[ShardingSpec],
|
||||||
count_backward: Optional[bool] = True,
|
count_backward: Optional[bool] = True,
|
||||||
dtype: Optional[torch.dtype] = None):
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
index=None):
|
||||||
'''
|
'''
|
||||||
Compute the resharding costs with this specific strategy.
|
Compute the resharding costs with this specific strategy.
|
||||||
|
|
||||||
|
@ -68,6 +69,9 @@ def generate_resharding_costs(nodes: List[Node],
|
||||||
resharding_costs[input_node] = []
|
resharding_costs[input_node] = []
|
||||||
for strategy in input_node.strategies_vector:
|
for strategy in input_node.strategies_vector:
|
||||||
input_sharding_spec = strategy.output_sharding_spec
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
if not isinstance(input_sharding_spec, ShardingSpec):
|
||||||
|
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
|
||||||
|
input_sharding_spec = input_sharding_spec[index]
|
||||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
try:
|
try:
|
||||||
# compute the resharding cost
|
# compute the resharding cost
|
||||||
|
|
|
@ -9,12 +9,22 @@ __all__ = [
|
||||||
|
|
||||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||||
ELEMENTWISE_FUNC_OP = [
|
ELEMENTWISE_FUNC_OP = [
|
||||||
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
|
torch.abs,
|
||||||
torch.nn.functional.dropout, torch.flatten
|
torch.cos,
|
||||||
|
torch.exp,
|
||||||
|
operator.neg,
|
||||||
|
torch.multiply,
|
||||||
|
torch.nn.functional.relu,
|
||||||
|
torch.nn.functional.dropout,
|
||||||
|
torch.flatten,
|
||||||
|
# softmax should not be here
|
||||||
|
torch.nn.functional.softmax
|
||||||
]
|
]
|
||||||
ELEMENTWISE_METHOD_OP = [
|
ELEMENTWISE_METHOD_OP = [
|
||||||
torch.Tensor.to,
|
torch.Tensor.to,
|
||||||
torch.Tensor.type,
|
torch.Tensor.type,
|
||||||
|
# TODO: contiguous maybe need some extra processes.
|
||||||
|
torch.Tensor.contiguous
|
||||||
]
|
]
|
||||||
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
||||||
RESHAPE_METHOD_OP = [
|
RESHAPE_METHOD_OP = [
|
||||||
|
@ -26,7 +36,7 @@ RESHAPE_METHOD_OP = [
|
||||||
]
|
]
|
||||||
BCAST_FUNC_OP = [
|
BCAST_FUNC_OP = [
|
||||||
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul
|
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
|
||||||
]
|
]
|
||||||
CONV_MODULE_OP = [
|
CONV_MODULE_OP = [
|
||||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||||
|
@ -41,6 +51,34 @@ LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
||||||
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
||||||
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
|
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
|
||||||
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
||||||
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP
|
NON_PARAM_FUNC_OP = [
|
||||||
|
torch.flatten,
|
||||||
|
torch.reshape,
|
||||||
|
torch.abs,
|
||||||
|
torch.cos,
|
||||||
|
torch.exp,
|
||||||
|
operator.neg,
|
||||||
|
torch.multiply,
|
||||||
|
torch.nn.functional.relu,
|
||||||
|
torch.nn.functional.dropout,
|
||||||
|
torch.flatten,
|
||||||
|
torch.where,
|
||||||
|
operator.pow,
|
||||||
|
torch.pow,
|
||||||
|
torch.tanh,
|
||||||
|
torch.add,
|
||||||
|
torch.sub,
|
||||||
|
torch.mul,
|
||||||
|
torch.div,
|
||||||
|
torch.floor_divide,
|
||||||
|
torch.true_divide,
|
||||||
|
operator.add,
|
||||||
|
operator.sub,
|
||||||
|
operator.mul,
|
||||||
|
operator.floordiv,
|
||||||
|
operator.truediv,
|
||||||
|
# softmax should not be here
|
||||||
|
torch.nn.functional.softmax
|
||||||
|
]
|
||||||
|
|
||||||
INFINITY_COST = 1e13
|
INFINITY_COST = 1e13
|
||||||
|
|
|
@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute computation cost
|
# compute computation cost
|
||||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
|
@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
|
@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
|
@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||||
|
@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||||
|
@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
|
@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
|
@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
|
|
|
@ -121,7 +121,13 @@ class OperatorHandler(ABC):
|
||||||
|
|
||||||
def _generate_resharding_costs(self, sharding_specs):
|
def _generate_resharding_costs(self, sharding_specs):
|
||||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||||
dtype = self.node._meta_data.dtype
|
if hasattr(self.node._meta_data, 'dtype'):
|
||||||
|
dtype = self.node._meta_data.dtype
|
||||||
|
else:
|
||||||
|
assert isinstance(self.node._meta_data,
|
||||||
|
tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
|
||||||
|
dtype = self.node._meta_data[0].dtype
|
||||||
|
|
||||||
nodes = self.predecessor_node
|
nodes = self.predecessor_node
|
||||||
return generate_resharding_costs(nodes=nodes,
|
return generate_resharding_costs(nodes=nodes,
|
||||||
sharding_specs=sharding_specs,
|
sharding_specs=sharding_specs,
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
|
import colorsys
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import math
|
import math
|
||||||
|
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
from ..constants import INFINITY_COST
|
||||||
|
|
||||||
|
|
||||||
class ReshapeHandler(OperatorHandler):
|
class ReshapeHandler(OperatorHandler):
|
||||||
|
@ -19,6 +24,7 @@ class ReshapeHandler(OperatorHandler):
|
||||||
def _generate_compute_cost(self, *args, **kwargs):
|
def _generate_compute_cost(self, *args, **kwargs):
|
||||||
return super()._generate_compute_cost(*args, **kwargs)
|
return super()._generate_compute_cost(*args, **kwargs)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def register_strategy(self):
|
def register_strategy(self):
|
||||||
# TODO: add strategies with more output sharding specs other than only fully replicated.
|
# TODO: add strategies with more output sharding specs other than only fully replicated.
|
||||||
input_node = self.strategies_vector.predecessor_nodes[0]
|
input_node = self.strategies_vector.predecessor_nodes[0]
|
||||||
|
@ -37,11 +43,23 @@ class ReshapeHandler(OperatorHandler):
|
||||||
continue
|
continue
|
||||||
sharding_spec_checklist.append(input_sharding_spec)
|
sharding_spec_checklist.append(input_sharding_spec)
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
if isinstance(self.output_data, tuple):
|
||||||
|
dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
|
||||||
|
try:
|
||||||
|
if isinstance(self.output_data, tuple):
|
||||||
|
output_sharding_spec = []
|
||||||
|
for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
|
||||||
|
output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
|
||||||
|
else:
|
||||||
|
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
except AssertionError as e:
|
||||||
|
warnings.warn(f'{e}')
|
||||||
|
continue
|
||||||
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
|
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
|
||||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||||
compute_cost = 0
|
compute_cost = 0
|
||||||
memory_cost = self.node._meta_data.numel()
|
# consider node._meta_data is in type of tuple
|
||||||
|
memory_cost = 0
|
||||||
|
|
||||||
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
|
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
|
||||||
dim_partition_dict_for_replicate_input = {}
|
dim_partition_dict_for_replicate_input = {}
|
||||||
|
@ -56,7 +74,7 @@ class ReshapeHandler(OperatorHandler):
|
||||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||||
|
|
||||||
# to prevent the resharding happening, set their resharding cost to inf.
|
# to prevent the resharding happening, set their resharding cost to inf.
|
||||||
resharding_costs[input_node] = [0 if cost == 0 else math.inf for cost in resharding_costs[input_node]]
|
resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
|
||||||
sharding_strategy = ShardingStrategy(name,
|
sharding_strategy = ShardingStrategy(name,
|
||||||
output_sharding_spec,
|
output_sharding_spec,
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
|
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||||
|
from copy import deepcopy
|
||||||
|
from colossalai.auto_parallel.solver import Solver
|
||||||
|
import transformers
|
||||||
|
from colossalai.auto_parallel.solver.constants import *
|
||||||
|
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||||
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
|
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
SEQ_LENGHT = 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("for higher testing speed")
|
||||||
|
def test_cost_graph():
|
||||||
|
physical_mesh_id = torch.arange(0, 8)
|
||||||
|
mesh_shape = (2, 4)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12)
|
||||||
|
model = transformers.GPT2LMHeadModel(config=config)
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||||
|
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
graph_analyser = GraphAnalyser(gm)
|
||||||
|
liveness_list = graph_analyser.liveness_analysis()
|
||||||
|
solver_options = SolverOptions(fast=True)
|
||||||
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
|
print(graph)
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
for check_node, strategies_vector in strategies_constructor.strategy_map.items():
|
||||||
|
print(check_node, len(strategies_vector))
|
||||||
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
|
cost_graph.simplify_graph()
|
||||||
|
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
|
||||||
|
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||||
|
|
||||||
|
ret = solver.call_solver_serialized_args()
|
||||||
|
print(ret)
|
||||||
|
strategies_list = list(ret[0])
|
||||||
|
print(strategies_list)
|
||||||
|
computation_cost = 0
|
||||||
|
communication_cost = 0
|
||||||
|
memory_cost = 0
|
||||||
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||||
|
for index, node in enumerate(nodes):
|
||||||
|
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||||
|
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
|
||||||
|
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
|
||||||
|
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
|
||||||
|
if isinstance(node_memory_cost, tuple):
|
||||||
|
node_memory_cost = node_memory_cost[0]
|
||||||
|
memory_cost += node_memory_cost
|
||||||
|
|
||||||
|
print(f'computation cost is {computation_cost}')
|
||||||
|
print(f'communication cost is {communication_cost}')
|
||||||
|
print(f'memory cost is {memory_cost}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_cost_graph()
|
Loading…
Reference in New Issue