mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] adapt solver with mlp (#1638)
parent
04443605a5
commit
b2b2a4af98
|
@ -4,9 +4,10 @@ from .conv_handler import ConvHandler
|
|||
from .batch_norm_handler import BatchNormHandler
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .bcast_op_handler import BcastOpHandler
|
||||
from .embedding_handler import EmbeddingHandler
|
||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||
|
||||
__all__ = [
|
||||
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
||||
'UnaryElementwiseHandler'
|
||||
'UnaryElementwiseHandler', 'EmbeddingHandler'
|
||||
]
|
||||
|
|
|
@ -410,9 +410,9 @@ class DotHandler(OperatorHandler):
|
|||
self.weight = self.module_named_parameters['weight']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, input_shape, weight_shape):
|
||||
def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size):
|
||||
# TODO: consider bias addition
|
||||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
||||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
@exception_handler
|
||||
|
@ -434,15 +434,17 @@ class DotHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute computation cost
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost
|
||||
# no all-reduce required for this case
|
||||
communication_cost = 0
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
|
||||
communication_cost = communication_cost_activation_backward + communication_cost_weight_backward
|
||||
|
||||
# create and register strategy
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
|
@ -474,14 +476,17 @@ class DotHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
|
||||
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -508,14 +513,18 @@ class DotHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0)
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1)
|
||||
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -542,11 +551,12 @@ class DotHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
||||
|
@ -576,14 +586,16 @@ class DotHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -610,14 +622,16 @@ class DotHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = 0
|
||||
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
|
||||
communication_cost = communication_cost_weight_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -644,14 +658,17 @@ class DotHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(activation_memory_cost, 0)
|
||||
communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
activation_memory_cost, 0)
|
||||
communication_cost = communication_cost_forward_activation
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -678,14 +695,16 @@ class DotHandler(OperatorHandler):
|
|||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = 0
|
||||
communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
input_grad_memory_cost, 0)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
|
|
@ -64,7 +64,8 @@ class OperatorHandler(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight):
|
||||
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight,
|
||||
sharding_spec_for_input):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
|
@ -102,9 +103,21 @@ class OperatorHandler(ABC):
|
|||
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
|
||||
|
||||
total_memory_cost = activation_memory_cost + weight_memory_cost
|
||||
# compute the memory cost of input grad
|
||||
input_grad_numel = self.input_data.numel()
|
||||
input_grad_sharding_size = 1
|
||||
input_grad_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in sharding_spec_for_input.items():
|
||||
input_grad_mesh_dims.extend(mesh_dims)
|
||||
for mesh_dim in input_grad_mesh_dims:
|
||||
input_grad_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes
|
||||
|
||||
return total_memory_cost, activation_memory_cost, weight_memory_cost
|
||||
memory_cost_forward = activation_memory_cost + weight_memory_cost
|
||||
memory_cost_backward = input_grad_memory_cost + weight_memory_cost
|
||||
|
||||
return (memory_cost_forward,
|
||||
memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver import Solver
|
||||
from torchvision.models import resnet34, resnet50
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim * 4)
|
||||
self.linear2 = torch.nn.Linear(dim * 4, dim)
|
||||
self.dropout = torch.nn.Dropout(0)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.dropout(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
def test_cost_graph():
|
||||
physical_mesh_id = torch.arange(0, 8)
|
||||
mesh_shape = (2, 4)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = MLP(32)
|
||||
|
||||
input_sample = {'x': torch.rand(16, 32).to('meta')}
|
||||
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
|
||||
# %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {})
|
||||
# %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {})
|
||||
# %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {})
|
||||
# return linear2
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
# # megatron mode if no memory constraints
|
||||
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
# all sharding on out feature dim if memory budget is not sufficient for megatron mode
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=5500.0)
|
||||
|
||||
ret = solver.call_solver_serialized_args()
|
||||
strategies_list = list(ret[0])
|
||||
computation_cost = 0
|
||||
communication_cost = 0
|
||||
memory_cost = 0
|
||||
for index, node in enumerate(graph.nodes):
|
||||
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
|
||||
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
|
||||
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
|
||||
if isinstance(node_memory_cost, tuple):
|
||||
node_memory_cost = node_memory_cost[0]
|
||||
memory_cost += node_memory_cost
|
||||
|
||||
print(f'computation cost is {computation_cost}')
|
||||
print(f'communication cost is {communication_cost}')
|
||||
print(f'memory cost is {memory_cost}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cost_graph()
|
Loading…
Reference in New Issue