[autoparallel] adapt solver with mlp (#1638)

pull/1650/head
YuliangLiu0306 2022-09-26 15:26:14 +08:00 committed by GitHub
parent 04443605a5
commit b2b2a4af98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 165 additions and 39 deletions

View File

@ -4,9 +4,10 @@ from .conv_handler import ConvHandler
from .batch_norm_handler import BatchNormHandler
from .reshape_handler import ReshapeHandler
from .bcast_op_handler import BcastOpHandler
from .embedding_handler import EmbeddingHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
__all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler'
'UnaryElementwiseHandler', 'EmbeddingHandler'
]

View File

@ -410,9 +410,9 @@ class DotHandler(OperatorHandler):
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
def _generate_compute_cost(self, input_shape, weight_shape):
def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size):
# TODO: consider bias addition
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
return compute_cost
@exception_handler
@ -434,15 +434,17 @@ class DotHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute computation cost
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost
# no all-reduce required for this case
communication_cost = 0
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_backward + communication_cost_weight_backward
# create and register strategy
sharding_strategies = ShardingStrategy(name,
@ -474,14 +476,17 @@ class DotHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -508,14 +513,18 @@ class DotHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0)
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1)
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -542,11 +551,12 @@ class DotHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
total_sharding_size = self.device_mesh.shape[mesh_dim]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
@ -576,14 +586,16 @@ class DotHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
total_sharding_size = self.device_mesh.shape[mesh_dim]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -610,14 +622,16 @@ class DotHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = 0
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
communication_cost = communication_cost_weight_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -644,14 +658,17 @@ class DotHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(activation_memory_cost, 0)
communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost(
activation_memory_cost, 0)
communication_cost = communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -678,14 +695,16 @@ class DotHandler(OperatorHandler):
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = 0
communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
input_grad_memory_cost, 0)
communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,

View File

@ -64,7 +64,8 @@ class OperatorHandler(ABC):
"""
pass
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight):
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight,
sharding_spec_for_input):
'''
Compute the memory cost per device with this specific strategy.
@ -102,9 +103,21 @@ class OperatorHandler(ABC):
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
total_memory_cost = activation_memory_cost + weight_memory_cost
# compute the memory cost of input grad
input_grad_numel = self.input_data.numel()
input_grad_sharding_size = 1
input_grad_mesh_dims = []
for sharding_dim, mesh_dims in sharding_spec_for_input.items():
input_grad_mesh_dims.extend(mesh_dims)
for mesh_dim in input_grad_mesh_dims:
input_grad_sharding_size *= self.device_mesh.shape[mesh_dim]
input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes
return total_memory_cost, activation_memory_cost, weight_memory_cost
memory_cost_forward = activation_memory_cost + weight_memory_cost
memory_cost_backward = input_grad_memory_cost + weight_memory_cost
return (memory_cost_forward,
memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.

View File

@ -0,0 +1,93 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.solver import Solver
from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions
class MLP(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim * 4)
self.linear2 = torch.nn.Linear(dim * 4, dim)
self.dropout = torch.nn.Dropout(0)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.linear1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.linear2(x)
return x
@pytest.mark.skip("for higher testing speed")
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = MLP(32)
input_sample = {'x': torch.rand(16, 32).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
# %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {})
# %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {})
# return linear2
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
# # megatron mode if no memory constraints
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
# all sharding on out feature dim if memory budget is not sufficient for megatron mode
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=5500.0)
ret = solver.call_solver_serialized_args()
strategies_list = list(ret[0])
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_cost_graph()