mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] support distributed dataloader option (#1906)
* [autoparallel] support distributed dataloader option * update output handler to support ddp dataloader * poish codepull/1978/head^2
parent
6630d45546
commit
0da1d00399
|
@ -19,5 +19,5 @@ __all__ = [
|
|||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler'
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetattrHandler'
|
||||
]
|
||||
|
|
|
@ -2,7 +2,9 @@ from typing import Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import NodeHandler
|
||||
from .strategy import OutputGenerator, StrategyGenerator
|
||||
|
||||
|
@ -14,26 +16,37 @@ class OuputHandler(NodeHandler):
|
|||
A OuputHandler which deals with the sharding strategies for Output Node.
|
||||
"""
|
||||
|
||||
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
output_option: str) -> None:
|
||||
super().__init__(node, device_mesh, strategies_vector)
|
||||
self.output_option = output_option
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node))
|
||||
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
dummy_output = torch.empty(1,).to("meta")
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=dummy_output)
|
||||
|
||||
mapping = {"output": physical_output}
|
||||
mapping = {}
|
||||
output_meta_data = []
|
||||
for index, input_node in enumerate(self.predecessor_node):
|
||||
if not hasattr(input_node, "_meta_data"):
|
||||
print(input_node.name)
|
||||
physical_inputs = OperationData(name=str(input_node),
|
||||
type=OperationDataType.ARG,
|
||||
data=input_node._meta_data)
|
||||
input_meta_data = input_node._meta_data
|
||||
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
|
||||
name_key = f'input_{index}'
|
||||
mapping[name_key] = physical_inputs
|
||||
output_meta_data.append(input_meta_data)
|
||||
|
||||
assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
|
||||
if len(output_meta_data) == 1:
|
||||
output_meta_data = output_meta_data[0]
|
||||
else:
|
||||
output_meta_data = tuple(output_meta_data)
|
||||
|
||||
self.node._meta_data = output_meta_data
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping["output"] = physical_output
|
||||
return mapping
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from typing import Dict, List
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import NodeHandler
|
||||
from .strategy import PlaceholderGenerator, StrategyGenerator
|
||||
|
||||
|
@ -12,10 +16,16 @@ class PlacehodlerHandler(NodeHandler):
|
|||
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
placeholder_option: str) -> None:
|
||||
super().__init__(node, device_mesh, strategies_vector)
|
||||
self.placeholder_option = placeholder_option
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh))
|
||||
generators.append(
|
||||
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from torch.fx import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .strategy_generator import OutputStrategyGenerator
|
||||
|
||||
|
@ -12,6 +20,11 @@ class OutputGenerator(OutputStrategyGenerator):
|
|||
OutputGenerator is a generic class to generate strategies for Output Node.
|
||||
"""
|
||||
|
||||
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
|
||||
predecessor_nodes: List[Node], output_option: str):
|
||||
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
|
||||
self.output_option = output_option
|
||||
|
||||
def validate(self) -> bool:
|
||||
return super().validate()
|
||||
|
||||
|
@ -32,7 +45,10 @@ class OutputGenerator(OutputStrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
def replica_strategy(self) -> List[ShardingStrategy]:
|
||||
"""
|
||||
Generate replica strategy for output node.
|
||||
"""
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {},
|
||||
}
|
||||
|
@ -48,5 +64,47 @@ class OutputGenerator(OutputStrategyGenerator):
|
|||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
return strategy
|
||||
|
||||
return [strategy]
|
||||
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
|
||||
"""
|
||||
Generate distributed strategy for output node.
|
||||
"""
|
||||
# TODO: need to take care of the case when the first element of output only need to be sharded.
|
||||
output_op_data = self.op_data['output']
|
||||
if isinstance(output_op_data.data, tuple):
|
||||
length = len(output_op_data.data)
|
||||
dim_partition_dict_mapping = {
|
||||
"output": [{
|
||||
0: mesh_list
|
||||
}] * length,
|
||||
}
|
||||
else:
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {
|
||||
0: mesh_list
|
||||
},
|
||||
}
|
||||
for index, _ in enumerate(self.predecessor_nodes):
|
||||
mapping_name = f"input_{index}"
|
||||
dim_partition_dict_mapping[mapping_name] = {0: mesh_list}
|
||||
|
||||
communication_action_mapping = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
name = 'Distributed Output'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
return strategy
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
mesh_list = [0, 1]
|
||||
if self.output_option == 'replicated':
|
||||
strategy_list.append(self.replica_strategy())
|
||||
elif self.output_option == 'distributed':
|
||||
strategy_list.append(self.distributed_strategy(mesh_list))
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
|
@ -12,6 +18,11 @@ class PlaceholderGenerator(StrategyGenerator):
|
|||
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
|
||||
"""
|
||||
|
||||
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
|
||||
placeholder_option: str):
|
||||
super().__init__(operation_data_mapping, device_mesh)
|
||||
self.placeholder_option = placeholder_option
|
||||
|
||||
def validate(self) -> bool:
|
||||
return super().validate()
|
||||
|
||||
|
@ -37,7 +48,10 @@ class PlaceholderGenerator(StrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
def replica_placeholder(self) -> ShardingStrategy:
|
||||
"""
|
||||
Generate replica strategy for placeholder node.
|
||||
"""
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {},
|
||||
}
|
||||
|
@ -50,4 +64,37 @@ class PlaceholderGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
return [strategy]
|
||||
return strategy
|
||||
|
||||
def distributed_placeholder(self, mesh_list) -> ShardingStrategy:
|
||||
"""
|
||||
Generate distributed strategy for placeholder node.
|
||||
"""
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {
|
||||
0: mesh_list
|
||||
},
|
||||
}
|
||||
communication_action_mapping = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
name = 'Distributed Placeholder'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
return strategy
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
if self.placeholder_option == 'distributed':
|
||||
mesh_list = [0, 1]
|
||||
distributed_strategy = self.distributed_placeholder(mesh_list)
|
||||
strategy_list.append(distributed_strategy)
|
||||
else:
|
||||
assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
|
||||
replicated_strategy = self.replica_placeholder()
|
||||
strategy_list.append(replicated_strategy)
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -73,7 +73,10 @@ class StrategyGenerator(ABC):
|
|||
for op_data_name, dim_partition_dict in mapping.items():
|
||||
if op_data_name in self.op_data:
|
||||
op_data = self.op_data[op_data_name]
|
||||
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
|
||||
if isinstance(op_data.data, tuple):
|
||||
for data in op_data.data:
|
||||
assert isinstance(
|
||||
data, torch.Tensor), 'We cannot create a ShardingSpec object from a non-tensor object.'
|
||||
sharding_spec = []
|
||||
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
||||
dim_size = len(logical_shape)
|
||||
|
@ -82,6 +85,9 @@ class StrategyGenerator(ABC):
|
|||
entire_shape=logical_shape,
|
||||
dim_partition_dict=dim_partition_dict_element)
|
||||
else:
|
||||
assert isinstance(
|
||||
op_data.data, torch.Tensor
|
||||
), f'op_data.data should be a torch.Tensor or Tuple[torch.Tensor], but got {type(op_data.data)}'
|
||||
dim_size = len(op_data.logical_shape)
|
||||
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
|
|
|
@ -43,8 +43,11 @@ class OperationData:
|
|||
|
||||
def __post_init__(self):
|
||||
# if no logical shape is specified, use the data shape as the logical shape
|
||||
if self.logical_shape is None and isinstance(self.data, torch.Tensor):
|
||||
self.logical_shape = self.data.shape
|
||||
if self.logical_shape is None:
|
||||
if isinstance(self.data, torch.Tensor):
|
||||
self.logical_shape = self.data.shape
|
||||
elif isinstance(self.data, tuple):
|
||||
self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'OperationData(name={self.name}, type={self.type})'
|
||||
|
|
|
@ -1,11 +1,30 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ['SolverOptions']
|
||||
|
||||
|
||||
class SolverPerference(Enum):
|
||||
"""
|
||||
This enum class is to define the solver preference.
|
||||
"""
|
||||
STANDARD = 0
|
||||
DP = 1
|
||||
TP = 2
|
||||
|
||||
|
||||
class DataloaderOption(Enum):
|
||||
"""
|
||||
This enum class is to define the dataloader option.
|
||||
"""
|
||||
REPLICATED = 0
|
||||
DISTRIBUTED = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolverOptions:
|
||||
"""
|
||||
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||||
"""
|
||||
fast: bool = False
|
||||
solver_perference: SolverPerference = SolverPerference.STANDARD
|
||||
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
|
||||
|
|
|
@ -6,15 +6,16 @@ from typing import Dict, List
|
|||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import (
|
||||
GetattrHandler,
|
||||
OuputHandler,
|
||||
PlacehodlerHandler,
|
||||
operator_registry,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .options import SolverOptions
|
||||
from .options import DataloaderOption, SolverOptions
|
||||
|
||||
__all__ = ['StrategiesConstructor']
|
||||
|
||||
|
@ -67,7 +68,15 @@ class StrategiesConstructor:
|
|||
strategies_vector = StrategiesVector(node)
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector)
|
||||
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||
placeholder_option = 'distributed'
|
||||
else:
|
||||
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||
placeholder_option = 'replicated'
|
||||
placeholder_handler = PlacehodlerHandler(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
placeholder_option=placeholder_option)
|
||||
placeholder_handler.register_strategy()
|
||||
|
||||
# get_attr node
|
||||
|
@ -97,7 +106,12 @@ class StrategiesConstructor:
|
|||
|
||||
# output node
|
||||
elif node.op == 'output':
|
||||
output_handler = OuputHandler(node, self.device_mesh, strategies_vector)
|
||||
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||
output_option = 'distributed'
|
||||
else:
|
||||
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||
output_option = 'replicated'
|
||||
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
||||
output_handler.register_strategy()
|
||||
|
||||
if len(strategies_vector) <= 0:
|
||||
|
|
|
@ -84,7 +84,7 @@ def check_linear_module(rank, world_size, port):
|
|||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
linear_node = node_list[3]
|
||||
|
@ -138,7 +138,7 @@ def check_conv_module(rank, world_size, port):
|
|||
|
||||
node_list = list(graph.nodes)
|
||||
conv_node = node_list[3]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ def mem_test_for_node_strategy(rank: int,
|
|||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
target_node = list(graph.nodes)[node_index]
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
|
||||
OuputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
|
||||
|
||||
class OutputModel(nn.Module):
|
||||
|
@ -18,7 +18,9 @@ class OutputModel(nn.Module):
|
|||
return x, y
|
||||
|
||||
|
||||
def test_output_handler():
|
||||
@parameterize('output_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_output_handler(output_option):
|
||||
model = OutputModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
|
@ -37,7 +39,10 @@ def test_output_handler():
|
|||
output_strategies_vector = StrategiesVector(output_node)
|
||||
|
||||
# build handler
|
||||
otuput_handler = OuputHandler(node=output_node, device_mesh=device_mesh, strategies_vector=output_strategies_vector)
|
||||
otuput_handler = OuputHandler(node=output_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=output_strategies_vector,
|
||||
output_option=output_option)
|
||||
|
||||
otuput_handler.register_strategy(compute_resharding_cost=False)
|
||||
# check operation data mapping
|
||||
|
@ -49,10 +54,12 @@ def test_output_handler():
|
|||
assert op_data.data is not None
|
||||
|
||||
assert mapping['output'].name == "output"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
|
||||
assert "Replica Output" in strategy_name_list
|
||||
if output_option == 'distributed':
|
||||
assert "Distributed Output" in strategy_name_list
|
||||
else:
|
||||
assert "Replica Output" in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
|
||||
PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
|
||||
|
||||
class PlaceholderModel(nn.Module):
|
||||
|
@ -17,7 +17,9 @@ class PlaceholderModel(nn.Module):
|
|||
return input
|
||||
|
||||
|
||||
def test_placeholder_handler():
|
||||
@parameterize('placeholder_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_placeholder_handler(placeholder_option):
|
||||
model = PlaceholderModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
|
@ -33,16 +35,25 @@ def test_placeholder_handler():
|
|||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
placeholder_node = list(graph.nodes)[0]
|
||||
placeholder_strategies_vector = StrategiesVector(placeholder_node)
|
||||
|
||||
# build handler
|
||||
placeholder_handler = PlacehodlerHandler(node=placeholder_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=placeholder_strategies_vector)
|
||||
strategies_vector=placeholder_strategies_vector,
|
||||
placeholder_option=placeholder_option)
|
||||
|
||||
placeholder_handler.register_strategy(compute_resharding_cost=False)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = placeholder_handler.get_operation_data_mapping()
|
||||
|
||||
strategy = placeholder_strategies_vector[0]
|
||||
strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name)
|
||||
|
||||
if placeholder_option == 'distributed':
|
||||
assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]'
|
||||
else:
|
||||
assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]'
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
|
@ -53,7 +64,10 @@ def test_placeholder_handler():
|
|||
assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64))
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
strategy_name_list = [val.name for val in placeholder_handler.strategies_vector]
|
||||
assert "Replica Placeholder" in strategy_name_list
|
||||
if placeholder_option == 'replicated':
|
||||
assert "Replica Placeholder" in strategy_name_list
|
||||
else:
|
||||
assert "Distributed Placeholder" in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -79,7 +79,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
target_node = list(graph.nodes)[node_index]
|
||||
|
|
|
@ -79,7 +79,7 @@ def test_linear_module():
|
|||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
linear_node = node_list[3]
|
||||
|
@ -117,7 +117,7 @@ def test_conv_module():
|
|||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
conv_node = node_list[3]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
_param_resharding_cost_assertion(conv_node)
|
||||
|
|
|
@ -138,7 +138,7 @@ def check_apply_bottleneck(rank, world_size, port):
|
|||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
@ -162,7 +162,7 @@ def check_apply_bottleneck(rank, world_size, port):
|
|||
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
||||
assert output.shape == origin_output.shape
|
||||
assert_close(output, origin_output)
|
||||
assert_close(output, origin_output, rtol=1e-03, atol=1e-05)
|
||||
print("*******************backward starting*******************")
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
output.sum().backward()
|
||||
|
|
|
@ -60,7 +60,7 @@ def check_apply(rank, world_size, port):
|
|||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
|
|
@ -3,8 +3,13 @@ from torch.fx import GraphModule
|
|||
from torchvision.models import resnet50
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
|
||||
StrategiesConstructor)
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
|
@ -53,7 +58,7 @@ def test_cost_graph():
|
|||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
|
Loading…
Reference in New Issue