mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add experimental view handler (#2011)
* [autoparallel] add experimental view handler * polish * polish * polish code * rename variablespull/2018/head
parent
d655eea515
commit
1438993113
|
@ -0,0 +1,4 @@
|
||||||
|
from .view_generator import ViewGenerator
|
||||||
|
from .view_handler import ViewHandler
|
||||||
|
|
||||||
|
__all__ = ['ViewGenerator', 'ViewHandler']
|
|
@ -0,0 +1,133 @@
|
||||||
|
import copy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
CommAction,
|
||||||
|
CommType,
|
||||||
|
MemoryCost,
|
||||||
|
ShardingStrategy,
|
||||||
|
TrainCycleItem,
|
||||||
|
)
|
||||||
|
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||||
|
check_keep_sharding_status,
|
||||||
|
detect_reshape_mapping,
|
||||||
|
infer_output_dim_partition_dict,
|
||||||
|
)
|
||||||
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
__all__ = ['ViewGenerator']
|
||||||
|
|
||||||
|
|
||||||
|
class ViewGenerator(FollowingStrategyGenerator):
|
||||||
|
"""
|
||||||
|
ViewGenerator which deals with the sharding strategies of view op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
return super().validate()
|
||||||
|
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy):
|
||||||
|
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
|
||||||
|
strategy.compute_cost = compute_cost
|
||||||
|
|
||||||
|
def update_memory_cost(self, strategy: ShardingStrategy):
|
||||||
|
'''
|
||||||
|
Compute the memory cost per device with this specific strategy.
|
||||||
|
'''
|
||||||
|
forward_size_mapping = {
|
||||||
|
'input': self._compute_size_in_bytes(strategy, "input"),
|
||||||
|
'output': self._compute_size_in_bytes(strategy, "output")
|
||||||
|
}
|
||||||
|
|
||||||
|
backward_size_mapping = copy.deepcopy(forward_size_mapping)
|
||||||
|
backward_size_mapping.pop("output")
|
||||||
|
# compute fwd cost incurred
|
||||||
|
# fwd_cost = input + output
|
||||||
|
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
|
||||||
|
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
|
||||||
|
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
|
||||||
|
|
||||||
|
# compute bwd cost incurred
|
||||||
|
# bwd_cost = input_grad
|
||||||
|
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
|
||||||
|
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
|
||||||
|
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
|
||||||
|
|
||||||
|
# compute total cost
|
||||||
|
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
|
||||||
|
parameter=fwd_parameter_cost + bwd_parameter_cost)
|
||||||
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
|
strategy_list = []
|
||||||
|
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||||
|
dim_partition_dict_mapping = {}
|
||||||
|
communication_action_mapping = {}
|
||||||
|
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||||
|
|
||||||
|
origin_shape = self.op_data['input'].data.shape
|
||||||
|
tgt_shape = self.op_data['tgt_shape'].data
|
||||||
|
|
||||||
|
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
|
||||||
|
|
||||||
|
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
|
||||||
|
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
|
||||||
|
|
||||||
|
if keep_sharding_status:
|
||||||
|
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
|
||||||
|
reshape_mapping_dict)
|
||||||
|
else:
|
||||||
|
dim_partition_dict_for_output = {}
|
||||||
|
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": dim_partition_dict_for_input,
|
||||||
|
"output": dim_partition_dict_for_output,
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# add index into name to pass the duplicated check
|
||||||
|
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||||
|
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||||
|
if keep_sharding_status:
|
||||||
|
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
|
||||||
|
else:
|
||||||
|
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
|
||||||
|
|
||||||
|
# add comm action for converting input to fully replicated
|
||||||
|
total_mesh_dim_list = []
|
||||||
|
for mesh_dim_list in dim_partition_dict_for_input.values():
|
||||||
|
total_mesh_dim_list.extend(mesh_dim_list)
|
||||||
|
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
|
||||||
|
if len(total_mesh_dim_list) == 1:
|
||||||
|
total_mesh_dim_list = total_mesh_dim_list[0]
|
||||||
|
input_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec=sharding_spec_mapping["input"],
|
||||||
|
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||||
|
logical_process_axis=total_mesh_dim_list,
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
arg_index=0)
|
||||||
|
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
|
||||||
|
|
||||||
|
elif len(total_mesh_dim_list) >= 2:
|
||||||
|
source_spec = sharding_spec_mapping["input"]
|
||||||
|
target_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
|
entire_shape=source_spec.entire_shape,
|
||||||
|
dim_partition_dict={})
|
||||||
|
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
|
||||||
|
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
input_comm_action = None
|
||||||
|
|
||||||
|
if input_comm_action is not None:
|
||||||
|
communication_action_mapping["input"] = input_comm_action
|
||||||
|
|
||||||
|
strategy = self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
strategy_list.append(strategy)
|
||||||
|
|
||||||
|
return strategy_list
|
|
@ -0,0 +1,51 @@
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...sharding_strategy import OperationData, OperationDataType
|
||||||
|
from ..node_handler import NodeHandler
|
||||||
|
from ..registry import operator_registry
|
||||||
|
from ..strategy import StrategyGenerator
|
||||||
|
from .view_generator import ViewGenerator
|
||||||
|
|
||||||
|
__all__ = ['ViewHandler']
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.Tensor.view)
|
||||||
|
class ViewHandler(NodeHandler):
|
||||||
|
"""
|
||||||
|
A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||||
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
|
generators = []
|
||||||
|
generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
|
# use transposed shape for strategies
|
||||||
|
# the strategies will be transformed back to its original shape in self.post_process
|
||||||
|
|
||||||
|
# check if the input operand is a parameter
|
||||||
|
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||||
|
data_type = OperationDataType.PARAM
|
||||||
|
else:
|
||||||
|
data_type = OperationDataType.ARG
|
||||||
|
|
||||||
|
input_data = self.node.args[0]._meta_data
|
||||||
|
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||||
|
|
||||||
|
target_shape = self.node._meta_data.shape
|
||||||
|
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
|
||||||
|
|
||||||
|
output_data = self.node._meta_data
|
||||||
|
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||||
|
|
||||||
|
mapping = {
|
||||||
|
"input": physical_input_operand,
|
||||||
|
"tgt_shape": physical_shape_operand,
|
||||||
|
"output": physical_output_operand
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapping
|
|
@ -51,6 +51,8 @@ class OperationData:
|
||||||
"""
|
"""
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.shape
|
return data.shape
|
||||||
|
elif isinstance(data, torch.Size):
|
||||||
|
return None
|
||||||
elif isinstance(data, (tuple, list)):
|
elif isinstance(data, (tuple, list)):
|
||||||
data_type = type(data)
|
data_type = type(data)
|
||||||
return data_type([_infer_logical_shape(d) for d in data])
|
return data_type([_infer_logical_shape(d) for d in data])
|
||||||
|
|
|
@ -82,7 +82,6 @@ class StrategiesConstructor:
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
strategies_vector = StrategiesVector(node)
|
strategies_vector = StrategiesVector(node)
|
||||||
|
|
||||||
print(node)
|
|
||||||
if _check_no_strategy_for_node(node):
|
if _check_no_strategy_for_node(node):
|
||||||
no_strategy_node.append(node)
|
no_strategy_node.append(node)
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -7,6 +7,7 @@ from .broadcast import (
|
||||||
)
|
)
|
||||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||||
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
|
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
|
||||||
|
from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict
|
||||||
from .sharding import (
|
from .sharding import (
|
||||||
enumerate_all_possible_1d_sharding,
|
enumerate_all_possible_1d_sharding,
|
||||||
enumerate_all_possible_2d_sharding,
|
enumerate_all_possible_2d_sharding,
|
||||||
|
@ -19,5 +20,6 @@ __all__ = [
|
||||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map'
|
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
|
||||||
|
'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class PreviousStatus(Enum):
|
||||||
|
"""
|
||||||
|
This class shows the status of previous comparision.
|
||||||
|
"""
|
||||||
|
RESET = 0
|
||||||
|
# ORIGIN means the dimension size of original tensor is larger in the previous comparision.
|
||||||
|
ORIGIN = 1
|
||||||
|
# TGT means the dimension size of target tensor is larger in the previous comparision.
|
||||||
|
TGT = 2
|
||||||
|
|
||||||
|
|
||||||
|
def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]:
|
||||||
|
"""
|
||||||
|
This method is used to detect the reshape mapping between original tensor and target tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related
|
||||||
|
target dims(values) during reshaping operation.
|
||||||
|
Examples:
|
||||||
|
import torch
|
||||||
|
origin_shape = torch.Size([4, 4, 4])
|
||||||
|
tgt_shape = torch.Size([2, 8, 2, 2])
|
||||||
|
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
|
||||||
|
print(reshape_mapping_dict)
|
||||||
|
Output:
|
||||||
|
{(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# reverse the shape object
|
||||||
|
origin_shape = list(origin_shape)
|
||||||
|
tgt_shape = list(tgt_shape)
|
||||||
|
origin_shape.reverse()
|
||||||
|
tgt_shape.reverse()
|
||||||
|
|
||||||
|
# initialize arguments
|
||||||
|
reshape_mapping_dict = {}
|
||||||
|
origin_len = len(origin_shape)
|
||||||
|
tgt_len = len(tgt_shape)
|
||||||
|
origin_index = 0
|
||||||
|
tgt_index = 0
|
||||||
|
original_dimension_size = origin_shape[origin_index]
|
||||||
|
tgt_dimension_size = tgt_shape[tgt_index]
|
||||||
|
tgt_dims = [tgt_len - tgt_index - 1]
|
||||||
|
origin_dims = [origin_len - origin_index - 1]
|
||||||
|
previous_label = PreviousStatus.RESET
|
||||||
|
|
||||||
|
while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
|
||||||
|
if original_dimension_size == tgt_dimension_size:
|
||||||
|
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||||
|
origin_index += 1
|
||||||
|
tgt_index += 1
|
||||||
|
# the last step of loop should always end with condition
|
||||||
|
# so we need to manually skip the preparation for next step
|
||||||
|
# in the last step.
|
||||||
|
if origin_index == len(origin_shape):
|
||||||
|
continue
|
||||||
|
original_dimension_size = origin_shape[origin_index]
|
||||||
|
tgt_dimension_size = tgt_shape[tgt_index]
|
||||||
|
origin_dims = [origin_len - origin_index - 1]
|
||||||
|
tgt_dims = [tgt_len - tgt_index - 1]
|
||||||
|
previous_label = PreviousStatus.RESET
|
||||||
|
|
||||||
|
elif original_dimension_size > tgt_dimension_size:
|
||||||
|
tgt_index += 1
|
||||||
|
|
||||||
|
if previous_label == PreviousStatus.TGT:
|
||||||
|
# if the target dimension size is larger in the previous comparision, which means
|
||||||
|
# the origin dimension size has already accumulated larger than target dimension size, so
|
||||||
|
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
|
||||||
|
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||||
|
original_dimension_size = original_dimension_size // tgt_dimension_size
|
||||||
|
origin_dims = [origin_len - origin_index - 1]
|
||||||
|
tgt_dimension_size = tgt_shape[tgt_index]
|
||||||
|
tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index]
|
||||||
|
# reset the previous_label after offloading the origin dims and tgt dims
|
||||||
|
previous_label = PreviousStatus.RESET
|
||||||
|
else:
|
||||||
|
# accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size
|
||||||
|
tgt_dimension_size *= tgt_shape[tgt_index]
|
||||||
|
tgt_dims.append(tgt_len - tgt_index - 1)
|
||||||
|
previous_label = PreviousStatus.ORIGIN
|
||||||
|
|
||||||
|
else:
|
||||||
|
origin_index += 1
|
||||||
|
|
||||||
|
if previous_label == PreviousStatus.ORIGIN:
|
||||||
|
# if the origin element is larger in the previous comparision, which means
|
||||||
|
# the target element has already accumulated larger than origin element, so
|
||||||
|
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
|
||||||
|
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||||
|
tgt_dimension_size = tgt_dimension_size // original_dimension_size
|
||||||
|
tgt_dims = [tgt_len - tgt_index - 1]
|
||||||
|
original_dimension_size = origin_shape[origin_index]
|
||||||
|
origin_dims = [origin_len - origin_index - 1, origin_len - origin_index]
|
||||||
|
# reset the previous_label after offloading the origin dims and tgt dims
|
||||||
|
previous_label = PreviousStatus.RESET
|
||||||
|
else:
|
||||||
|
# accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size
|
||||||
|
original_dimension_size *= origin_shape[origin_index]
|
||||||
|
origin_dims.append(origin_len - origin_index - 1)
|
||||||
|
previous_label = PreviousStatus.TGT
|
||||||
|
|
||||||
|
return reshape_mapping_dict
|
||||||
|
|
||||||
|
|
||||||
|
def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
|
||||||
|
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
|
||||||
|
"""
|
||||||
|
This method is used to check whether the reshape operation could implement without converting
|
||||||
|
the input to fully replicated status.
|
||||||
|
|
||||||
|
Rule:
|
||||||
|
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
|
||||||
|
the function will return false.
|
||||||
|
To illustrate this issue, there are two cases to analyse:
|
||||||
|
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
|
||||||
|
operation without distributed tensor.
|
||||||
|
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
|
||||||
|
consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded
|
||||||
|
dim get recovered.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# the second dimension of the input has been sharded.
|
||||||
|
input_dim_partition_dict = {1: [1]}
|
||||||
|
origin_shape = torch.Size([8, 4, 2])
|
||||||
|
tgt_shape = torch.Size([2, 4, 8])
|
||||||
|
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
|
||||||
|
# {(2, 1): (2,), (0,): (1, 0)}
|
||||||
|
# the sharded dim of input is 1, which is the minimum element of the tuple (2, 1),
|
||||||
|
# so we do not have to convert the input to fully replicated status.
|
||||||
|
print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict))
|
||||||
|
|
||||||
|
Output:
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
sharded_dims = list(input_dim_partition_dict.keys())
|
||||||
|
for input_dims in reshape_mapping_dict.keys():
|
||||||
|
min_element = min(input_dims)
|
||||||
|
for dim in input_dims:
|
||||||
|
if dim in sharded_dims and dim is not min_element:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
|
||||||
|
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
|
||||||
|
"""
|
||||||
|
This method is used to infer the output dim partition dict for a reshape operation,
|
||||||
|
given the input dim partition dict and reshape mapping dict.
|
||||||
|
"""
|
||||||
|
assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
|
||||||
|
'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
|
||||||
|
sharded_dims = list(input_dim_partition_dict.keys())
|
||||||
|
output_dim_partition_dict = {}
|
||||||
|
for input_dims, output_dims in reshape_mapping_dict.items():
|
||||||
|
for dim in input_dims:
|
||||||
|
if dim in sharded_dims:
|
||||||
|
output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim]
|
||||||
|
# we could break because input dims cannot contain two sharded dims, otherwise
|
||||||
|
# the keep sharding status check will fail.
|
||||||
|
break
|
||||||
|
return output_dim_partition_dict
|
|
@ -0,0 +1,98 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
|
class ViewModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, input, other):
|
||||||
|
conv_node = nn.functional.conv2d(input, other)
|
||||||
|
reshape_node = conv_node.view(32, 4, 32, 32, 4)
|
||||||
|
return reshape_node
|
||||||
|
|
||||||
|
|
||||||
|
def test_view_handler():
|
||||||
|
model = ViewModel()
|
||||||
|
tracer = ColoTracer()
|
||||||
|
# graph():
|
||||||
|
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||||
|
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||||
|
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
|
||||||
|
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
|
||||||
|
# return view
|
||||||
|
graph = tracer.trace(model,
|
||||||
|
meta_args={
|
||||||
|
"input": torch.rand(8, 8, 66, 66).to('meta'),
|
||||||
|
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||||
|
})
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
conv_mod_node = list(graph.nodes)[2]
|
||||||
|
view_node = list(graph.nodes)[3]
|
||||||
|
view_strategies_vector = StrategiesVector(view_node)
|
||||||
|
conv_strategies_vector = StrategiesVector(conv_mod_node)
|
||||||
|
|
||||||
|
# build handler
|
||||||
|
conv_handler = ConvFunctionHandler(node=conv_mod_node,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
strategies_vector=conv_strategies_vector)
|
||||||
|
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
||||||
|
view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector)
|
||||||
|
|
||||||
|
view_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
|
||||||
|
# check operation data mapping
|
||||||
|
mapping = view_handler.get_operation_data_mapping()
|
||||||
|
|
||||||
|
for name, op_data in mapping.items():
|
||||||
|
op_data: OperationData
|
||||||
|
# make sure they have valid values
|
||||||
|
assert op_data.data is not None
|
||||||
|
|
||||||
|
assert mapping['input'].name == "conv2d"
|
||||||
|
assert mapping['input'].data.is_meta
|
||||||
|
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
|
||||||
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
|
assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64])
|
||||||
|
|
||||||
|
assert mapping['output'].name == "view"
|
||||||
|
assert mapping['output'].data.is_meta
|
||||||
|
assert mapping['output'].data.shape == torch.Size([32, 4, 32, 32, 4])
|
||||||
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
|
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
|
||||||
|
assert len(view_strategies_vector) == len(conv_strategies_vector)
|
||||||
|
strategy_name_list = [strategy.name for strategy in view_strategies_vector]
|
||||||
|
assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list
|
||||||
|
assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list
|
||||||
|
assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
|
||||||
|
assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_view_handler()
|
Loading…
Reference in New Issue