[autoparallel] use pytree map style to process data (#1989)

pull/1999/head
YuliangLiu0306 2022-11-21 10:44:22 +08:00 committed by GitHub
parent 35e6b9ec82
commit 155891113e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 178 additions and 66 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Union from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
@ -7,6 +7,7 @@ from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
OperationDataType, OperationDataType,
ShardingSpec,
ShardingStrategy, ShardingStrategy,
StrategiesVector, StrategiesVector,
TrainCycleItem, TrainCycleItem,
@ -52,12 +53,14 @@ class NodeHandler(ABC):
node_name = str(node) node_name = str(node)
# get the current sharding spec generated by this node handler # get the current sharding spec generated by this node handler
# TODO: we need to check this in future # we will not compute the resharding costs for the node not counted in the strategy.
if not isinstance(node._meta_data, torch.Tensor): # And the node with tuple or list output need to be handled below.
node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()]
if str(node) not in node_in_strategy:
continue continue
op_data = strategy.get_op_data_by_name(node_name) op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data] current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated # get the sharding specs for this node generated
# in its own node handler # in its own node handler
assert hasattr(node, 'strategies_vector'), \ assert hasattr(node, 'strategies_vector'), \
@ -68,23 +71,64 @@ class NodeHandler(ABC):
] ]
# create data structrure to store costs # create data structrure to store costs
if op_data not in resharding_costs: if node not in resharding_costs:
resharding_costs[node] = [] resharding_costs[node] = []
def _compute_resharding_cost(
prev_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]],
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
if prev_sharding_spec is None:
return TrainCycleItem(fwd=0, bwd=0, total=0)
elif isinstance(prev_sharding_spec, ShardingSpec):
if isinstance(data, torch.nn.parameter.Parameter):
# we won't compute the resharding cost for the parameters,
# since the parameters will be sharded before runtime and
# not converted during runtime.
return TrainCycleItem(fwd=0, bwd=0, total=0)
elif isinstance(data, torch.Tensor):
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes,
total=consistency_cost["total"] * size_per_elem_bytes)
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
raise ValueError(f'Unsupported data type {type(data)}')
else:
assert isinstance(prev_sharding_spec, (tuple, list)), \
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
fwd_cost = 0
bwd_cost = 0
total_cost = 0
for index, (prev_sharding_spec_item,
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
current_sharding_spec)):
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
data[index])
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost)
return resharding_cost
# for each sharding spec generated by the predecessor's node handler # for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated # compute the resharding cost to switch to the sharding spec generated
# by the current node handler # by the current node handler
for prev_sharding_spec in prev_sharding_specs: for prev_sharding_spec in prev_sharding_specs:
if op_data.type == OperationDataType.PARAM: resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data)
resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
else:
dtype = op_data.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, resharding_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"] * size_per_elem_bytes,
bwd=resharding_cost["backward"] * size_per_elem_bytes,
total=resharding_cost["total"] * size_per_elem_bytes)
resharding_costs[node].append(resharding_cost) resharding_costs[node].append(resharding_cost)
strategy.resharding_costs = resharding_costs strategy.resharding_costs = resharding_costs
return strategy return strategy

View File

@ -68,32 +68,41 @@ class StrategyGenerator(ABC):
Args: Args:
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary. mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
Notes:
The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data.
However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as
list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned.
""" """
results = {} results = {}
for op_data_name, dim_partition_dict in mapping.items(): for op_data_name, dim_partition_dict in mapping.items():
if op_data_name in self.op_data: if op_data_name in self.op_data:
op_data = self.op_data[op_data_name] op_data = self.op_data[op_data_name]
if isinstance(op_data.data, tuple):
for data in op_data.data: def _to_sharding_spec(
assert isinstance( data: any, logical_shape: any,
data, torch.Tensor), 'We cannot create a ShardingSpec object from a non-tensor object.' dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
sharding_spec = [] """
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict): This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape) dim_size = len(logical_shape)
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element) dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh, sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape, entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict_element) dim_partition_dict=dim_partition_dict)
sharding_spec.append(sharding_spec_element) return sharding_spec
else: elif isinstance(data, (list, tuple)):
assert isinstance( sharding_spec = []
op_data.data, torch.Tensor for data_element, logical_shape_element, dim_partition_dict_element in zip(
), f'op_data.data should be a torch.Tensor or Tuple[torch.Tensor], but got {type(op_data.data)}' data, logical_shape, dim_partition_dict):
dim_size = len(op_data.logical_shape) sharding_spec.append(
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, return sharding_spec
entire_shape=op_data.logical_shape, else:
dim_partition_dict=dim_partition_dict) return None
sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict)
results[op_data_name] = sharding_spec results[op_data_name] = sharding_spec
return results return results
@ -285,6 +294,5 @@ class OutputStrategyGenerator(StrategyGenerator):
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node]): predecessor_nodes: List[Node]):
self.op_data = operation_data_mapping super().__init__(operation_data_mapping, device_mesh)
self.device_mesh = device_mesh
self.predecessor_nodes = predecessor_nodes self.predecessor_nodes = predecessor_nodes

View File

@ -44,10 +44,20 @@ class OperationData:
def __post_init__(self): def __post_init__(self):
# if no logical shape is specified, use the data shape as the logical shape # if no logical shape is specified, use the data shape as the logical shape
if self.logical_shape is None: if self.logical_shape is None:
if isinstance(self.data, torch.Tensor):
self.logical_shape = self.data.shape def _infer_logical_shape(data: any):
elif isinstance(self.data, tuple): """
self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data]) This function is used to infer the logical shape of the data.
"""
if isinstance(data, torch.Tensor):
return data.shape
elif isinstance(data, (tuple, list)):
data_type = type(data)
return data_type([_infer_logical_shape(d) for d in data])
else:
return None
self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})' return f'OperationData(name={self.name}, type={self.type})'
@ -216,8 +226,6 @@ class StrategiesVector(list):
# fetch its input and output nodes # fetch its input and output nodes
# TODO: placeholder input nodes # TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys()) self.predecessor_nodes = list(node._input_nodes.keys())
if self.node.op == 'output':
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
self.successor_nodes = list(node.users.keys()) self.successor_nodes = list(node.users.keys())
def check_merge(self): def check_merge(self):

View File

@ -1,13 +1,14 @@
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
import torch import torch
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph: class CostGraph:
''' '''
A graph data structure to simplify the edge cost graph. It has two main functions: A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as 2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes. be given by the StrategiesVector depending on the type of target node and following nodes.
@ -66,8 +67,6 @@ class CostGraph:
children_nodes = [node for node in strategies_vector.successor_nodes] children_nodes = [node for node in strategies_vector.successor_nodes]
setattr(dst_node, 'parents', parent_nodes) setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes) setattr(dst_node, 'children', children_nodes)
# self._remove_invalid_node(dst_node, 'parents')
# self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge(): if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes: for followed_node in strategies_vector.predecessor_nodes:
@ -79,14 +78,14 @@ class CostGraph:
def merge_node(self, src_node, dst_node): def merge_node(self, src_node, dst_node):
''' '''
To merge dst_node into src_node, we need to do it in following steps: To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy 1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1 strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0. x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy, contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy. another is the origin extra costs in src_node strategy.
@ -98,10 +97,9 @@ class CostGraph:
src_node(Node): The node will be merged into dst_node. src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node. dst_node(Node): The node to integrate src_node.
''' '''
src_node_index = dst_node.parents.index(src_node)
# build merge_map # build merge_map
merge_map = {} merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector): for src_index, _ in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST min_cost = INFINITY_COST
lowest_cost_index = -1 lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
@ -139,7 +137,6 @@ class CostGraph:
for i in range(self.node_lens[src_node]): for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]): for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i] dst_strate_index = merge_map[i]
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs: if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost self.edge_costs[new_node_pair] = edge_cost

View File

@ -1,3 +1,4 @@
import builtins
import math import math
import operator import operator
from copy import deepcopy from copy import deepcopy
@ -13,6 +14,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
operator_registry, operator_registry,
) )
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from .options import DataloaderOption, SolverOptions from .options import DataloaderOption, SolverOptions
@ -49,10 +51,6 @@ class StrategiesConstructor:
name_checklist = [] name_checklist = []
remove_list = [] remove_list = []
for strategy in strategies_vector: for strategy in strategies_vector:
if strategy is None:
print(strategies_vector.node.name)
print(strategies_vector)
assert False
if strategy.name not in name_checklist: if strategy.name not in name_checklist:
name_checklist.append(strategy.name) name_checklist.append(strategy.name)
else: else:
@ -64,10 +62,33 @@ class StrategiesConstructor:
""" """
This method is to build the strategy vector for each node in the computation graph. This method is to build the strategy vector for each node in the computation graph.
""" """
def _check_no_strategy_for_node(node):
if node.op in ('placeholder', 'get_attr', 'output'):
return False
def _check_no_strategy_for_data(data):
label = True
if isinstance(data, torch.Tensor):
return False
elif isinstance(data, (tuple, list)):
for d in data:
label = label and _check_no_strategy_for_data(d)
return label
return _check_no_strategy_for_data(node._meta_data)
no_strategy_node = []
for node in self.nodes: for node in self.nodes:
strategies_vector = StrategiesVector(node) strategies_vector = StrategiesVector(node)
print(node)
if _check_no_strategy_for_node(node):
no_strategy_node.append(node)
pass
# placeholder node # placeholder node
if node.op == 'placeholder': elif node.op == 'placeholder':
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed' placeholder_option = 'distributed'
else: else:
@ -80,7 +101,7 @@ class StrategiesConstructor:
placeholder_handler.register_strategy() placeholder_handler.register_strategy()
# get_attr node # get_attr node
if node.op == 'get_attr': elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector) getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
getattr_handler.register_strategy() getattr_handler.register_strategy()
@ -114,10 +135,19 @@ class StrategiesConstructor:
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy() output_handler.register_strategy()
if len(strategies_vector) <= 0:
print(node.name)
assert len(strategies_vector) > 0
self.remove_duplicated_strategy(strategies_vector) self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector) setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector) self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)

View File

@ -6,7 +6,7 @@ from .broadcast import (
recover_sharding_spec_for_broadcast_shape, recover_sharding_spec_for_broadcast_shape,
) )
from .factory import generate_resharding_costs, generate_sharding_spec from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import check_sharding_spec_validity, ignore_sharding_exception from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
from .sharding import ( from .sharding import (
enumerate_all_possible_1d_sharding, enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding, enumerate_all_possible_2d_sharding,
@ -19,5 +19,5 @@ __all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands' 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map'
] ]

View File

@ -1,11 +1,12 @@
import functools import functools
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch import torch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
__all__ = ['ignore_sharding_exception'] __all__ = ['ignore_sharding_exception', 'pytree_map']
def ignore_sharding_exception(func): def ignore_sharding_exception(func):
@ -70,3 +71,27 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure the entire shape matches the physical tensor shape # make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape, \ assert sharding_spec.entire_shape == tensor.shape, \
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types (:class: `type | tuple[type]`): types to determine the type to process
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
"""
if isinstance(obj, dict):
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
elif isinstance(obj, tuple):
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, list):
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, process_types):
return fn(obj)
else:
return fn(obj) if map_all else obj