mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] support distributed dataloader option (#1906)
* [autoparallel] support distributed dataloader option * update output handler to support ddp dataloader * poish codepull/1978/head^2
parent
6630d45546
commit
0da1d00399
|
@ -19,5 +19,5 @@ __all__ = [
|
||||||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
||||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler'
|
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetattrHandler'
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,7 +2,9 @@ from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import NodeHandler
|
||||||
from .strategy import OutputGenerator, StrategyGenerator
|
from .strategy import OutputGenerator, StrategyGenerator
|
||||||
|
|
||||||
|
@ -14,26 +16,37 @@ class OuputHandler(NodeHandler):
|
||||||
A OuputHandler which deals with the sharding strategies for Output Node.
|
A OuputHandler which deals with the sharding strategies for Output Node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||||
|
output_option: str) -> None:
|
||||||
|
super().__init__(node, device_mesh, strategies_vector)
|
||||||
|
self.output_option = output_option
|
||||||
|
|
||||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||||
op_data_mapping = self.get_operation_data_mapping()
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
generators = []
|
generators = []
|
||||||
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node))
|
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option))
|
||||||
return generators
|
return generators
|
||||||
|
|
||||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
# use transposed shape for strategies
|
# use transposed shape for strategies
|
||||||
# the strategies will be transformed back to its original shape in self.post_process
|
# the strategies will be transformed back to its original shape in self.post_process
|
||||||
dummy_output = torch.empty(1,).to("meta")
|
mapping = {}
|
||||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=dummy_output)
|
output_meta_data = []
|
||||||
|
|
||||||
mapping = {"output": physical_output}
|
|
||||||
for index, input_node in enumerate(self.predecessor_node):
|
for index, input_node in enumerate(self.predecessor_node):
|
||||||
if not hasattr(input_node, "_meta_data"):
|
input_meta_data = input_node._meta_data
|
||||||
print(input_node.name)
|
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
|
||||||
physical_inputs = OperationData(name=str(input_node),
|
|
||||||
type=OperationDataType.ARG,
|
|
||||||
data=input_node._meta_data)
|
|
||||||
name_key = f'input_{index}'
|
name_key = f'input_{index}'
|
||||||
mapping[name_key] = physical_inputs
|
mapping[name_key] = physical_inputs
|
||||||
|
output_meta_data.append(input_meta_data)
|
||||||
|
|
||||||
|
assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
|
||||||
|
if len(output_meta_data) == 1:
|
||||||
|
output_meta_data = output_meta_data[0]
|
||||||
|
else:
|
||||||
|
output_meta_data = tuple(output_meta_data)
|
||||||
|
|
||||||
|
self.node._meta_data = output_meta_data
|
||||||
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||||
|
|
||||||
|
mapping["output"] = physical_output
|
||||||
return mapping
|
return mapping
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import NodeHandler
|
||||||
from .strategy import PlaceholderGenerator, StrategyGenerator
|
from .strategy import PlaceholderGenerator, StrategyGenerator
|
||||||
|
|
||||||
|
@ -12,10 +16,16 @@ class PlacehodlerHandler(NodeHandler):
|
||||||
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
|
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||||
|
placeholder_option: str) -> None:
|
||||||
|
super().__init__(node, device_mesh, strategies_vector)
|
||||||
|
self.placeholder_option = placeholder_option
|
||||||
|
|
||||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||||
op_data_mapping = self.get_operation_data_mapping()
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
generators = []
|
generators = []
|
||||||
generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh))
|
generators.append(
|
||||||
|
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
|
||||||
return generators
|
return generators
|
||||||
|
|
||||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
|
|
|
@ -1,6 +1,14 @@
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
from torch.fx import Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
MemoryCost,
|
||||||
|
OperationData,
|
||||||
|
ShardingStrategy,
|
||||||
|
TrainCycleItem,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
from .strategy_generator import OutputStrategyGenerator
|
from .strategy_generator import OutputStrategyGenerator
|
||||||
|
|
||||||
|
@ -12,6 +20,11 @@ class OutputGenerator(OutputStrategyGenerator):
|
||||||
OutputGenerator is a generic class to generate strategies for Output Node.
|
OutputGenerator is a generic class to generate strategies for Output Node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
|
||||||
|
predecessor_nodes: List[Node], output_option: str):
|
||||||
|
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
|
||||||
|
self.output_option = output_option
|
||||||
|
|
||||||
def validate(self) -> bool:
|
def validate(self) -> bool:
|
||||||
return super().validate()
|
return super().validate()
|
||||||
|
|
||||||
|
@ -32,7 +45,10 @@ class OutputGenerator(OutputStrategyGenerator):
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
strategy.memory_cost = memory_cost
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
def replica_strategy(self) -> List[ShardingStrategy]:
|
||||||
|
"""
|
||||||
|
Generate replica strategy for output node.
|
||||||
|
"""
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
"output": {},
|
"output": {},
|
||||||
}
|
}
|
||||||
|
@ -48,5 +64,47 @@ class OutputGenerator(OutputStrategyGenerator):
|
||||||
strategy = self.get_sharding_strategy(name=name,
|
strategy = self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
return strategy
|
||||||
|
|
||||||
return [strategy]
|
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
|
||||||
|
"""
|
||||||
|
Generate distributed strategy for output node.
|
||||||
|
"""
|
||||||
|
# TODO: need to take care of the case when the first element of output only need to be sharded.
|
||||||
|
output_op_data = self.op_data['output']
|
||||||
|
if isinstance(output_op_data.data, tuple):
|
||||||
|
length = len(output_op_data.data)
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"output": [{
|
||||||
|
0: mesh_list
|
||||||
|
}] * length,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"output": {
|
||||||
|
0: mesh_list
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for index, _ in enumerate(self.predecessor_nodes):
|
||||||
|
mapping_name = f"input_{index}"
|
||||||
|
dim_partition_dict_mapping[mapping_name] = {0: mesh_list}
|
||||||
|
|
||||||
|
communication_action_mapping = {}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
name = 'Distributed Output'
|
||||||
|
|
||||||
|
strategy = self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
|
strategy_list = []
|
||||||
|
mesh_list = [0, 1]
|
||||||
|
if self.output_option == 'replicated':
|
||||||
|
strategy_list.append(self.replica_strategy())
|
||||||
|
elif self.output_option == 'distributed':
|
||||||
|
strategy_list.append(self.distributed_strategy(mesh_list))
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
MemoryCost,
|
||||||
|
OperationData,
|
||||||
|
ShardingStrategy,
|
||||||
|
TrainCycleItem,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
from .strategy_generator import StrategyGenerator
|
from .strategy_generator import StrategyGenerator
|
||||||
|
|
||||||
|
@ -12,6 +18,11 @@ class PlaceholderGenerator(StrategyGenerator):
|
||||||
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
|
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
|
||||||
|
placeholder_option: str):
|
||||||
|
super().__init__(operation_data_mapping, device_mesh)
|
||||||
|
self.placeholder_option = placeholder_option
|
||||||
|
|
||||||
def validate(self) -> bool:
|
def validate(self) -> bool:
|
||||||
return super().validate()
|
return super().validate()
|
||||||
|
|
||||||
|
@ -37,7 +48,10 @@ class PlaceholderGenerator(StrategyGenerator):
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
strategy.memory_cost = memory_cost
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
def replica_placeholder(self) -> ShardingStrategy:
|
||||||
|
"""
|
||||||
|
Generate replica strategy for placeholder node.
|
||||||
|
"""
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
"output": {},
|
"output": {},
|
||||||
}
|
}
|
||||||
|
@ -50,4 +64,37 @@ class PlaceholderGenerator(StrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
return [strategy]
|
return strategy
|
||||||
|
|
||||||
|
def distributed_placeholder(self, mesh_list) -> ShardingStrategy:
|
||||||
|
"""
|
||||||
|
Generate distributed strategy for placeholder node.
|
||||||
|
"""
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"output": {
|
||||||
|
0: mesh_list
|
||||||
|
},
|
||||||
|
}
|
||||||
|
communication_action_mapping = {}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
name = 'Distributed Placeholder'
|
||||||
|
|
||||||
|
strategy = self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
|
strategy_list = []
|
||||||
|
if self.placeholder_option == 'distributed':
|
||||||
|
mesh_list = [0, 1]
|
||||||
|
distributed_strategy = self.distributed_placeholder(mesh_list)
|
||||||
|
strategy_list.append(distributed_strategy)
|
||||||
|
else:
|
||||||
|
assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
|
||||||
|
replicated_strategy = self.replica_placeholder()
|
||||||
|
strategy_list.append(replicated_strategy)
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
|
@ -73,7 +73,10 @@ class StrategyGenerator(ABC):
|
||||||
for op_data_name, dim_partition_dict in mapping.items():
|
for op_data_name, dim_partition_dict in mapping.items():
|
||||||
if op_data_name in self.op_data:
|
if op_data_name in self.op_data:
|
||||||
op_data = self.op_data[op_data_name]
|
op_data = self.op_data[op_data_name]
|
||||||
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
|
if isinstance(op_data.data, tuple):
|
||||||
|
for data in op_data.data:
|
||||||
|
assert isinstance(
|
||||||
|
data, torch.Tensor), 'We cannot create a ShardingSpec object from a non-tensor object.'
|
||||||
sharding_spec = []
|
sharding_spec = []
|
||||||
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
||||||
dim_size = len(logical_shape)
|
dim_size = len(logical_shape)
|
||||||
|
@ -82,6 +85,9 @@ class StrategyGenerator(ABC):
|
||||||
entire_shape=logical_shape,
|
entire_shape=logical_shape,
|
||||||
dim_partition_dict=dim_partition_dict_element)
|
dim_partition_dict=dim_partition_dict_element)
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(
|
||||||
|
op_data.data, torch.Tensor
|
||||||
|
), f'op_data.data should be a torch.Tensor or Tuple[torch.Tensor], but got {type(op_data.data)}'
|
||||||
dim_size = len(op_data.logical_shape)
|
dim_size = len(op_data.logical_shape)
|
||||||
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
||||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
|
|
|
@ -43,8 +43,11 @@ class OperationData:
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# if no logical shape is specified, use the data shape as the logical shape
|
# if no logical shape is specified, use the data shape as the logical shape
|
||||||
if self.logical_shape is None and isinstance(self.data, torch.Tensor):
|
if self.logical_shape is None:
|
||||||
self.logical_shape = self.data.shape
|
if isinstance(self.data, torch.Tensor):
|
||||||
|
self.logical_shape = self.data.shape
|
||||||
|
elif isinstance(self.data, tuple):
|
||||||
|
self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data])
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'OperationData(name={self.name}, type={self.type})'
|
return f'OperationData(name={self.name}, type={self.type})'
|
||||||
|
|
|
@ -1,11 +1,30 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
__all__ = ['SolverOptions']
|
__all__ = ['SolverOptions']
|
||||||
|
|
||||||
|
|
||||||
|
class SolverPerference(Enum):
|
||||||
|
"""
|
||||||
|
This enum class is to define the solver preference.
|
||||||
|
"""
|
||||||
|
STANDARD = 0
|
||||||
|
DP = 1
|
||||||
|
TP = 2
|
||||||
|
|
||||||
|
|
||||||
|
class DataloaderOption(Enum):
|
||||||
|
"""
|
||||||
|
This enum class is to define the dataloader option.
|
||||||
|
"""
|
||||||
|
REPLICATED = 0
|
||||||
|
DISTRIBUTED = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SolverOptions:
|
class SolverOptions:
|
||||||
"""
|
"""
|
||||||
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||||||
"""
|
"""
|
||||||
fast: bool = False
|
solver_perference: SolverPerference = SolverPerference.STANDARD
|
||||||
|
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
|
||||||
|
|
|
@ -6,15 +6,16 @@ from typing import Dict, List
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry
|
from colossalai.auto_parallel.tensor_shard.node_handler import (
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler
|
GetattrHandler,
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
|
OuputHandler,
|
||||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
PlacehodlerHandler,
|
||||||
|
operator_registry,
|
||||||
|
)
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
||||||
|
|
||||||
from .options import SolverOptions
|
from .options import DataloaderOption, SolverOptions
|
||||||
|
|
||||||
__all__ = ['StrategiesConstructor']
|
__all__ = ['StrategiesConstructor']
|
||||||
|
|
||||||
|
@ -67,7 +68,15 @@ class StrategiesConstructor:
|
||||||
strategies_vector = StrategiesVector(node)
|
strategies_vector = StrategiesVector(node)
|
||||||
# placeholder node
|
# placeholder node
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector)
|
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||||
|
placeholder_option = 'distributed'
|
||||||
|
else:
|
||||||
|
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||||
|
placeholder_option = 'replicated'
|
||||||
|
placeholder_handler = PlacehodlerHandler(node,
|
||||||
|
self.device_mesh,
|
||||||
|
strategies_vector,
|
||||||
|
placeholder_option=placeholder_option)
|
||||||
placeholder_handler.register_strategy()
|
placeholder_handler.register_strategy()
|
||||||
|
|
||||||
# get_attr node
|
# get_attr node
|
||||||
|
@ -97,7 +106,12 @@ class StrategiesConstructor:
|
||||||
|
|
||||||
# output node
|
# output node
|
||||||
elif node.op == 'output':
|
elif node.op == 'output':
|
||||||
output_handler = OuputHandler(node, self.device_mesh, strategies_vector)
|
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||||
|
output_option = 'distributed'
|
||||||
|
else:
|
||||||
|
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||||
|
output_option = 'replicated'
|
||||||
|
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
||||||
output_handler.register_strategy()
|
output_handler.register_strategy()
|
||||||
|
|
||||||
if len(strategies_vector) <= 0:
|
if len(strategies_vector) <= 0:
|
||||||
|
|
|
@ -84,7 +84,7 @@ def check_linear_module(rank, world_size, port):
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
node_list = list(graph.nodes)
|
node_list = list(graph.nodes)
|
||||||
|
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
linear_node = node_list[3]
|
linear_node = node_list[3]
|
||||||
|
@ -138,7 +138,7 @@ def check_conv_module(rank, world_size, port):
|
||||||
|
|
||||||
node_list = list(graph.nodes)
|
node_list = list(graph.nodes)
|
||||||
conv_node = node_list[3]
|
conv_node = node_list[3]
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ def mem_test_for_node_strategy(rank: int,
|
||||||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
target_node = list(graph.nodes)[node_index]
|
target_node = list(graph.nodes)[node_index]
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
|
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler
|
||||||
OuputHandler
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||||
|
|
||||||
|
|
||||||
class OutputModel(nn.Module):
|
class OutputModel(nn.Module):
|
||||||
|
@ -18,7 +18,9 @@ class OutputModel(nn.Module):
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
def test_output_handler():
|
@parameterize('output_option', ['distributed', 'replicated'])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_output_handler(output_option):
|
||||||
model = OutputModel()
|
model = OutputModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
# graph():
|
# graph():
|
||||||
|
@ -37,7 +39,10 @@ def test_output_handler():
|
||||||
output_strategies_vector = StrategiesVector(output_node)
|
output_strategies_vector = StrategiesVector(output_node)
|
||||||
|
|
||||||
# build handler
|
# build handler
|
||||||
otuput_handler = OuputHandler(node=output_node, device_mesh=device_mesh, strategies_vector=output_strategies_vector)
|
otuput_handler = OuputHandler(node=output_node,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
strategies_vector=output_strategies_vector,
|
||||||
|
output_option=output_option)
|
||||||
|
|
||||||
otuput_handler.register_strategy(compute_resharding_cost=False)
|
otuput_handler.register_strategy(compute_resharding_cost=False)
|
||||||
# check operation data mapping
|
# check operation data mapping
|
||||||
|
@ -49,10 +54,12 @@ def test_output_handler():
|
||||||
assert op_data.data is not None
|
assert op_data.data is not None
|
||||||
|
|
||||||
assert mapping['output'].name == "output"
|
assert mapping['output'].name == "output"
|
||||||
assert mapping['output'].data.is_meta
|
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
|
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
|
||||||
assert "Replica Output" in strategy_name_list
|
if output_option == 'distributed':
|
||||||
|
assert "Distributed Output" in strategy_name_list
|
||||||
|
else:
|
||||||
|
assert "Replica Output" in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
|
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler
|
||||||
PlacehodlerHandler
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderModel(nn.Module):
|
class PlaceholderModel(nn.Module):
|
||||||
|
@ -17,7 +17,9 @@ class PlaceholderModel(nn.Module):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
def test_placeholder_handler():
|
@parameterize('placeholder_option', ['distributed', 'replicated'])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_placeholder_handler(placeholder_option):
|
||||||
model = PlaceholderModel()
|
model = PlaceholderModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
# graph():
|
# graph():
|
||||||
|
@ -33,16 +35,25 @@ def test_placeholder_handler():
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
placeholder_node = list(graph.nodes)[0]
|
placeholder_node = list(graph.nodes)[0]
|
||||||
placeholder_strategies_vector = StrategiesVector(placeholder_node)
|
placeholder_strategies_vector = StrategiesVector(placeholder_node)
|
||||||
|
|
||||||
# build handler
|
# build handler
|
||||||
placeholder_handler = PlacehodlerHandler(node=placeholder_node,
|
placeholder_handler = PlacehodlerHandler(node=placeholder_node,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
strategies_vector=placeholder_strategies_vector)
|
strategies_vector=placeholder_strategies_vector,
|
||||||
|
placeholder_option=placeholder_option)
|
||||||
|
|
||||||
placeholder_handler.register_strategy(compute_resharding_cost=False)
|
placeholder_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
|
||||||
# check operation data mapping
|
# check operation data mapping
|
||||||
mapping = placeholder_handler.get_operation_data_mapping()
|
mapping = placeholder_handler.get_operation_data_mapping()
|
||||||
|
|
||||||
|
strategy = placeholder_strategies_vector[0]
|
||||||
|
strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name)
|
||||||
|
|
||||||
|
if placeholder_option == 'distributed':
|
||||||
|
assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]'
|
||||||
|
else:
|
||||||
|
assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]'
|
||||||
|
|
||||||
for name, op_data in mapping.items():
|
for name, op_data in mapping.items():
|
||||||
op_data: OperationData
|
op_data: OperationData
|
||||||
# make sure they have valid values
|
# make sure they have valid values
|
||||||
|
@ -53,7 +64,10 @@ def test_placeholder_handler():
|
||||||
assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64))
|
assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64))
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
strategy_name_list = [val.name for val in placeholder_handler.strategies_vector]
|
strategy_name_list = [val.name for val in placeholder_handler.strategies_vector]
|
||||||
assert "Replica Placeholder" in strategy_name_list
|
if placeholder_option == 'replicated':
|
||||||
|
assert "Replica Placeholder" in strategy_name_list
|
||||||
|
else:
|
||||||
|
assert "Distributed Placeholder" in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -79,7 +79,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
target_node = list(graph.nodes)[node_index]
|
target_node = list(graph.nodes)[node_index]
|
||||||
|
|
|
@ -79,7 +79,7 @@ def test_linear_module():
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
node_list = list(graph.nodes)
|
node_list = list(graph.nodes)
|
||||||
|
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
linear_node = node_list[3]
|
linear_node = node_list[3]
|
||||||
|
@ -117,7 +117,7 @@ def test_conv_module():
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
node_list = list(graph.nodes)
|
node_list = list(graph.nodes)
|
||||||
conv_node = node_list[3]
|
conv_node = node_list[3]
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
_param_resharding_cost_assertion(conv_node)
|
_param_resharding_cost_assertion(conv_node)
|
||||||
|
|
|
@ -138,7 +138,7 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||||
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||||
|
|
||||||
assert output.shape == origin_output.shape
|
assert output.shape == origin_output.shape
|
||||||
assert_close(output, origin_output)
|
assert_close(output, origin_output, rtol=1e-03, atol=1e-05)
|
||||||
print("*******************backward starting*******************")
|
print("*******************backward starting*******************")
|
||||||
cuda_rng_state = torch.cuda.get_rng_state()
|
cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
output.sum().backward()
|
output.sum().backward()
|
||||||
|
|
|
@ -60,7 +60,7 @@ def check_apply(rank, world_size, port):
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,13 @@ from torch.fx import GraphModule
|
||||||
from torchvision.models import resnet50
|
from torchvision.models import resnet50
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||||
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
|
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||||
StrategiesConstructor)
|
CostGraph,
|
||||||
|
GraphAnalyser,
|
||||||
|
Solver,
|
||||||
|
SolverOptions,
|
||||||
|
StrategiesConstructor,
|
||||||
|
)
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
@ -53,7 +58,7 @@ def test_cost_graph():
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
graph_analyser = GraphAnalyser(gm)
|
graph_analyser = GraphAnalyser(gm)
|
||||||
liveness_list = graph_analyser.liveness_analysis()
|
liveness_list = graph_analyser.liveness_analysis()
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue