[autoparallel] use pytree map style to process data (#1989)

pull/1999/head
YuliangLiu0306 2022-11-21 10:44:22 +08:00 committed by GitHub
parent 35e6b9ec82
commit 155891113e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 178 additions and 66 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from typing import Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
@ -7,6 +7,7 @@ from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingSpec,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
@ -52,12 +53,14 @@ class NodeHandler(ABC):
node_name = str(node)
# get the current sharding spec generated by this node handler
# TODO: we need to check this in future
if not isinstance(node._meta_data, torch.Tensor):
# we will not compute the resharding costs for the node not counted in the strategy.
# And the node with tuple or list output need to be handled below.
node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()]
if str(node) not in node_in_strategy:
continue
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated
# in its own node handler
assert hasattr(node, 'strategies_vector'), \
@ -68,23 +71,64 @@ class NodeHandler(ABC):
]
# create data structrure to store costs
if op_data not in resharding_costs:
if node not in resharding_costs:
resharding_costs[node] = []
def _compute_resharding_cost(
prev_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]],
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
if prev_sharding_spec is None:
return TrainCycleItem(fwd=0, bwd=0, total=0)
elif isinstance(prev_sharding_spec, ShardingSpec):
if isinstance(data, torch.nn.parameter.Parameter):
# we won't compute the resharding cost for the parameters,
# since the parameters will be sharded before runtime and
# not converted during runtime.
return TrainCycleItem(fwd=0, bwd=0, total=0)
elif isinstance(data, torch.Tensor):
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes,
total=consistency_cost["total"] * size_per_elem_bytes)
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
raise ValueError(f'Unsupported data type {type(data)}')
else:
assert isinstance(prev_sharding_spec, (tuple, list)), \
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
fwd_cost = 0
bwd_cost = 0
total_cost = 0
for index, (prev_sharding_spec_item,
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
current_sharding_spec)):
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
data[index])
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost)
return resharding_cost
# for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated
# by the current node handler
for prev_sharding_spec in prev_sharding_specs:
if op_data.type == OperationDataType.PARAM:
resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
else:
dtype = op_data.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, resharding_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"] * size_per_elem_bytes,
bwd=resharding_cost["backward"] * size_per_elem_bytes,
total=resharding_cost["total"] * size_per_elem_bytes)
resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data)
resharding_costs[node].append(resharding_cost)
strategy.resharding_costs = resharding_costs
return strategy

View File

@ -68,32 +68,41 @@ class StrategyGenerator(ABC):
Args:
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
Notes:
The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data.
However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as
list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned.
"""
results = {}
for op_data_name, dim_partition_dict in mapping.items():
if op_data_name in self.op_data:
op_data = self.op_data[op_data_name]
if isinstance(op_data.data, tuple):
for data in op_data.data:
assert isinstance(
data, torch.Tensor), 'We cannot create a ShardingSpec object from a non-tensor object.'
sharding_spec = []
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
def _to_sharding_spec(
data: any, logical_shape: any,
dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict_element)
sharding_spec.append(sharding_spec_element)
else:
assert isinstance(
op_data.data, torch.Tensor
), f'op_data.data should be a torch.Tensor or Tuple[torch.Tensor], but got {type(op_data.data)}'
dim_size = len(op_data.logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
data, logical_shape, dim_partition_dict):
sharding_spec.append(
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
return sharding_spec
else:
return None
sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict)
results[op_data_name] = sharding_spec
return results
@ -285,6 +294,5 @@ class OutputStrategyGenerator(StrategyGenerator):
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node]):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes

View File

@ -44,10 +44,20 @@ class OperationData:
def __post_init__(self):
# if no logical shape is specified, use the data shape as the logical shape
if self.logical_shape is None:
if isinstance(self.data, torch.Tensor):
self.logical_shape = self.data.shape
elif isinstance(self.data, tuple):
self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data])
def _infer_logical_shape(data: any):
"""
This function is used to infer the logical shape of the data.
"""
if isinstance(data, torch.Tensor):
return data.shape
elif isinstance(data, (tuple, list)):
data_type = type(data)
return data_type([_infer_logical_shape(d) for d in data])
else:
return None
self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'
@ -216,8 +226,6 @@ class StrategiesVector(list):
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
if self.node.op == 'output':
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
self.successor_nodes = list(node.users.keys())
def check_merge(self):

View File

@ -1,13 +1,14 @@
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
import torch
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
@ -66,8 +67,6 @@ class CostGraph:
children_nodes = [node for node in strategies_vector.successor_nodes]
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
# self._remove_invalid_node(dst_node, 'parents')
# self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
@ -79,14 +78,14 @@ class CostGraph:
def merge_node(self, src_node, dst_node):
'''
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
@ -98,10 +97,9 @@ class CostGraph:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
src_node_index = dst_node.parents.index(src_node)
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
for src_index, _ in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
@ -139,7 +137,6 @@ class CostGraph:
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost

View File

@ -1,3 +1,4 @@
import builtins
import math
import operator
from copy import deepcopy
@ -13,6 +14,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.device.device_mesh import DeviceMesh
from .options import DataloaderOption, SolverOptions
@ -49,10 +51,6 @@ class StrategiesConstructor:
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy is None:
print(strategies_vector.node.name)
print(strategies_vector)
assert False
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
@ -64,10 +62,33 @@ class StrategiesConstructor:
"""
This method is to build the strategy vector for each node in the computation graph.
"""
def _check_no_strategy_for_node(node):
if node.op in ('placeholder', 'get_attr', 'output'):
return False
def _check_no_strategy_for_data(data):
label = True
if isinstance(data, torch.Tensor):
return False
elif isinstance(data, (tuple, list)):
for d in data:
label = label and _check_no_strategy_for_data(d)
return label
return _check_no_strategy_for_data(node._meta_data)
no_strategy_node = []
for node in self.nodes:
strategies_vector = StrategiesVector(node)
print(node)
if _check_no_strategy_for_node(node):
no_strategy_node.append(node)
pass
# placeholder node
if node.op == 'placeholder':
elif node.op == 'placeholder':
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed'
else:
@ -80,7 +101,7 @@ class StrategiesConstructor:
placeholder_handler.register_strategy()
# get_attr node
if node.op == 'get_attr':
elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
getattr_handler.register_strategy()
@ -114,10 +135,19 @@ class StrategiesConstructor:
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
if len(strategies_vector) <= 0:
print(node.name)
assert len(strategies_vector) > 0
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)

View File

@ -6,7 +6,7 @@ from .broadcast import (
recover_sharding_spec_for_broadcast_shape,
)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import check_sharding_spec_validity, ignore_sharding_exception
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
from .sharding import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
@ -19,5 +19,5 @@ __all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands'
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map'
]

View File

@ -1,11 +1,12 @@
import functools
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
__all__ = ['ignore_sharding_exception']
__all__ = ['ignore_sharding_exception', 'pytree_map']
def ignore_sharding_exception(func):
@ -70,3 +71,27 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape, \
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types (:class: `type | tuple[type]`): types to determine the type to process
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
"""
if isinstance(obj, dict):
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
elif isinstance(obj, tuple):
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, list):
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, process_types):
return fn(obj)
else:
return fn(obj) if map_all else obj