mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] use pytree map style to process data (#1989)
parent
35e6b9ec82
commit
155891113e
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
@ -7,6 +7,7 @@ from torch.fx.node import Node
|
|||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingSpec,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
|
@ -52,12 +53,14 @@ class NodeHandler(ABC):
|
|||
node_name = str(node)
|
||||
# get the current sharding spec generated by this node handler
|
||||
|
||||
# TODO: we need to check this in future
|
||||
if not isinstance(node._meta_data, torch.Tensor):
|
||||
# we will not compute the resharding costs for the node not counted in the strategy.
|
||||
# And the node with tuple or list output need to be handled below.
|
||||
node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()]
|
||||
if str(node) not in node_in_strategy:
|
||||
continue
|
||||
|
||||
op_data = strategy.get_op_data_by_name(node_name)
|
||||
current_sharding_spec = strategy.sharding_specs[op_data]
|
||||
|
||||
# get the sharding specs for this node generated
|
||||
# in its own node handler
|
||||
assert hasattr(node, 'strategies_vector'), \
|
||||
|
@ -68,23 +71,64 @@ class NodeHandler(ABC):
|
|||
]
|
||||
|
||||
# create data structrure to store costs
|
||||
if op_data not in resharding_costs:
|
||||
if node not in resharding_costs:
|
||||
resharding_costs[node] = []
|
||||
|
||||
def _compute_resharding_cost(
|
||||
prev_sharding_spec: Union[ShardingSpec,
|
||||
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
|
||||
List[ShardingSpec]],
|
||||
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
|
||||
"""
|
||||
This is a helper function to compute the resharding cost for a specific strategy of a node.
|
||||
"""
|
||||
if prev_sharding_spec is None:
|
||||
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
elif isinstance(prev_sharding_spec, ShardingSpec):
|
||||
if isinstance(data, torch.nn.parameter.Parameter):
|
||||
# we won't compute the resharding cost for the parameters,
|
||||
# since the parameters will be sharded before runtime and
|
||||
# not converted during runtime.
|
||||
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
dtype = data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
|
||||
prev_sharding_spec, current_sharding_spec)
|
||||
|
||||
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
|
||||
bwd=consistency_cost["backward"] * size_per_elem_bytes,
|
||||
total=consistency_cost["total"] * size_per_elem_bytes)
|
||||
return resharding_cost
|
||||
else:
|
||||
# This raise is used to check if we have missed any type of data.
|
||||
# It could be merged into Parameter branch, which means we won't handle
|
||||
# non-tensor arguments.
|
||||
raise ValueError(f'Unsupported data type {type(data)}')
|
||||
else:
|
||||
assert isinstance(prev_sharding_spec, (tuple, list)), \
|
||||
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
|
||||
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
|
||||
|
||||
fwd_cost = 0
|
||||
bwd_cost = 0
|
||||
total_cost = 0
|
||||
for index, (prev_sharding_spec_item,
|
||||
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
|
||||
current_sharding_spec)):
|
||||
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
|
||||
data[index])
|
||||
fwd_cost += item_cost.fwd
|
||||
bwd_cost += item_cost.bwd
|
||||
total_cost += item_cost.total
|
||||
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost)
|
||||
return resharding_cost
|
||||
|
||||
# for each sharding spec generated by the predecessor's node handler
|
||||
# compute the resharding cost to switch to the sharding spec generated
|
||||
# by the current node handler
|
||||
for prev_sharding_spec in prev_sharding_specs:
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
else:
|
||||
dtype = op_data.data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
_, _, resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
prev_sharding_spec, current_sharding_spec)
|
||||
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"] * size_per_elem_bytes,
|
||||
bwd=resharding_cost["backward"] * size_per_elem_bytes,
|
||||
total=resharding_cost["total"] * size_per_elem_bytes)
|
||||
resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data)
|
||||
resharding_costs[node].append(resharding_cost)
|
||||
strategy.resharding_costs = resharding_costs
|
||||
return strategy
|
||||
|
|
|
@ -68,32 +68,41 @@ class StrategyGenerator(ABC):
|
|||
|
||||
Args:
|
||||
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
|
||||
|
||||
Notes:
|
||||
The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data.
|
||||
However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as
|
||||
list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned.
|
||||
"""
|
||||
results = {}
|
||||
for op_data_name, dim_partition_dict in mapping.items():
|
||||
if op_data_name in self.op_data:
|
||||
op_data = self.op_data[op_data_name]
|
||||
if isinstance(op_data.data, tuple):
|
||||
for data in op_data.data:
|
||||
assert isinstance(
|
||||
data, torch.Tensor), 'We cannot create a ShardingSpec object from a non-tensor object.'
|
||||
sharding_spec = []
|
||||
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
||||
|
||||
def _to_sharding_spec(
|
||||
data: any, logical_shape: any,
|
||||
dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
|
||||
"""
|
||||
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
dim_size = len(logical_shape)
|
||||
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
|
||||
sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=logical_shape,
|
||||
dim_partition_dict=dim_partition_dict_element)
|
||||
sharding_spec.append(sharding_spec_element)
|
||||
else:
|
||||
assert isinstance(
|
||||
op_data.data, torch.Tensor
|
||||
), f'op_data.data should be a torch.Tensor or Tuple[torch.Tensor], but got {type(op_data.data)}'
|
||||
dim_size = len(op_data.logical_shape)
|
||||
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=op_data.logical_shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=logical_shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
elif isinstance(data, (list, tuple)):
|
||||
sharding_spec = []
|
||||
for data_element, logical_shape_element, dim_partition_dict_element in zip(
|
||||
data, logical_shape, dim_partition_dict):
|
||||
sharding_spec.append(
|
||||
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
|
||||
return sharding_spec
|
||||
else:
|
||||
return None
|
||||
|
||||
sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict)
|
||||
results[op_data_name] = sharding_spec
|
||||
return results
|
||||
|
||||
|
@ -285,6 +294,5 @@ class OutputStrategyGenerator(StrategyGenerator):
|
|||
|
||||
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
|
||||
predecessor_nodes: List[Node]):
|
||||
self.op_data = operation_data_mapping
|
||||
self.device_mesh = device_mesh
|
||||
super().__init__(operation_data_mapping, device_mesh)
|
||||
self.predecessor_nodes = predecessor_nodes
|
||||
|
|
|
@ -44,10 +44,20 @@ class OperationData:
|
|||
def __post_init__(self):
|
||||
# if no logical shape is specified, use the data shape as the logical shape
|
||||
if self.logical_shape is None:
|
||||
if isinstance(self.data, torch.Tensor):
|
||||
self.logical_shape = self.data.shape
|
||||
elif isinstance(self.data, tuple):
|
||||
self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data])
|
||||
|
||||
def _infer_logical_shape(data: any):
|
||||
"""
|
||||
This function is used to infer the logical shape of the data.
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.shape
|
||||
elif isinstance(data, (tuple, list)):
|
||||
data_type = type(data)
|
||||
return data_type([_infer_logical_shape(d) for d in data])
|
||||
else:
|
||||
return None
|
||||
|
||||
self.logical_shape = _infer_logical_shape(self.data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'OperationData(name={self.name}, type={self.type})'
|
||||
|
@ -216,8 +226,6 @@ class StrategiesVector(list):
|
|||
# fetch its input and output nodes
|
||||
# TODO: placeholder input nodes
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||
if self.node.op == 'output':
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
|
||||
self.successor_nodes = list(node.users.keys())
|
||||
|
||||
def check_merge(self):
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
|
||||
|
||||
class CostGraph:
|
||||
'''
|
||||
A graph data structure to simplify the edge cost graph. It has two main functions:
|
||||
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
|
||||
be given by the StrategiesVector depending on the type of target node and following nodes.
|
||||
|
||||
|
@ -66,8 +67,6 @@ class CostGraph:
|
|||
children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
setattr(dst_node, 'parents', parent_nodes)
|
||||
setattr(dst_node, 'children', children_nodes)
|
||||
# self._remove_invalid_node(dst_node, 'parents')
|
||||
# self._remove_invalid_node(dst_node, 'children')
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
|
@ -79,14 +78,14 @@ class CostGraph:
|
|||
def merge_node(self, src_node, dst_node):
|
||||
'''
|
||||
To merge dst_node into src_node, we need to do it in following steps:
|
||||
|
||||
|
||||
1. For each strategy in dst_node, we need to pick an appropriate strategy
|
||||
of src_node to merge, it is important because the logical resharding costs
|
||||
between the parents node of src_node and merged node depend on the src_node
|
||||
of src_node to merge, it is important because the logical resharding costs
|
||||
between the parents node of src_node and merged node depend on the src_node
|
||||
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
|
||||
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
|
||||
x represents the picking strategy of node 1 merged into node 2 strategy 0.
|
||||
|
||||
|
||||
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
|
||||
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
|
||||
another is the origin extra costs in src_node strategy.
|
||||
|
@ -98,10 +97,9 @@ class CostGraph:
|
|||
src_node(Node): The node will be merged into dst_node.
|
||||
dst_node(Node): The node to integrate src_node.
|
||||
'''
|
||||
src_node_index = dst_node.parents.index(src_node)
|
||||
# build merge_map
|
||||
merge_map = {}
|
||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||
for src_index, _ in enumerate(src_node.strategies_vector):
|
||||
min_cost = INFINITY_COST
|
||||
lowest_cost_index = -1
|
||||
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
||||
|
@ -139,7 +137,6 @@ class CostGraph:
|
|||
for i in range(self.node_lens[src_node]):
|
||||
for j in range(self.node_lens[child_node]):
|
||||
dst_strate_index = merge_map[i]
|
||||
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
|
||||
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
|
||||
if new_node_pair not in self.edge_costs:
|
||||
self.edge_costs[new_node_pair] = edge_cost
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import builtins
|
||||
import math
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
|
@ -13,6 +14,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
|
|||
operator_registry,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .options import DataloaderOption, SolverOptions
|
||||
|
@ -49,10 +51,6 @@ class StrategiesConstructor:
|
|||
name_checklist = []
|
||||
remove_list = []
|
||||
for strategy in strategies_vector:
|
||||
if strategy is None:
|
||||
print(strategies_vector.node.name)
|
||||
print(strategies_vector)
|
||||
assert False
|
||||
if strategy.name not in name_checklist:
|
||||
name_checklist.append(strategy.name)
|
||||
else:
|
||||
|
@ -64,10 +62,33 @@ class StrategiesConstructor:
|
|||
"""
|
||||
This method is to build the strategy vector for each node in the computation graph.
|
||||
"""
|
||||
|
||||
def _check_no_strategy_for_node(node):
|
||||
if node.op in ('placeholder', 'get_attr', 'output'):
|
||||
return False
|
||||
|
||||
def _check_no_strategy_for_data(data):
|
||||
label = True
|
||||
if isinstance(data, torch.Tensor):
|
||||
return False
|
||||
elif isinstance(data, (tuple, list)):
|
||||
for d in data:
|
||||
label = label and _check_no_strategy_for_data(d)
|
||||
return label
|
||||
|
||||
return _check_no_strategy_for_data(node._meta_data)
|
||||
|
||||
no_strategy_node = []
|
||||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
|
||||
print(node)
|
||||
if _check_no_strategy_for_node(node):
|
||||
no_strategy_node.append(node)
|
||||
pass
|
||||
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
elif node.op == 'placeholder':
|
||||
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||
placeholder_option = 'distributed'
|
||||
else:
|
||||
|
@ -80,7 +101,7 @@ class StrategiesConstructor:
|
|||
placeholder_handler.register_strategy()
|
||||
|
||||
# get_attr node
|
||||
if node.op == 'get_attr':
|
||||
elif node.op == 'get_attr':
|
||||
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
|
||||
getattr_handler.register_strategy()
|
||||
|
||||
|
@ -114,10 +135,19 @@ class StrategiesConstructor:
|
|||
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
||||
output_handler.register_strategy()
|
||||
|
||||
if len(strategies_vector) <= 0:
|
||||
print(node.name)
|
||||
assert len(strategies_vector) > 0
|
||||
self.remove_duplicated_strategy(strategies_vector)
|
||||
setattr(node, 'strategies_vector', strategies_vector)
|
||||
self.leaf_strategies.append(strategies_vector)
|
||||
self.strategy_map[node] = strategies_vector
|
||||
|
||||
# remove no strategy nodes
|
||||
remove_list = []
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
if len(strategies_vector) == 0:
|
||||
remove_list.append(strategies_vector.node)
|
||||
|
||||
for node in remove_list:
|
||||
if node.strategies_vector in self.leaf_strategies:
|
||||
self.leaf_strategies.remove(node.strategies_vector)
|
||||
if node in self.strategy_map:
|
||||
self.strategy_map.pop(node)
|
||||
|
|
|
@ -6,7 +6,7 @@ from .broadcast import (
|
|||
recover_sharding_spec_for_broadcast_shape,
|
||||
)
|
||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||
from .misc import check_sharding_spec_validity, ignore_sharding_exception
|
||||
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
|
||||
from .sharding import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
|
@ -19,5 +19,5 @@ __all__ = [
|
|||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands'
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map'
|
||||
]
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import functools
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||
|
||||
__all__ = ['ignore_sharding_exception']
|
||||
__all__ = ['ignore_sharding_exception', 'pytree_map']
|
||||
|
||||
|
||||
def ignore_sharding_exception(func):
|
||||
|
@ -70,3 +71,27 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
|||
# make sure the entire shape matches the physical tensor shape
|
||||
assert sharding_spec.entire_shape == tensor.shape, \
|
||||
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
|
||||
|
||||
|
||||
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||
"""process object recursively, like pytree
|
||||
|
||||
Args:
|
||||
obj (:class:`Any`): object to process
|
||||
fn (:class:`Callable`): a function to process subobject in obj
|
||||
process_types (:class: `type | tuple[type]`): types to determine the type to process
|
||||
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
|
||||
|
||||
Returns:
|
||||
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
|
||||
elif isinstance(obj, list):
|
||||
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
|
||||
elif isinstance(obj, process_types):
|
||||
return fn(obj)
|
||||
else:
|
||||
return fn(obj) if map_all else obj
|
||||
|
|
Loading…
Reference in New Issue