mirror of https://github.com/hpcaitech/ColossalAI
parent
6135e178b3
commit
59f100510a
|
@ -0,0 +1,181 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
||||
|
||||
__all__ = ['WhereHandler']
|
||||
|
||||
|
||||
class WhereHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of torch.where.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# TODO: x or y could be scalar
|
||||
super().__init__(*args, **kwargs)
|
||||
assert len(self.predecessor_node) == 3
|
||||
self.condition_data = self.predecessor_node[0]._meta_data
|
||||
self.x_data = self.predecessor_node[1]._meta_data
|
||||
self.y_data = self.predecessor_node[2]._meta_data
|
||||
self.condition = self.predecessor_node[0]
|
||||
self.x = self.predecessor_node[1]
|
||||
self.y = self.predecessor_node[2]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
shape = list(input_.shape)
|
||||
|
||||
# padding the shape to the same length as output_data
|
||||
while len(shape) < self.output_data.dim():
|
||||
shape.insert(0, 1)
|
||||
shape = torch.Size(shape)
|
||||
|
||||
# if the sharding happens on a size one dimension, we should record it as R.
|
||||
processed_dim_partition_dict = deepcopy(dim_partition_dict)
|
||||
for dim_index, _ in dim_partition_dict.items():
|
||||
if shape[dim_index] == 1:
|
||||
processed_dim_partition_dict.pop(dim_index)
|
||||
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
|
||||
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=shape,
|
||||
dim_partition_dict=processed_dim_partition_dict)
|
||||
|
||||
return sharding_spec
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
lhs_matrix_shape = self.lhs_data.shape[-2:]
|
||||
rhs_matrix_shape = self.rhs_data.shape[-2:]
|
||||
batch_dimensions_shape = self.output_data.shape[:-2]
|
||||
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
|
||||
compute_cost = reduce(
|
||||
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
dtype = self.node._meta_data.dtype
|
||||
nodes = self.predecessor_node
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
|
||||
# Then, use the padded input sharding spec to compute the resharding cost.
|
||||
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape = list(input_sharding_spec.entire_shape)
|
||||
while len(new_entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape.insert(0, 1)
|
||||
new_entire_shape = torch.Size(new_entire_shape)
|
||||
new_device_mesh = input_sharding_spec.device_mesh
|
||||
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
|
||||
entire_shape=new_entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
|
||||
return resharding_costs
|
||||
|
||||
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
|
||||
|
||||
sharding_spec_list = []
|
||||
check_duplicated_list = []
|
||||
for output_dim_partition_dict in dim_partition_list:
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
break
|
||||
sharding_seq = output_sharding_spec.sharding_sequence
|
||||
if sharding_seq not in check_duplicated_list:
|
||||
check_duplicated_list.append(sharding_seq)
|
||||
sharding_spec_list.append(output_sharding_spec)
|
||||
|
||||
return sharding_spec_list
|
||||
|
||||
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||
|
||||
output_dim_partition_list = []
|
||||
dim_size = self.output_data.dim()
|
||||
# enumerate all the 2D sharding cases
|
||||
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_2d)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
output_dim_partition_list.append({})
|
||||
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
|
||||
|
||||
return output_sharding_spec_list
|
||||
|
||||
@exception_handler
|
||||
def _register_strategy(self, output_sharding_spec):
|
||||
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_x = self._generate_sharding_spec(self.x_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_y = self._generate_sharding_spec(self.y_data, dim_partition_dict_for_input)
|
||||
|
||||
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}'
|
||||
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs(
|
||||
[sharding_spec_for_condition, sharding_spec_for_x, sharding_spec_for_y])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
sharding_dims = []
|
||||
for mesh_dims in dim_partition_dict_for_output.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
sharding_size = reduce(operator.mul, sharding_dims, 1)
|
||||
memory_cost = self.output_data.numel() / sharding_size
|
||||
compute_cost = memory_cost
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_condition, sharding_spec_for_x,
|
||||
sharding_spec_for_y))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
MESH_DIM_LIST = [0, 1]
|
||||
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
|
||||
for output_sharding_spec in output_sharding_specs:
|
||||
self._register_strategy(output_sharding_spec)
|
|
@ -0,0 +1,65 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
|
||||
def forward(self, condition, x, y):
|
||||
output = torch.where(condition, x, y)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skip("temporarily skipped")
|
||||
def test_where_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {
|
||||
'condition': torch.rand(16, 32).to('meta'),
|
||||
'x': torch.rand(16, 32).to('meta'),
|
||||
'y': torch.rand(16, 32).to('meta')
|
||||
}
|
||||
# graph():
|
||||
# %condition : torch.Tensor [#users=1] = placeholder[target=condition]
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %y : torch.Tensor [#users=1] = placeholder[target=y]
|
||||
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
|
||||
# return where
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
# [condition, x, y, where, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
strategy_map = strategies_constructor.strategy_map
|
||||
# check a tensor add with a scalar case
|
||||
where_node = strategy_map[nodes[3]]
|
||||
# ['[S0, S1] = [S0, S1] x [S0, S1] x [S0, S1]', '[S1, S0] = [S1, S0] x [S1, S0] x [S1, S0]', '[S01, R] = [S01, R] x [S01, R] x [S01, R]',
|
||||
# '[R, S01] = [R, S01] x [R, S01] x [R, S01]', '[S0, R] = [S0, R] x [S0, R] x [S0, R]', '[R, S0] = [R, S0] x [R, S0] x [R, S0]',
|
||||
# '[S1, R] = [S1, R] x [S1, R] x [S1, R]', '[R, S1] = [R, S1] x [R, S1] x [R, S1]', '[R, R] = [R, R] x [R, R] x [R, R]']
|
||||
assert len(where_node) == 9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_where_handler()
|
Loading…
Reference in New Issue