mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] use pytree map style to process data (#1989)
parent
35e6b9ec82
commit
155891113e
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
@ -7,6 +7,7 @@ from torch.fx.node import Node
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
OperationData,
|
OperationData,
|
||||||
OperationDataType,
|
OperationDataType,
|
||||||
|
ShardingSpec,
|
||||||
ShardingStrategy,
|
ShardingStrategy,
|
||||||
StrategiesVector,
|
StrategiesVector,
|
||||||
TrainCycleItem,
|
TrainCycleItem,
|
||||||
|
@ -52,12 +53,14 @@ class NodeHandler(ABC):
|
||||||
node_name = str(node)
|
node_name = str(node)
|
||||||
# get the current sharding spec generated by this node handler
|
# get the current sharding spec generated by this node handler
|
||||||
|
|
||||||
# TODO: we need to check this in future
|
# we will not compute the resharding costs for the node not counted in the strategy.
|
||||||
if not isinstance(node._meta_data, torch.Tensor):
|
# And the node with tuple or list output need to be handled below.
|
||||||
|
node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()]
|
||||||
|
if str(node) not in node_in_strategy:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
op_data = strategy.get_op_data_by_name(node_name)
|
op_data = strategy.get_op_data_by_name(node_name)
|
||||||
current_sharding_spec = strategy.sharding_specs[op_data]
|
current_sharding_spec = strategy.sharding_specs[op_data]
|
||||||
|
|
||||||
# get the sharding specs for this node generated
|
# get the sharding specs for this node generated
|
||||||
# in its own node handler
|
# in its own node handler
|
||||||
assert hasattr(node, 'strategies_vector'), \
|
assert hasattr(node, 'strategies_vector'), \
|
||||||
|
@ -68,23 +71,64 @@ class NodeHandler(ABC):
|
||||||
]
|
]
|
||||||
|
|
||||||
# create data structrure to store costs
|
# create data structrure to store costs
|
||||||
if op_data not in resharding_costs:
|
if node not in resharding_costs:
|
||||||
resharding_costs[node] = []
|
resharding_costs[node] = []
|
||||||
|
|
||||||
|
def _compute_resharding_cost(
|
||||||
|
prev_sharding_spec: Union[ShardingSpec,
|
||||||
|
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
|
||||||
|
List[ShardingSpec]],
|
||||||
|
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
|
||||||
|
"""
|
||||||
|
This is a helper function to compute the resharding cost for a specific strategy of a node.
|
||||||
|
"""
|
||||||
|
if prev_sharding_spec is None:
|
||||||
|
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||||
|
elif isinstance(prev_sharding_spec, ShardingSpec):
|
||||||
|
if isinstance(data, torch.nn.parameter.Parameter):
|
||||||
|
# we won't compute the resharding cost for the parameters,
|
||||||
|
# since the parameters will be sharded before runtime and
|
||||||
|
# not converted during runtime.
|
||||||
|
return TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||||
|
elif isinstance(data, torch.Tensor):
|
||||||
|
dtype = data.dtype
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
|
||||||
|
prev_sharding_spec, current_sharding_spec)
|
||||||
|
|
||||||
|
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
|
||||||
|
bwd=consistency_cost["backward"] * size_per_elem_bytes,
|
||||||
|
total=consistency_cost["total"] * size_per_elem_bytes)
|
||||||
|
return resharding_cost
|
||||||
|
else:
|
||||||
|
# This raise is used to check if we have missed any type of data.
|
||||||
|
# It could be merged into Parameter branch, which means we won't handle
|
||||||
|
# non-tensor arguments.
|
||||||
|
raise ValueError(f'Unsupported data type {type(data)}')
|
||||||
|
else:
|
||||||
|
assert isinstance(prev_sharding_spec, (tuple, list)), \
|
||||||
|
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
|
||||||
|
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
|
||||||
|
|
||||||
|
fwd_cost = 0
|
||||||
|
bwd_cost = 0
|
||||||
|
total_cost = 0
|
||||||
|
for index, (prev_sharding_spec_item,
|
||||||
|
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
|
||||||
|
current_sharding_spec)):
|
||||||
|
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
|
||||||
|
data[index])
|
||||||
|
fwd_cost += item_cost.fwd
|
||||||
|
bwd_cost += item_cost.bwd
|
||||||
|
total_cost += item_cost.total
|
||||||
|
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost)
|
||||||
|
return resharding_cost
|
||||||
|
|
||||||
# for each sharding spec generated by the predecessor's node handler
|
# for each sharding spec generated by the predecessor's node handler
|
||||||
# compute the resharding cost to switch to the sharding spec generated
|
# compute the resharding cost to switch to the sharding spec generated
|
||||||
# by the current node handler
|
# by the current node handler
|
||||||
for prev_sharding_spec in prev_sharding_specs:
|
for prev_sharding_spec in prev_sharding_specs:
|
||||||
if op_data.type == OperationDataType.PARAM:
|
resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data)
|
||||||
resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
|
||||||
else:
|
|
||||||
dtype = op_data.data.dtype
|
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
|
||||||
_, _, resharding_cost = shape_consistency_manager.shape_consistency(
|
|
||||||
prev_sharding_spec, current_sharding_spec)
|
|
||||||
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"] * size_per_elem_bytes,
|
|
||||||
bwd=resharding_cost["backward"] * size_per_elem_bytes,
|
|
||||||
total=resharding_cost["total"] * size_per_elem_bytes)
|
|
||||||
resharding_costs[node].append(resharding_cost)
|
resharding_costs[node].append(resharding_cost)
|
||||||
strategy.resharding_costs = resharding_costs
|
strategy.resharding_costs = resharding_costs
|
||||||
return strategy
|
return strategy
|
||||||
|
|
|
@ -68,32 +68,41 @@ class StrategyGenerator(ABC):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
|
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data.
|
||||||
|
However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as
|
||||||
|
list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned.
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
for op_data_name, dim_partition_dict in mapping.items():
|
for op_data_name, dim_partition_dict in mapping.items():
|
||||||
if op_data_name in self.op_data:
|
if op_data_name in self.op_data:
|
||||||
op_data = self.op_data[op_data_name]
|
op_data = self.op_data[op_data_name]
|
||||||
if isinstance(op_data.data, tuple):
|
|
||||||
for data in op_data.data:
|
def _to_sharding_spec(
|
||||||
assert isinstance(
|
data: any, logical_shape: any,
|
||||||
data, torch.Tensor), 'We cannot create a ShardingSpec object from a non-tensor object.'
|
dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
|
||||||
sharding_spec = []
|
"""
|
||||||
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
|
||||||
|
"""
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
dim_size = len(logical_shape)
|
dim_size = len(logical_shape)
|
||||||
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
|
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
||||||
sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh,
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
entire_shape=logical_shape,
|
entire_shape=logical_shape,
|
||||||
dim_partition_dict=dim_partition_dict_element)
|
dim_partition_dict=dim_partition_dict)
|
||||||
sharding_spec.append(sharding_spec_element)
|
return sharding_spec
|
||||||
else:
|
elif isinstance(data, (list, tuple)):
|
||||||
assert isinstance(
|
sharding_spec = []
|
||||||
op_data.data, torch.Tensor
|
for data_element, logical_shape_element, dim_partition_dict_element in zip(
|
||||||
), f'op_data.data should be a torch.Tensor or Tuple[torch.Tensor], but got {type(op_data.data)}'
|
data, logical_shape, dim_partition_dict):
|
||||||
dim_size = len(op_data.logical_shape)
|
sharding_spec.append(
|
||||||
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
|
||||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
return sharding_spec
|
||||||
entire_shape=op_data.logical_shape,
|
else:
|
||||||
dim_partition_dict=dim_partition_dict)
|
return None
|
||||||
|
|
||||||
|
sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict)
|
||||||
results[op_data_name] = sharding_spec
|
results[op_data_name] = sharding_spec
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -285,6 +294,5 @@ class OutputStrategyGenerator(StrategyGenerator):
|
||||||
|
|
||||||
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
|
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
|
||||||
predecessor_nodes: List[Node]):
|
predecessor_nodes: List[Node]):
|
||||||
self.op_data = operation_data_mapping
|
super().__init__(operation_data_mapping, device_mesh)
|
||||||
self.device_mesh = device_mesh
|
|
||||||
self.predecessor_nodes = predecessor_nodes
|
self.predecessor_nodes = predecessor_nodes
|
||||||
|
|
|
@ -44,10 +44,20 @@ class OperationData:
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# if no logical shape is specified, use the data shape as the logical shape
|
# if no logical shape is specified, use the data shape as the logical shape
|
||||||
if self.logical_shape is None:
|
if self.logical_shape is None:
|
||||||
if isinstance(self.data, torch.Tensor):
|
|
||||||
self.logical_shape = self.data.shape
|
def _infer_logical_shape(data: any):
|
||||||
elif isinstance(self.data, tuple):
|
"""
|
||||||
self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data])
|
This function is used to infer the logical shape of the data.
|
||||||
|
"""
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return data.shape
|
||||||
|
elif isinstance(data, (tuple, list)):
|
||||||
|
data_type = type(data)
|
||||||
|
return data_type([_infer_logical_shape(d) for d in data])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.logical_shape = _infer_logical_shape(self.data)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'OperationData(name={self.name}, type={self.type})'
|
return f'OperationData(name={self.name}, type={self.type})'
|
||||||
|
@ -216,8 +226,6 @@ class StrategiesVector(list):
|
||||||
# fetch its input and output nodes
|
# fetch its input and output nodes
|
||||||
# TODO: placeholder input nodes
|
# TODO: placeholder input nodes
|
||||||
self.predecessor_nodes = list(node._input_nodes.keys())
|
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||||
if self.node.op == 'output':
|
|
||||||
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
|
|
||||||
self.successor_nodes = list(node.users.keys())
|
self.successor_nodes = list(node.users.keys())
|
||||||
|
|
||||||
def check_merge(self):
|
def check_merge(self):
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||||
|
|
||||||
|
|
||||||
class CostGraph:
|
class CostGraph:
|
||||||
'''
|
'''
|
||||||
A graph data structure to simplify the edge cost graph. It has two main functions:
|
A graph data structure to simplify the edge cost graph. It has two main functions:
|
||||||
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||||
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||||
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
|
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
|
||||||
be given by the StrategiesVector depending on the type of target node and following nodes.
|
be given by the StrategiesVector depending on the type of target node and following nodes.
|
||||||
|
|
||||||
|
@ -66,8 +67,6 @@ class CostGraph:
|
||||||
children_nodes = [node for node in strategies_vector.successor_nodes]
|
children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||||
setattr(dst_node, 'parents', parent_nodes)
|
setattr(dst_node, 'parents', parent_nodes)
|
||||||
setattr(dst_node, 'children', children_nodes)
|
setattr(dst_node, 'children', children_nodes)
|
||||||
# self._remove_invalid_node(dst_node, 'parents')
|
|
||||||
# self._remove_invalid_node(dst_node, 'children')
|
|
||||||
|
|
||||||
if self.simplify and strategies_vector.check_merge():
|
if self.simplify and strategies_vector.check_merge():
|
||||||
for followed_node in strategies_vector.predecessor_nodes:
|
for followed_node in strategies_vector.predecessor_nodes:
|
||||||
|
@ -79,14 +78,14 @@ class CostGraph:
|
||||||
def merge_node(self, src_node, dst_node):
|
def merge_node(self, src_node, dst_node):
|
||||||
'''
|
'''
|
||||||
To merge dst_node into src_node, we need to do it in following steps:
|
To merge dst_node into src_node, we need to do it in following steps:
|
||||||
|
|
||||||
1. For each strategy in dst_node, we need to pick an appropriate strategy
|
1. For each strategy in dst_node, we need to pick an appropriate strategy
|
||||||
of src_node to merge, it is important because the logical resharding costs
|
of src_node to merge, it is important because the logical resharding costs
|
||||||
between the parents node of src_node and merged node depend on the src_node
|
between the parents node of src_node and merged node depend on the src_node
|
||||||
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
|
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
|
||||||
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
|
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
|
||||||
x represents the picking strategy of node 1 merged into node 2 strategy 0.
|
x represents the picking strategy of node 1 merged into node 2 strategy 0.
|
||||||
|
|
||||||
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
|
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
|
||||||
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
|
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
|
||||||
another is the origin extra costs in src_node strategy.
|
another is the origin extra costs in src_node strategy.
|
||||||
|
@ -98,10 +97,9 @@ class CostGraph:
|
||||||
src_node(Node): The node will be merged into dst_node.
|
src_node(Node): The node will be merged into dst_node.
|
||||||
dst_node(Node): The node to integrate src_node.
|
dst_node(Node): The node to integrate src_node.
|
||||||
'''
|
'''
|
||||||
src_node_index = dst_node.parents.index(src_node)
|
|
||||||
# build merge_map
|
# build merge_map
|
||||||
merge_map = {}
|
merge_map = {}
|
||||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
for src_index, _ in enumerate(src_node.strategies_vector):
|
||||||
min_cost = INFINITY_COST
|
min_cost = INFINITY_COST
|
||||||
lowest_cost_index = -1
|
lowest_cost_index = -1
|
||||||
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
||||||
|
@ -139,7 +137,6 @@ class CostGraph:
|
||||||
for i in range(self.node_lens[src_node]):
|
for i in range(self.node_lens[src_node]):
|
||||||
for j in range(self.node_lens[child_node]):
|
for j in range(self.node_lens[child_node]):
|
||||||
dst_strate_index = merge_map[i]
|
dst_strate_index = merge_map[i]
|
||||||
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
|
|
||||||
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
|
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
|
||||||
if new_node_pair not in self.edge_costs:
|
if new_node_pair not in self.edge_costs:
|
||||||
self.edge_costs[new_node_pair] = edge_cost
|
self.edge_costs[new_node_pair] = edge_cost
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import builtins
|
||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -13,6 +14,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
|
||||||
operator_registry,
|
operator_registry,
|
||||||
)
|
)
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||||
|
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
from .options import DataloaderOption, SolverOptions
|
from .options import DataloaderOption, SolverOptions
|
||||||
|
@ -49,10 +51,6 @@ class StrategiesConstructor:
|
||||||
name_checklist = []
|
name_checklist = []
|
||||||
remove_list = []
|
remove_list = []
|
||||||
for strategy in strategies_vector:
|
for strategy in strategies_vector:
|
||||||
if strategy is None:
|
|
||||||
print(strategies_vector.node.name)
|
|
||||||
print(strategies_vector)
|
|
||||||
assert False
|
|
||||||
if strategy.name not in name_checklist:
|
if strategy.name not in name_checklist:
|
||||||
name_checklist.append(strategy.name)
|
name_checklist.append(strategy.name)
|
||||||
else:
|
else:
|
||||||
|
@ -64,10 +62,33 @@ class StrategiesConstructor:
|
||||||
"""
|
"""
|
||||||
This method is to build the strategy vector for each node in the computation graph.
|
This method is to build the strategy vector for each node in the computation graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _check_no_strategy_for_node(node):
|
||||||
|
if node.op in ('placeholder', 'get_attr', 'output'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_no_strategy_for_data(data):
|
||||||
|
label = True
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return False
|
||||||
|
elif isinstance(data, (tuple, list)):
|
||||||
|
for d in data:
|
||||||
|
label = label and _check_no_strategy_for_data(d)
|
||||||
|
return label
|
||||||
|
|
||||||
|
return _check_no_strategy_for_data(node._meta_data)
|
||||||
|
|
||||||
|
no_strategy_node = []
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
strategies_vector = StrategiesVector(node)
|
strategies_vector = StrategiesVector(node)
|
||||||
|
|
||||||
|
print(node)
|
||||||
|
if _check_no_strategy_for_node(node):
|
||||||
|
no_strategy_node.append(node)
|
||||||
|
pass
|
||||||
|
|
||||||
# placeholder node
|
# placeholder node
|
||||||
if node.op == 'placeholder':
|
elif node.op == 'placeholder':
|
||||||
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||||
placeholder_option = 'distributed'
|
placeholder_option = 'distributed'
|
||||||
else:
|
else:
|
||||||
|
@ -80,7 +101,7 @@ class StrategiesConstructor:
|
||||||
placeholder_handler.register_strategy()
|
placeholder_handler.register_strategy()
|
||||||
|
|
||||||
# get_attr node
|
# get_attr node
|
||||||
if node.op == 'get_attr':
|
elif node.op == 'get_attr':
|
||||||
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
|
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
|
||||||
getattr_handler.register_strategy()
|
getattr_handler.register_strategy()
|
||||||
|
|
||||||
|
@ -114,10 +135,19 @@ class StrategiesConstructor:
|
||||||
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
||||||
output_handler.register_strategy()
|
output_handler.register_strategy()
|
||||||
|
|
||||||
if len(strategies_vector) <= 0:
|
|
||||||
print(node.name)
|
|
||||||
assert len(strategies_vector) > 0
|
|
||||||
self.remove_duplicated_strategy(strategies_vector)
|
self.remove_duplicated_strategy(strategies_vector)
|
||||||
setattr(node, 'strategies_vector', strategies_vector)
|
setattr(node, 'strategies_vector', strategies_vector)
|
||||||
self.leaf_strategies.append(strategies_vector)
|
self.leaf_strategies.append(strategies_vector)
|
||||||
self.strategy_map[node] = strategies_vector
|
self.strategy_map[node] = strategies_vector
|
||||||
|
|
||||||
|
# remove no strategy nodes
|
||||||
|
remove_list = []
|
||||||
|
for strategies_vector in self.leaf_strategies:
|
||||||
|
if len(strategies_vector) == 0:
|
||||||
|
remove_list.append(strategies_vector.node)
|
||||||
|
|
||||||
|
for node in remove_list:
|
||||||
|
if node.strategies_vector in self.leaf_strategies:
|
||||||
|
self.leaf_strategies.remove(node.strategies_vector)
|
||||||
|
if node in self.strategy_map:
|
||||||
|
self.strategy_map.pop(node)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .broadcast import (
|
||||||
recover_sharding_spec_for_broadcast_shape,
|
recover_sharding_spec_for_broadcast_shape,
|
||||||
)
|
)
|
||||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||||
from .misc import check_sharding_spec_validity, ignore_sharding_exception
|
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
|
||||||
from .sharding import (
|
from .sharding import (
|
||||||
enumerate_all_possible_1d_sharding,
|
enumerate_all_possible_1d_sharding,
|
||||||
enumerate_all_possible_2d_sharding,
|
enumerate_all_possible_2d_sharding,
|
||||||
|
@ -19,5 +19,5 @@ __all__ = [
|
||||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands'
|
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import functools
|
import functools
|
||||||
|
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||||
|
|
||||||
__all__ = ['ignore_sharding_exception']
|
__all__ = ['ignore_sharding_exception', 'pytree_map']
|
||||||
|
|
||||||
|
|
||||||
def ignore_sharding_exception(func):
|
def ignore_sharding_exception(func):
|
||||||
|
@ -70,3 +71,27 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
||||||
# make sure the entire shape matches the physical tensor shape
|
# make sure the entire shape matches the physical tensor shape
|
||||||
assert sharding_spec.entire_shape == tensor.shape, \
|
assert sharding_spec.entire_shape == tensor.shape, \
|
||||||
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
|
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
|
||||||
|
|
||||||
|
|
||||||
|
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||||
|
"""process object recursively, like pytree
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (:class:`Any`): object to process
|
||||||
|
fn (:class:`Callable`): a function to process subobject in obj
|
||||||
|
process_types (:class: `type | tuple[type]`): types to determine the type to process
|
||||||
|
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
|
||||||
|
elif isinstance(obj, tuple):
|
||||||
|
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
|
||||||
|
elif isinstance(obj, process_types):
|
||||||
|
return fn(obj)
|
||||||
|
else:
|
||||||
|
return fn(obj) if map_all else obj
|
||||||
|
|
Loading…
Reference in New Issue